1use alloc::{sync::Arc, vec::Vec};
2
3use axdriver_base::{BaseDriverOps, DevError, DevResult, DeviceType};
4use axdriver_net::{EthernetAddress, NetBuf, NetBufBox, NetBufPool, NetBufPtr, NetDriverOps};
5use virtio_drivers::{Hal, device::net::VirtIONetRaw as InnerDev, transport::Transport};
6
7use crate::as_dev_err;
8
9const NET_BUF_LEN: usize = 1526;
10
11pub struct VirtIoNetDev<H: Hal, T: Transport, const QS: usize> {
15 rx_buffers: [Option<NetBufBox>; QS],
16 tx_buffers: [Option<NetBufBox>; QS],
17 free_tx_bufs: Vec<NetBufBox>,
18 buf_pool: Arc<NetBufPool>,
19 inner: InnerDev<H, T, QS>,
20}
21
22unsafe impl<H: Hal, T: Transport, const QS: usize> Send for VirtIoNetDev<H, T, QS> {}
23unsafe impl<H: Hal, T: Transport, const QS: usize> Sync for VirtIoNetDev<H, T, QS> {}
24
25impl<H: Hal, T: Transport, const QS: usize> VirtIoNetDev<H, T, QS> {
26 pub fn try_new(transport: T) -> DevResult<Self> {
29 const NONE_BUF: Option<NetBufBox> = None;
31 let inner = InnerDev::new(transport).map_err(as_dev_err)?;
32 let rx_buffers = [NONE_BUF; QS];
33 let tx_buffers = [NONE_BUF; QS];
34 let buf_pool = NetBufPool::new(2 * QS, NET_BUF_LEN)?;
35 let free_tx_bufs = Vec::with_capacity(QS);
36
37 let mut dev = Self {
38 rx_buffers,
39 inner,
40 tx_buffers,
41 free_tx_bufs,
42 buf_pool,
43 };
44
45 for (i, rx_buf_place) in dev.rx_buffers.iter_mut().enumerate() {
47 let mut rx_buf = dev.buf_pool.alloc_boxed().ok_or(DevError::NoMemory)?;
48 let token = unsafe {
50 dev.inner
51 .receive_begin(rx_buf.raw_buf_mut())
52 .map_err(as_dev_err)?
53 };
54 assert_eq!(token, i as u16);
55 *rx_buf_place = Some(rx_buf);
56 }
57
58 for _ in 0..QS {
60 let mut tx_buf = dev.buf_pool.alloc_boxed().ok_or(DevError::NoMemory)?;
61 let hdr_len = dev
63 .inner
64 .fill_buffer_header(tx_buf.raw_buf_mut())
65 .or(Err(DevError::InvalidParam))?;
66 tx_buf.set_header_len(hdr_len);
67 dev.free_tx_bufs.push(tx_buf);
68 }
69
70 Ok(dev)
72 }
73}
74
75impl<H: Hal, T: Transport, const QS: usize> BaseDriverOps for VirtIoNetDev<H, T, QS> {
76 fn device_name(&self) -> &str {
77 "virtio-net"
78 }
79
80 fn device_type(&self) -> DeviceType {
81 DeviceType::Net
82 }
83}
84
85impl<H: Hal, T: Transport, const QS: usize> NetDriverOps for VirtIoNetDev<H, T, QS> {
86 #[inline]
87 fn mac_address(&self) -> EthernetAddress {
88 EthernetAddress(self.inner.mac_address())
89 }
90
91 #[inline]
92 fn can_transmit(&self) -> bool {
93 !self.free_tx_bufs.is_empty() && self.inner.can_send()
94 }
95
96 #[inline]
97 fn can_receive(&self) -> bool {
98 self.inner.poll_receive().is_some()
99 }
100
101 #[inline]
102 fn rx_queue_size(&self) -> usize {
103 QS
104 }
105
106 #[inline]
107 fn tx_queue_size(&self) -> usize {
108 QS
109 }
110
111 fn recycle_rx_buffer(&mut self, rx_buf: NetBufPtr) -> DevResult {
112 let mut rx_buf = unsafe { NetBuf::from_buf_ptr(rx_buf) };
113 let new_token = unsafe {
116 self.inner
117 .receive_begin(rx_buf.raw_buf_mut())
118 .map_err(as_dev_err)?
119 };
120 if self.rx_buffers[new_token as usize].is_some() {
123 return Err(DevError::BadState);
124 }
125 self.rx_buffers[new_token as usize] = Some(rx_buf);
126 Ok(())
127 }
128
129 fn recycle_tx_buffers(&mut self) -> DevResult {
130 while let Some(token) = self.inner.poll_transmit() {
131 let tx_buf = self.tx_buffers[token as usize]
132 .take()
133 .ok_or(DevError::BadState)?;
134 unsafe {
135 self.inner
136 .transmit_complete(token, tx_buf.packet_with_header())
137 .map_err(as_dev_err)?;
138 }
139 self.free_tx_bufs.push(tx_buf);
141 }
142 Ok(())
143 }
144
145 fn transmit(&mut self, tx_buf: NetBufPtr) -> DevResult {
146 let tx_buf = unsafe { NetBuf::from_buf_ptr(tx_buf) };
148 let token = unsafe {
150 self.inner
151 .transmit_begin(tx_buf.packet_with_header())
152 .map_err(as_dev_err)?
153 };
154 self.tx_buffers[token as usize] = Some(tx_buf);
155 Ok(())
156 }
157
158 fn receive(&mut self) -> DevResult<NetBufPtr> {
159 if let Some(token) = self.inner.poll_receive() {
160 let mut rx_buf = self.rx_buffers[token as usize]
161 .take()
162 .ok_or(DevError::BadState)?;
163 let (hdr_len, pkt_len) = unsafe {
165 self.inner
166 .receive_complete(token, rx_buf.raw_buf_mut())
167 .map_err(as_dev_err)?
168 };
169 rx_buf.set_header_len(hdr_len);
170 rx_buf.set_packet_len(pkt_len);
171
172 Ok(rx_buf.into_buf_ptr())
173 } else {
174 Err(DevError::Again)
175 }
176 }
177
178 fn alloc_tx_buffer(&mut self, size: usize) -> DevResult<NetBufPtr> {
179 let mut net_buf = self.free_tx_bufs.pop().ok_or(DevError::NoMemory)?;
181 let pkt_len = size;
182
183 let hdr_len = net_buf.header_len();
185 if hdr_len + pkt_len > net_buf.capacity() {
186 return Err(DevError::InvalidParam);
187 }
188 net_buf.set_packet_len(pkt_len);
189
190 Ok(net_buf.into_buf_ptr())
192 }
193}