axnet/smoltcp_impl/
mod.rs

1mod addr;
2mod bench;
3mod dns;
4mod listen_table;
5mod tcp;
6mod udp;
7
8use alloc::vec;
9use core::cell::RefCell;
10use core::ops::DerefMut;
11
12use axdriver::prelude::*;
13use axdriver_net::{DevError, NetBufPtr};
14use axhal::time::{NANOS_PER_MICROS, wall_time_nanos};
15use axsync::Mutex;
16use lazyinit::LazyInit;
17use smoltcp::iface::{Config, Interface, SocketHandle, SocketSet};
18use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken};
19use smoltcp::socket::{self, AnySocket};
20use smoltcp::time::Instant;
21use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress, IpCidr};
22
23use self::listen_table::ListenTable;
24
25pub use self::dns::dns_query;
26pub use self::tcp::TcpSocket;
27pub use self::udp::UdpSocket;
28
29macro_rules! env_or_default {
30    ($key:literal) => {
31        match option_env!($key) {
32            Some(val) => val,
33            None => "",
34        }
35    };
36}
37
38const IP: &str = env_or_default!("AX_IP");
39const GATEWAY: &str = env_or_default!("AX_GW");
40const DNS_SEVER: &str = "8.8.8.8";
41const IP_PREFIX: u8 = 24;
42
43const STANDARD_MTU: usize = 1500;
44
45const RANDOM_SEED: u64 = 0xA2CE_05A2_CE05_A2CE;
46
47const TCP_RX_BUF_LEN: usize = 64 * 1024;
48const TCP_TX_BUF_LEN: usize = 64 * 1024;
49const UDP_RX_BUF_LEN: usize = 64 * 1024;
50const UDP_TX_BUF_LEN: usize = 64 * 1024;
51const LISTEN_QUEUE_SIZE: usize = 512;
52
53static LISTEN_TABLE: LazyInit<ListenTable> = LazyInit::new();
54static SOCKET_SET: LazyInit<SocketSetWrapper> = LazyInit::new();
55static ETH0: LazyInit<InterfaceWrapper> = LazyInit::new();
56
57struct SocketSetWrapper<'a>(Mutex<SocketSet<'a>>);
58
59struct DeviceWrapper {
60    inner: RefCell<AxNetDevice>, // use `RefCell` is enough since it's wrapped in `Mutex` in `InterfaceWrapper`.
61}
62
63struct InterfaceWrapper {
64    name: &'static str,
65    ether_addr: EthernetAddress,
66    dev: Mutex<DeviceWrapper>,
67    iface: Mutex<Interface>,
68}
69
70impl<'a> SocketSetWrapper<'a> {
71    fn new() -> Self {
72        Self(Mutex::new(SocketSet::new(vec![])))
73    }
74
75    pub fn new_tcp_socket() -> socket::tcp::Socket<'a> {
76        let tcp_rx_buffer = socket::tcp::SocketBuffer::new(vec![0; TCP_RX_BUF_LEN]);
77        let tcp_tx_buffer = socket::tcp::SocketBuffer::new(vec![0; TCP_TX_BUF_LEN]);
78        socket::tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer)
79    }
80
81    pub fn new_udp_socket() -> socket::udp::Socket<'a> {
82        let udp_rx_buffer = socket::udp::PacketBuffer::new(
83            vec![socket::udp::PacketMetadata::EMPTY; 8],
84            vec![0; UDP_RX_BUF_LEN],
85        );
86        let udp_tx_buffer = socket::udp::PacketBuffer::new(
87            vec![socket::udp::PacketMetadata::EMPTY; 8],
88            vec![0; UDP_TX_BUF_LEN],
89        );
90        socket::udp::Socket::new(udp_rx_buffer, udp_tx_buffer)
91    }
92
93    pub fn new_dns_socket() -> socket::dns::Socket<'a> {
94        let server_addr = DNS_SEVER.parse().expect("invalid DNS server address");
95        socket::dns::Socket::new(&[server_addr], vec![])
96    }
97
98    pub fn add<T: AnySocket<'a>>(&self, socket: T) -> SocketHandle {
99        let handle = self.0.lock().add(socket);
100        debug!("socket {}: created", handle);
101        handle
102    }
103
104    pub fn with_socket<T: AnySocket<'a>, R, F>(&self, handle: SocketHandle, f: F) -> R
105    where
106        F: FnOnce(&T) -> R,
107    {
108        let set = self.0.lock();
109        let socket = set.get(handle);
110        f(socket)
111    }
112
113    pub fn with_socket_mut<T: AnySocket<'a>, R, F>(&self, handle: SocketHandle, f: F) -> R
114    where
115        F: FnOnce(&mut T) -> R,
116    {
117        let mut set = self.0.lock();
118        let socket = set.get_mut(handle);
119        f(socket)
120    }
121
122    pub fn poll_interfaces(&self) {
123        ETH0.poll(&self.0);
124    }
125
126    pub fn remove(&self, handle: SocketHandle) {
127        self.0.lock().remove(handle);
128        debug!("socket {}: destroyed", handle);
129    }
130}
131
132impl InterfaceWrapper {
133    fn new(name: &'static str, dev: AxNetDevice, ether_addr: EthernetAddress) -> Self {
134        let mut config = Config::new(HardwareAddress::Ethernet(ether_addr));
135        config.random_seed = RANDOM_SEED;
136
137        let mut dev = DeviceWrapper::new(dev);
138        let iface = Mutex::new(Interface::new(config, &mut dev, Self::current_time()));
139        Self {
140            name,
141            ether_addr,
142            dev: Mutex::new(dev),
143            iface,
144        }
145    }
146
147    fn current_time() -> Instant {
148        Instant::from_micros_const((wall_time_nanos() / NANOS_PER_MICROS) as i64)
149    }
150
151    pub fn name(&self) -> &str {
152        self.name
153    }
154
155    pub fn ethernet_address(&self) -> EthernetAddress {
156        self.ether_addr
157    }
158
159    pub fn setup_ip_addr(&self, ip: IpAddress, prefix_len: u8) {
160        let mut iface = self.iface.lock();
161        iface.update_ip_addrs(|ip_addrs| {
162            ip_addrs.push(IpCidr::new(ip, prefix_len)).unwrap();
163        });
164    }
165
166    pub fn setup_gateway(&self, gateway: IpAddress) {
167        let mut iface = self.iface.lock();
168        match gateway {
169            IpAddress::Ipv4(v4) => iface.routes_mut().add_default_ipv4_route(v4).unwrap(),
170            IpAddress::Ipv6(v6) => iface.routes_mut().add_default_ipv6_route(v6).unwrap(),
171        };
172    }
173
174    pub fn poll(&self, sockets: &Mutex<SocketSet>) {
175        let mut dev = self.dev.lock();
176        let mut iface = self.iface.lock();
177        let mut sockets = sockets.lock();
178        let timestamp = Self::current_time();
179        iface.poll(timestamp, dev.deref_mut(), &mut sockets);
180    }
181}
182
183impl DeviceWrapper {
184    fn new(inner: AxNetDevice) -> Self {
185        Self {
186            inner: RefCell::new(inner),
187        }
188    }
189}
190
191impl Device for DeviceWrapper {
192    type RxToken<'a>
193        = AxNetRxToken<'a>
194    where
195        Self: 'a;
196    type TxToken<'a>
197        = AxNetTxToken<'a>
198    where
199        Self: 'a;
200
201    fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
202        let mut dev = self.inner.borrow_mut();
203        if let Err(e) = dev.recycle_tx_buffers() {
204            warn!("recycle_tx_buffers failed: {:?}", e);
205            return None;
206        }
207
208        if !dev.can_transmit() {
209            return None;
210        }
211        let rx_buf = match dev.receive() {
212            Ok(buf) => buf,
213            Err(err) => {
214                if !matches!(err, DevError::Again) {
215                    warn!("receive failed: {:?}", err);
216                }
217                return None;
218            }
219        };
220        Some((AxNetRxToken(&self.inner, rx_buf), AxNetTxToken(&self.inner)))
221    }
222
223    fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
224        let mut dev = self.inner.borrow_mut();
225        if let Err(e) = dev.recycle_tx_buffers() {
226            warn!("recycle_tx_buffers failed: {:?}", e);
227            return None;
228        }
229        if dev.can_transmit() {
230            Some(AxNetTxToken(&self.inner))
231        } else {
232            None
233        }
234    }
235
236    fn capabilities(&self) -> DeviceCapabilities {
237        let mut caps = DeviceCapabilities::default();
238        caps.max_transmission_unit = 1514;
239        caps.max_burst_size = None;
240        caps.medium = Medium::Ethernet;
241        caps
242    }
243}
244
245struct AxNetRxToken<'a>(&'a RefCell<AxNetDevice>, NetBufPtr);
246struct AxNetTxToken<'a>(&'a RefCell<AxNetDevice>);
247
248impl RxToken for AxNetRxToken<'_> {
249    fn preprocess(&self, sockets: &mut SocketSet<'_>) {
250        snoop_tcp_packet(self.1.packet(), sockets).ok();
251    }
252
253    fn consume<R, F>(self, f: F) -> R
254    where
255        F: FnOnce(&[u8]) -> R,
256    {
257        let rx_buf = self.1;
258        trace!(
259            "RECV {} bytes: {:02X?}",
260            rx_buf.packet_len(),
261            rx_buf.packet()
262        );
263        let result = f(rx_buf.packet());
264        self.0.borrow_mut().recycle_rx_buffer(rx_buf).unwrap();
265        result
266    }
267}
268
269impl TxToken for AxNetTxToken<'_> {
270    fn consume<R, F>(self, len: usize, f: F) -> R
271    where
272        F: FnOnce(&mut [u8]) -> R,
273    {
274        let mut dev = self.0.borrow_mut();
275        let mut tx_buf = dev.alloc_tx_buffer(len).unwrap();
276        let ret = f(tx_buf.packet_mut());
277        trace!("SEND {} bytes: {:02X?}", len, tx_buf.packet());
278        dev.transmit(tx_buf).unwrap();
279        ret
280    }
281}
282
283fn snoop_tcp_packet(buf: &[u8], sockets: &mut SocketSet<'_>) -> Result<(), smoltcp::wire::Error> {
284    use smoltcp::wire::{EthernetFrame, IpProtocol, Ipv4Packet, TcpPacket};
285
286    let ether_frame = EthernetFrame::new_checked(buf)?;
287    let ipv4_packet = Ipv4Packet::new_checked(ether_frame.payload())?;
288
289    if ipv4_packet.next_header() == IpProtocol::Tcp {
290        let tcp_packet = TcpPacket::new_checked(ipv4_packet.payload())?;
291        let src_addr = (ipv4_packet.src_addr(), tcp_packet.src_port()).into();
292        let dst_addr = (ipv4_packet.dst_addr(), tcp_packet.dst_port()).into();
293        let is_first = tcp_packet.syn() && !tcp_packet.ack();
294        if is_first {
295            // create a socket for the first incoming TCP packet, as the later accept() returns.
296            LISTEN_TABLE.incoming_tcp_packet(src_addr, dst_addr, sockets);
297        }
298    }
299    Ok(())
300}
301
302/// Poll the network stack.
303///
304/// It may receive packets from the NIC and process them, and transmit queued
305/// packets to the NIC.
306pub fn poll_interfaces() {
307    SOCKET_SET.poll_interfaces();
308}
309
310/// Benchmark raw socket transmit bandwidth.
311pub fn bench_transmit() {
312    ETH0.dev.lock().bench_transmit_bandwidth();
313}
314
315/// Benchmark raw socket receive bandwidth.
316pub fn bench_receive() {
317    ETH0.dev.lock().bench_receive_bandwidth();
318}
319
320pub(crate) fn init(net_dev: AxNetDevice) {
321    let ether_addr = EthernetAddress(net_dev.mac_address().0);
322    let eth0 = InterfaceWrapper::new("eth0", net_dev, ether_addr);
323
324    let ip = IP.parse().expect("invalid IP address");
325    let gateway = GATEWAY.parse().expect("invalid gateway IP address");
326    eth0.setup_ip_addr(ip, IP_PREFIX);
327    eth0.setup_gateway(gateway);
328
329    ETH0.init_once(eth0);
330    SOCKET_SET.init_once(SocketSetWrapper::new());
331    LISTEN_TABLE.init_once(ListenTable::new());
332
333    info!("created net interface {:?}:", ETH0.name());
334    info!("  ether:    {}", ETH0.ethernet_address());
335    info!("  ip:       {}/{}", ip, IP_PREFIX);
336    info!("  gateway:  {}", gateway);
337}