axnet/smoltcp_impl/
mod.rs

1mod addr;
2mod bench;
3mod dns;
4mod listen_table;
5mod tcp;
6mod udp;
7
8use alloc::vec;
9use core::cell::RefCell;
10use core::ops::DerefMut;
11
12use axdriver::prelude::*;
13use axdriver_net::{DevError, NetBufPtr};
14use axhal::time::{NANOS_PER_MICROS, wall_time_nanos};
15use axsync::Mutex;
16use lazyinit::LazyInit;
17use smoltcp::iface::{Config, Interface, SocketHandle, SocketSet};
18use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken};
19use smoltcp::socket::{self, AnySocket};
20use smoltcp::time::Instant;
21use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress, IpCidr};
22
23use self::listen_table::ListenTable;
24
25pub use self::dns::dns_query;
26pub use self::tcp::TcpSocket;
27pub use self::udp::UdpSocket;
28
29macro_rules! env_or_default {
30    ($key:literal) => {
31        match option_env!($key) {
32            Some(val) => val,
33            None => "",
34        }
35    };
36}
37
38const IP: &str = env_or_default!("AX_IP");
39const GATEWAY: &str = env_or_default!("AX_GW");
40const DNS_SEVER: &str = "8.8.8.8";
41const IP_PREFIX: u8 = 24;
42
43const STANDARD_MTU: usize = 1500;
44
45const RANDOM_SEED: u64 = 0xA2CE_05A2_CE05_A2CE;
46
47const TCP_RX_BUF_LEN: usize = 64 * 1024;
48const TCP_TX_BUF_LEN: usize = 64 * 1024;
49const UDP_RX_BUF_LEN: usize = 64 * 1024;
50const UDP_TX_BUF_LEN: usize = 64 * 1024;
51const LISTEN_QUEUE_SIZE: usize = 512;
52
53static LISTEN_TABLE: LazyInit<ListenTable> = LazyInit::new();
54static SOCKET_SET: LazyInit<SocketSetWrapper> = LazyInit::new();
55static ETH0: LazyInit<InterfaceWrapper> = LazyInit::new();
56
57struct SocketSetWrapper<'a>(Mutex<SocketSet<'a>>);
58
59struct DeviceWrapper {
60    inner: RefCell<AxNetDevice>, // use `RefCell` is enough since it's wrapped in `Mutex` in `InterfaceWrapper`.
61}
62
63struct InterfaceWrapper {
64    name: &'static str,
65    ether_addr: EthernetAddress,
66    dev: Mutex<DeviceWrapper>,
67    iface: Mutex<Interface>,
68}
69
70impl<'a> SocketSetWrapper<'a> {
71    fn new() -> Self {
72        Self(Mutex::new(SocketSet::new(vec![])))
73    }
74
75    pub fn new_tcp_socket() -> socket::tcp::Socket<'a> {
76        let tcp_rx_buffer = socket::tcp::SocketBuffer::new(vec![0; TCP_RX_BUF_LEN]);
77        let tcp_tx_buffer = socket::tcp::SocketBuffer::new(vec![0; TCP_TX_BUF_LEN]);
78        socket::tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer)
79    }
80
81    pub fn new_udp_socket() -> socket::udp::Socket<'a> {
82        let udp_rx_buffer = socket::udp::PacketBuffer::new(
83            vec![socket::udp::PacketMetadata::EMPTY; 8],
84            vec![0; UDP_RX_BUF_LEN],
85        );
86        let udp_tx_buffer = socket::udp::PacketBuffer::new(
87            vec![socket::udp::PacketMetadata::EMPTY; 8],
88            vec![0; UDP_TX_BUF_LEN],
89        );
90        socket::udp::Socket::new(udp_rx_buffer, udp_tx_buffer)
91    }
92
93    pub fn new_dns_socket() -> socket::dns::Socket<'a> {
94        let server_addr = DNS_SEVER.parse().expect("invalid DNS server address");
95        socket::dns::Socket::new(&[server_addr], vec![])
96    }
97
98    pub fn add<T: AnySocket<'a>>(&self, socket: T) -> SocketHandle {
99        let handle = self.0.lock().add(socket);
100        debug!("socket {}: created", handle);
101        handle
102    }
103
104    pub fn with_socket<T: AnySocket<'a>, R, F>(&self, handle: SocketHandle, f: F) -> R
105    where
106        F: FnOnce(&T) -> R,
107    {
108        let set = self.0.lock();
109        let socket = set.get(handle);
110        f(socket)
111    }
112
113    pub fn with_socket_mut<T: AnySocket<'a>, R, F>(&self, handle: SocketHandle, f: F) -> R
114    where
115        F: FnOnce(&mut T) -> R,
116    {
117        let mut set = self.0.lock();
118        let socket = set.get_mut(handle);
119        f(socket)
120    }
121
122    pub fn poll_interfaces(&self) {
123        ETH0.poll(&self.0);
124    }
125
126    pub fn remove(&self, handle: SocketHandle) {
127        self.0.lock().remove(handle);
128        debug!("socket {}: destroyed", handle);
129    }
130}
131
132impl InterfaceWrapper {
133    fn new(name: &'static str, dev: AxNetDevice, ether_addr: EthernetAddress) -> Self {
134        let mut config = Config::new(HardwareAddress::Ethernet(ether_addr));
135        config.random_seed = RANDOM_SEED;
136
137        let mut dev = DeviceWrapper::new(dev);
138        let iface = Mutex::new(Interface::new(config, &mut dev, Self::current_time()));
139        Self {
140            name,
141            ether_addr,
142            dev: Mutex::new(dev),
143            iface,
144        }
145    }
146
147    fn current_time() -> Instant {
148        Instant::from_micros_const((wall_time_nanos() / NANOS_PER_MICROS) as i64)
149    }
150
151    pub fn name(&self) -> &str {
152        self.name
153    }
154
155    pub fn ethernet_address(&self) -> EthernetAddress {
156        self.ether_addr
157    }
158
159    pub fn setup_ip_addr(&self, ip: IpAddress, prefix_len: u8) {
160        let mut iface = self.iface.lock();
161        iface.update_ip_addrs(|ip_addrs| {
162            ip_addrs.push(IpCidr::new(ip, prefix_len)).unwrap();
163        });
164    }
165
166    pub fn setup_gateway(&self, gateway: IpAddress) {
167        let mut iface = self.iface.lock();
168        match gateway {
169            IpAddress::Ipv4(v4) => iface.routes_mut().add_default_ipv4_route(v4).unwrap(),
170        };
171    }
172
173    pub fn poll(&self, sockets: &Mutex<SocketSet>) {
174        let mut dev = self.dev.lock();
175        let mut iface = self.iface.lock();
176        let mut sockets = sockets.lock();
177        let timestamp = Self::current_time();
178        iface.poll(timestamp, dev.deref_mut(), &mut sockets);
179    }
180}
181
182impl DeviceWrapper {
183    fn new(inner: AxNetDevice) -> Self {
184        Self {
185            inner: RefCell::new(inner),
186        }
187    }
188}
189
190impl Device for DeviceWrapper {
191    type RxToken<'a>
192        = AxNetRxToken<'a>
193    where
194        Self: 'a;
195    type TxToken<'a>
196        = AxNetTxToken<'a>
197    where
198        Self: 'a;
199
200    fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
201        let mut dev = self.inner.borrow_mut();
202        if let Err(e) = dev.recycle_tx_buffers() {
203            warn!("recycle_tx_buffers failed: {:?}", e);
204            return None;
205        }
206
207        if !dev.can_transmit() {
208            return None;
209        }
210        let rx_buf = match dev.receive() {
211            Ok(buf) => buf,
212            Err(err) => {
213                if !matches!(err, DevError::Again) {
214                    warn!("receive failed: {:?}", err);
215                }
216                return None;
217            }
218        };
219        Some((AxNetRxToken(&self.inner, rx_buf), AxNetTxToken(&self.inner)))
220    }
221
222    fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
223        let mut dev = self.inner.borrow_mut();
224        if let Err(e) = dev.recycle_tx_buffers() {
225            warn!("recycle_tx_buffers failed: {:?}", e);
226            return None;
227        }
228        if dev.can_transmit() {
229            Some(AxNetTxToken(&self.inner))
230        } else {
231            None
232        }
233    }
234
235    fn capabilities(&self) -> DeviceCapabilities {
236        let mut caps = DeviceCapabilities::default();
237        caps.max_transmission_unit = 1514;
238        caps.max_burst_size = None;
239        caps.medium = Medium::Ethernet;
240        caps
241    }
242}
243
244struct AxNetRxToken<'a>(&'a RefCell<AxNetDevice>, NetBufPtr);
245struct AxNetTxToken<'a>(&'a RefCell<AxNetDevice>);
246
247impl RxToken for AxNetRxToken<'_> {
248    fn preprocess(&self, sockets: &mut SocketSet<'_>) {
249        snoop_tcp_packet(self.1.packet(), sockets).ok();
250    }
251
252    fn consume<R, F>(self, f: F) -> R
253    where
254        F: FnOnce(&mut [u8]) -> R,
255    {
256        let mut rx_buf = self.1;
257        trace!(
258            "RECV {} bytes: {:02X?}",
259            rx_buf.packet_len(),
260            rx_buf.packet()
261        );
262        let result = f(rx_buf.packet_mut());
263        self.0.borrow_mut().recycle_rx_buffer(rx_buf).unwrap();
264        result
265    }
266}
267
268impl TxToken for AxNetTxToken<'_> {
269    fn consume<R, F>(self, len: usize, f: F) -> R
270    where
271        F: FnOnce(&mut [u8]) -> R,
272    {
273        let mut dev = self.0.borrow_mut();
274        let mut tx_buf = dev.alloc_tx_buffer(len).unwrap();
275        let ret = f(tx_buf.packet_mut());
276        trace!("SEND {} bytes: {:02X?}", len, tx_buf.packet());
277        dev.transmit(tx_buf).unwrap();
278        ret
279    }
280}
281
282fn snoop_tcp_packet(buf: &[u8], sockets: &mut SocketSet<'_>) -> Result<(), smoltcp::wire::Error> {
283    use smoltcp::wire::{EthernetFrame, IpProtocol, Ipv4Packet, TcpPacket};
284
285    let ether_frame = EthernetFrame::new_checked(buf)?;
286    let ipv4_packet = Ipv4Packet::new_checked(ether_frame.payload())?;
287
288    if ipv4_packet.next_header() == IpProtocol::Tcp {
289        let tcp_packet = TcpPacket::new_checked(ipv4_packet.payload())?;
290        let src_addr = (ipv4_packet.src_addr(), tcp_packet.src_port()).into();
291        let dst_addr = (ipv4_packet.dst_addr(), tcp_packet.dst_port()).into();
292        let is_first = tcp_packet.syn() && !tcp_packet.ack();
293        if is_first {
294            // create a socket for the first incoming TCP packet, as the later accept() returns.
295            LISTEN_TABLE.incoming_tcp_packet(src_addr, dst_addr, sockets);
296        }
297    }
298    Ok(())
299}
300
301/// Poll the network stack.
302///
303/// It may receive packets from the NIC and process them, and transmit queued
304/// packets to the NIC.
305pub fn poll_interfaces() {
306    SOCKET_SET.poll_interfaces();
307}
308
309/// Benchmark raw socket transmit bandwidth.
310pub fn bench_transmit() {
311    ETH0.dev.lock().bench_transmit_bandwidth();
312}
313
314/// Benchmark raw socket receive bandwidth.
315pub fn bench_receive() {
316    ETH0.dev.lock().bench_receive_bandwidth();
317}
318
319pub(crate) fn init(net_dev: AxNetDevice) {
320    let ether_addr = EthernetAddress(net_dev.mac_address().0);
321    let eth0 = InterfaceWrapper::new("eth0", net_dev, ether_addr);
322
323    let ip = IP.parse().expect("invalid IP address");
324    let gateway = GATEWAY.parse().expect("invalid gateway IP address");
325    eth0.setup_ip_addr(ip, IP_PREFIX);
326    eth0.setup_gateway(gateway);
327
328    ETH0.init_once(eth0);
329    SOCKET_SET.init_once(SocketSetWrapper::new());
330    LISTEN_TABLE.init_once(ListenTable::new());
331
332    info!("created net interface {:?}:", ETH0.name());
333    info!("  ether:    {}", ETH0.ethernet_address());
334    info!("  ip:       {}/{}", ip, IP_PREFIX);
335    info!("  gateway:  {}", gateway);
336}