Skip to main content

axdriver_virtio/
net.rs

1use alloc::{sync::Arc, vec::Vec};
2
3use axdriver_base::{BaseDriverOps, DevError, DevResult, DeviceType};
4use axdriver_net::{EthernetAddress, NetBuf, NetBufBox, NetBufPool, NetBufPtr, NetDriverOps};
5use virtio_drivers::{Hal, device::net::VirtIONetRaw as InnerDev, transport::Transport};
6
7use crate::as_dev_err;
8
9const NET_BUF_LEN: usize = 1526;
10
11/// The VirtIO network device driver.
12///
13/// `QS` is the VirtIO queue size.
14pub struct VirtIoNetDev<H: Hal, T: Transport, const QS: usize> {
15    rx_buffers: [Option<NetBufBox>; QS],
16    tx_buffers: [Option<NetBufBox>; QS],
17    free_tx_bufs: Vec<NetBufBox>,
18    buf_pool: Arc<NetBufPool>,
19    inner: InnerDev<H, T, QS>,
20}
21
22unsafe impl<H: Hal, T: Transport, const QS: usize> Send for VirtIoNetDev<H, T, QS> {}
23unsafe impl<H: Hal, T: Transport, const QS: usize> Sync for VirtIoNetDev<H, T, QS> {}
24
25impl<H: Hal, T: Transport, const QS: usize> VirtIoNetDev<H, T, QS> {
26    /// Creates a new driver instance and initializes the device, or returns
27    /// an error if any step fails.
28    pub fn try_new(transport: T) -> DevResult<Self> {
29        // 0. Create a new driver instance.
30        const NONE_BUF: Option<NetBufBox> = None;
31        let inner = InnerDev::new(transport).map_err(as_dev_err)?;
32        let rx_buffers = [NONE_BUF; QS];
33        let tx_buffers = [NONE_BUF; QS];
34        let buf_pool = NetBufPool::new(2 * QS, NET_BUF_LEN)?;
35        let free_tx_bufs = Vec::with_capacity(QS);
36
37        let mut dev = Self {
38            rx_buffers,
39            inner,
40            tx_buffers,
41            free_tx_bufs,
42            buf_pool,
43        };
44
45        // 1. Fill all rx buffers.
46        for (i, rx_buf_place) in dev.rx_buffers.iter_mut().enumerate() {
47            let mut rx_buf = dev.buf_pool.alloc_boxed().ok_or(DevError::NoMemory)?;
48            // Safe because the buffer lives as long as the queue.
49            let token = unsafe {
50                dev.inner
51                    .receive_begin(rx_buf.raw_buf_mut())
52                    .map_err(as_dev_err)?
53            };
54            assert_eq!(token, i as u16);
55            *rx_buf_place = Some(rx_buf);
56        }
57
58        // 2. Allocate all tx buffers.
59        for _ in 0..QS {
60            let mut tx_buf = dev.buf_pool.alloc_boxed().ok_or(DevError::NoMemory)?;
61            // Fill header
62            let hdr_len = dev
63                .inner
64                .fill_buffer_header(tx_buf.raw_buf_mut())
65                .or(Err(DevError::InvalidParam))?;
66            tx_buf.set_header_len(hdr_len);
67            dev.free_tx_bufs.push(tx_buf);
68        }
69
70        // 3. Return the driver instance.
71        Ok(dev)
72    }
73}
74
75impl<H: Hal, T: Transport, const QS: usize> BaseDriverOps for VirtIoNetDev<H, T, QS> {
76    fn device_name(&self) -> &str {
77        "virtio-net"
78    }
79
80    fn device_type(&self) -> DeviceType {
81        DeviceType::Net
82    }
83}
84
85impl<H: Hal, T: Transport, const QS: usize> NetDriverOps for VirtIoNetDev<H, T, QS> {
86    #[inline]
87    fn mac_address(&self) -> EthernetAddress {
88        EthernetAddress(self.inner.mac_address())
89    }
90
91    #[inline]
92    fn can_transmit(&self) -> bool {
93        !self.free_tx_bufs.is_empty() && self.inner.can_send()
94    }
95
96    #[inline]
97    fn can_receive(&self) -> bool {
98        self.inner.poll_receive().is_some()
99    }
100
101    #[inline]
102    fn rx_queue_size(&self) -> usize {
103        QS
104    }
105
106    #[inline]
107    fn tx_queue_size(&self) -> usize {
108        QS
109    }
110
111    fn recycle_rx_buffer(&mut self, rx_buf: NetBufPtr) -> DevResult {
112        let mut rx_buf = unsafe { NetBuf::from_buf_ptr(rx_buf) };
113        // Safe because we take the ownership of `rx_buf` back to `rx_buffers`,
114        // it lives as long as the queue.
115        let new_token = unsafe {
116            self.inner
117                .receive_begin(rx_buf.raw_buf_mut())
118                .map_err(as_dev_err)?
119        };
120        // `rx_buffers[new_token]` is expected to be `None` since it was taken
121        // away at `Self::receive()` and has not been added back.
122        if self.rx_buffers[new_token as usize].is_some() {
123            return Err(DevError::BadState);
124        }
125        self.rx_buffers[new_token as usize] = Some(rx_buf);
126        Ok(())
127    }
128
129    fn recycle_tx_buffers(&mut self) -> DevResult {
130        while let Some(token) = self.inner.poll_transmit() {
131            let tx_buf = self.tx_buffers[token as usize]
132                .take()
133                .ok_or(DevError::BadState)?;
134            unsafe {
135                self.inner
136                    .transmit_complete(token, tx_buf.packet_with_header())
137                    .map_err(as_dev_err)?;
138            }
139            // Recycle the buffer.
140            self.free_tx_bufs.push(tx_buf);
141        }
142        Ok(())
143    }
144
145    fn transmit(&mut self, tx_buf: NetBufPtr) -> DevResult {
146        // 0. prepare tx buffer.
147        let tx_buf = unsafe { NetBuf::from_buf_ptr(tx_buf) };
148        // 1. transmit packet.
149        let token = unsafe {
150            self.inner
151                .transmit_begin(tx_buf.packet_with_header())
152                .map_err(as_dev_err)?
153        };
154        self.tx_buffers[token as usize] = Some(tx_buf);
155        Ok(())
156    }
157
158    fn receive(&mut self) -> DevResult<NetBufPtr> {
159        if let Some(token) = self.inner.poll_receive() {
160            let mut rx_buf = self.rx_buffers[token as usize]
161                .take()
162                .ok_or(DevError::BadState)?;
163            // Safe because the buffer lives as long as the queue.
164            let (hdr_len, pkt_len) = unsafe {
165                self.inner
166                    .receive_complete(token, rx_buf.raw_buf_mut())
167                    .map_err(as_dev_err)?
168            };
169            rx_buf.set_header_len(hdr_len);
170            rx_buf.set_packet_len(pkt_len);
171
172            Ok(rx_buf.into_buf_ptr())
173        } else {
174            Err(DevError::Again)
175        }
176    }
177
178    fn alloc_tx_buffer(&mut self, size: usize) -> DevResult<NetBufPtr> {
179        // 0. Allocate a buffer from the queue.
180        let mut net_buf = self.free_tx_bufs.pop().ok_or(DevError::NoMemory)?;
181        let pkt_len = size;
182
183        // 1. Check if the buffer is large enough.
184        let hdr_len = net_buf.header_len();
185        if hdr_len + pkt_len > net_buf.capacity() {
186            return Err(DevError::InvalidParam);
187        }
188        net_buf.set_packet_len(pkt_len);
189
190        // 2. Return the buffer.
191        Ok(net_buf.into_buf_ptr())
192    }
193}