1mod addr;
2mod bench;
3mod dns;
4mod listen_table;
5mod tcp;
6mod udp;
7
8use alloc::vec;
9use core::cell::RefCell;
10use core::ops::DerefMut;
11
12use axdriver::prelude::*;
13use axdriver_net::{DevError, NetBufPtr};
14use axhal::time::{NANOS_PER_MICROS, wall_time_nanos};
15use axsync::Mutex;
16use lazyinit::LazyInit;
17use smoltcp::iface::{Config, Interface, SocketHandle, SocketSet};
18use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken};
19use smoltcp::socket::{self, AnySocket};
20use smoltcp::time::Instant;
21use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress, IpCidr};
22
23use self::listen_table::ListenTable;
24
25pub use self::dns::dns_query;
26pub use self::tcp::TcpSocket;
27pub use self::udp::UdpSocket;
28
29macro_rules! env_or_default {
30 ($key:literal) => {
31 match option_env!($key) {
32 Some(val) => val,
33 None => "",
34 }
35 };
36}
37
38const IP: &str = env_or_default!("AX_IP");
39const GATEWAY: &str = env_or_default!("AX_GW");
40const DNS_SEVER: &str = "8.8.8.8";
41const IP_PREFIX: u8 = 24;
42
43const STANDARD_MTU: usize = 1500;
44
45const RANDOM_SEED: u64 = 0xA2CE_05A2_CE05_A2CE;
46
47const TCP_RX_BUF_LEN: usize = 64 * 1024;
48const TCP_TX_BUF_LEN: usize = 64 * 1024;
49const UDP_RX_BUF_LEN: usize = 64 * 1024;
50const UDP_TX_BUF_LEN: usize = 64 * 1024;
51const LISTEN_QUEUE_SIZE: usize = 512;
52
53static LISTEN_TABLE: LazyInit<ListenTable> = LazyInit::new();
54static SOCKET_SET: LazyInit<SocketSetWrapper> = LazyInit::new();
55static ETH0: LazyInit<InterfaceWrapper> = LazyInit::new();
56
57struct SocketSetWrapper<'a>(Mutex<SocketSet<'a>>);
58
59struct DeviceWrapper {
60 inner: RefCell<AxNetDevice>, }
62
63struct InterfaceWrapper {
64 name: &'static str,
65 ether_addr: EthernetAddress,
66 dev: Mutex<DeviceWrapper>,
67 iface: Mutex<Interface>,
68}
69
70impl<'a> SocketSetWrapper<'a> {
71 fn new() -> Self {
72 Self(Mutex::new(SocketSet::new(vec![])))
73 }
74
75 pub fn new_tcp_socket() -> socket::tcp::Socket<'a> {
76 let tcp_rx_buffer = socket::tcp::SocketBuffer::new(vec![0; TCP_RX_BUF_LEN]);
77 let tcp_tx_buffer = socket::tcp::SocketBuffer::new(vec![0; TCP_TX_BUF_LEN]);
78 socket::tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer)
79 }
80
81 pub fn new_udp_socket() -> socket::udp::Socket<'a> {
82 let udp_rx_buffer = socket::udp::PacketBuffer::new(
83 vec![socket::udp::PacketMetadata::EMPTY; 8],
84 vec![0; UDP_RX_BUF_LEN],
85 );
86 let udp_tx_buffer = socket::udp::PacketBuffer::new(
87 vec![socket::udp::PacketMetadata::EMPTY; 8],
88 vec![0; UDP_TX_BUF_LEN],
89 );
90 socket::udp::Socket::new(udp_rx_buffer, udp_tx_buffer)
91 }
92
93 pub fn new_dns_socket() -> socket::dns::Socket<'a> {
94 let server_addr = DNS_SEVER.parse().expect("invalid DNS server address");
95 socket::dns::Socket::new(&[server_addr], vec![])
96 }
97
98 pub fn add<T: AnySocket<'a>>(&self, socket: T) -> SocketHandle {
99 let handle = self.0.lock().add(socket);
100 debug!("socket {}: created", handle);
101 handle
102 }
103
104 pub fn with_socket<T: AnySocket<'a>, R, F>(&self, handle: SocketHandle, f: F) -> R
105 where
106 F: FnOnce(&T) -> R,
107 {
108 let set = self.0.lock();
109 let socket = set.get(handle);
110 f(socket)
111 }
112
113 pub fn with_socket_mut<T: AnySocket<'a>, R, F>(&self, handle: SocketHandle, f: F) -> R
114 where
115 F: FnOnce(&mut T) -> R,
116 {
117 let mut set = self.0.lock();
118 let socket = set.get_mut(handle);
119 f(socket)
120 }
121
122 pub fn poll_interfaces(&self) {
123 ETH0.poll(&self.0);
124 }
125
126 pub fn remove(&self, handle: SocketHandle) {
127 self.0.lock().remove(handle);
128 debug!("socket {}: destroyed", handle);
129 }
130}
131
132impl InterfaceWrapper {
133 fn new(name: &'static str, dev: AxNetDevice, ether_addr: EthernetAddress) -> Self {
134 let mut config = Config::new(HardwareAddress::Ethernet(ether_addr));
135 config.random_seed = RANDOM_SEED;
136
137 let mut dev = DeviceWrapper::new(dev);
138 let iface = Mutex::new(Interface::new(config, &mut dev, Self::current_time()));
139 Self {
140 name,
141 ether_addr,
142 dev: Mutex::new(dev),
143 iface,
144 }
145 }
146
147 fn current_time() -> Instant {
148 Instant::from_micros_const((wall_time_nanos() / NANOS_PER_MICROS) as i64)
149 }
150
151 pub fn name(&self) -> &str {
152 self.name
153 }
154
155 pub fn ethernet_address(&self) -> EthernetAddress {
156 self.ether_addr
157 }
158
159 pub fn setup_ip_addr(&self, ip: IpAddress, prefix_len: u8) {
160 let mut iface = self.iface.lock();
161 iface.update_ip_addrs(|ip_addrs| {
162 ip_addrs.push(IpCidr::new(ip, prefix_len)).unwrap();
163 });
164 }
165
166 pub fn setup_gateway(&self, gateway: IpAddress) {
167 let mut iface = self.iface.lock();
168 match gateway {
169 IpAddress::Ipv4(v4) => iface.routes_mut().add_default_ipv4_route(v4).unwrap(),
170 IpAddress::Ipv6(v6) => iface.routes_mut().add_default_ipv6_route(v6).unwrap(),
171 };
172 }
173
174 pub fn poll(&self, sockets: &Mutex<SocketSet>) {
175 let mut dev = self.dev.lock();
176 let mut iface = self.iface.lock();
177 let mut sockets = sockets.lock();
178 let timestamp = Self::current_time();
179 iface.poll(timestamp, dev.deref_mut(), &mut sockets);
180 }
181}
182
183impl DeviceWrapper {
184 fn new(inner: AxNetDevice) -> Self {
185 Self {
186 inner: RefCell::new(inner),
187 }
188 }
189}
190
191impl Device for DeviceWrapper {
192 type RxToken<'a>
193 = AxNetRxToken<'a>
194 where
195 Self: 'a;
196 type TxToken<'a>
197 = AxNetTxToken<'a>
198 where
199 Self: 'a;
200
201 fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
202 let mut dev = self.inner.borrow_mut();
203 if let Err(e) = dev.recycle_tx_buffers() {
204 warn!("recycle_tx_buffers failed: {:?}", e);
205 return None;
206 }
207
208 if !dev.can_transmit() {
209 return None;
210 }
211 let rx_buf = match dev.receive() {
212 Ok(buf) => buf,
213 Err(err) => {
214 if !matches!(err, DevError::Again) {
215 warn!("receive failed: {:?}", err);
216 }
217 return None;
218 }
219 };
220 Some((AxNetRxToken(&self.inner, rx_buf), AxNetTxToken(&self.inner)))
221 }
222
223 fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
224 let mut dev = self.inner.borrow_mut();
225 if let Err(e) = dev.recycle_tx_buffers() {
226 warn!("recycle_tx_buffers failed: {:?}", e);
227 return None;
228 }
229 if dev.can_transmit() {
230 Some(AxNetTxToken(&self.inner))
231 } else {
232 None
233 }
234 }
235
236 fn capabilities(&self) -> DeviceCapabilities {
237 let mut caps = DeviceCapabilities::default();
238 caps.max_transmission_unit = 1514;
239 caps.max_burst_size = None;
240 caps.medium = Medium::Ethernet;
241 caps
242 }
243}
244
245struct AxNetRxToken<'a>(&'a RefCell<AxNetDevice>, NetBufPtr);
246struct AxNetTxToken<'a>(&'a RefCell<AxNetDevice>);
247
248impl RxToken for AxNetRxToken<'_> {
249 fn preprocess(&self, sockets: &mut SocketSet<'_>) {
250 snoop_tcp_packet(self.1.packet(), sockets).ok();
251 }
252
253 fn consume<R, F>(self, f: F) -> R
254 where
255 F: FnOnce(&[u8]) -> R,
256 {
257 let rx_buf = self.1;
258 trace!(
259 "RECV {} bytes: {:02X?}",
260 rx_buf.packet_len(),
261 rx_buf.packet()
262 );
263 let result = f(rx_buf.packet());
264 self.0.borrow_mut().recycle_rx_buffer(rx_buf).unwrap();
265 result
266 }
267}
268
269impl TxToken for AxNetTxToken<'_> {
270 fn consume<R, F>(self, len: usize, f: F) -> R
271 where
272 F: FnOnce(&mut [u8]) -> R,
273 {
274 let mut dev = self.0.borrow_mut();
275 let mut tx_buf = dev.alloc_tx_buffer(len).unwrap();
276 let ret = f(tx_buf.packet_mut());
277 trace!("SEND {} bytes: {:02X?}", len, tx_buf.packet());
278 dev.transmit(tx_buf).unwrap();
279 ret
280 }
281}
282
283fn snoop_tcp_packet(buf: &[u8], sockets: &mut SocketSet<'_>) -> Result<(), smoltcp::wire::Error> {
284 use smoltcp::wire::{EthernetFrame, IpProtocol, Ipv4Packet, TcpPacket};
285
286 let ether_frame = EthernetFrame::new_checked(buf)?;
287 let ipv4_packet = Ipv4Packet::new_checked(ether_frame.payload())?;
288
289 if ipv4_packet.next_header() == IpProtocol::Tcp {
290 let tcp_packet = TcpPacket::new_checked(ipv4_packet.payload())?;
291 let src_addr = (ipv4_packet.src_addr(), tcp_packet.src_port()).into();
292 let dst_addr = (ipv4_packet.dst_addr(), tcp_packet.dst_port()).into();
293 let is_first = tcp_packet.syn() && !tcp_packet.ack();
294 if is_first {
295 LISTEN_TABLE.incoming_tcp_packet(src_addr, dst_addr, sockets);
297 }
298 }
299 Ok(())
300}
301
302pub fn poll_interfaces() {
307 SOCKET_SET.poll_interfaces();
308}
309
310pub fn bench_transmit() {
312 ETH0.dev.lock().bench_transmit_bandwidth();
313}
314
315pub fn bench_receive() {
317 ETH0.dev.lock().bench_receive_bandwidth();
318}
319
320pub(crate) fn init(net_dev: AxNetDevice) {
321 let ether_addr = EthernetAddress(net_dev.mac_address().0);
322 let eth0 = InterfaceWrapper::new("eth0", net_dev, ether_addr);
323
324 let ip = IP.parse().expect("invalid IP address");
325 let gateway = GATEWAY.parse().expect("invalid gateway IP address");
326 eth0.setup_ip_addr(ip, IP_PREFIX);
327 eth0.setup_gateway(gateway);
328
329 ETH0.init_once(eth0);
330 SOCKET_SET.init_once(SocketSetWrapper::new());
331 LISTEN_TABLE.init_once(ListenTable::new());
332
333 info!("created net interface {:?}:", ETH0.name());
334 info!(" ether: {}", ETH0.ethernet_address());
335 info!(" ip: {}/{}", ip, IP_PREFIX);
336 info!(" gateway: {}", gateway);
337}