axnet/smoltcp_impl/
udp.rs

1use core::net::SocketAddr;
2use core::sync::atomic::{AtomicBool, Ordering};
3
4use axerrno::{AxError, AxResult, ax_err, ax_err_type};
5use axio::PollState;
6use axsync::Mutex;
7use spin::RwLock;
8
9use smoltcp::iface::SocketHandle;
10use smoltcp::socket::udp::{self, BindError, SendError};
11use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
12
13use super::addr::{UNSPECIFIED_ENDPOINT, from_core_sockaddr, into_core_sockaddr, is_unspecified};
14use super::{SOCKET_SET, SocketSetWrapper};
15
16/// A UDP socket that provides POSIX-like APIs.
17pub struct UdpSocket {
18    handle: SocketHandle,
19    local_addr: RwLock<Option<IpEndpoint>>,
20    peer_addr: RwLock<Option<IpEndpoint>>,
21    nonblock: AtomicBool,
22}
23
24impl UdpSocket {
25    /// Creates a new UDP socket.
26    #[allow(clippy::new_without_default)]
27    pub fn new() -> Self {
28        let socket = SocketSetWrapper::new_udp_socket();
29        let handle = SOCKET_SET.add(socket);
30        Self {
31            handle,
32            local_addr: RwLock::new(None),
33            peer_addr: RwLock::new(None),
34            nonblock: AtomicBool::new(false),
35        }
36    }
37
38    /// Returns the local address and port, or
39    /// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
40    pub fn local_addr(&self) -> AxResult<SocketAddr> {
41        match self.local_addr.try_read() {
42            Some(addr) => addr.map(into_core_sockaddr).ok_or(AxError::NotConnected),
43            None => Err(AxError::NotConnected),
44        }
45    }
46
47    /// Returns the remote address and port, or
48    /// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
49    pub fn peer_addr(&self) -> AxResult<SocketAddr> {
50        self.remote_endpoint().map(into_core_sockaddr)
51    }
52
53    /// Returns whether this socket is in nonblocking mode.
54    #[inline]
55    pub fn is_nonblocking(&self) -> bool {
56        self.nonblock.load(Ordering::Acquire)
57    }
58
59    /// Moves this UDP socket into or out of nonblocking mode.
60    ///
61    /// This will result in `recv`, `recv_from`, `send`, and `send_to`
62    /// operations becoming nonblocking, i.e., immediately returning from their
63    /// calls. If the IO operation is successful, `Ok` is returned and no
64    /// further action is required. If the IO operation could not be completed
65    /// and needs to be retried, an error with kind
66    /// [`Err(WouldBlock)`](AxError::WouldBlock) is returned.
67    #[inline]
68    pub fn set_nonblocking(&self, nonblocking: bool) {
69        self.nonblock.store(nonblocking, Ordering::Release);
70    }
71
72    /// Binds an unbound socket to the given address and port.
73    ///
74    /// It's must be called before [`send_to`](Self::send_to) and
75    /// [`recv_from`](Self::recv_from).
76    pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
77        let mut self_local_addr = self.local_addr.write();
78
79        if local_addr.port() == 0 {
80            local_addr.set_port(get_ephemeral_port()?);
81        }
82        if self_local_addr.is_some() {
83            return ax_err!(InvalidInput, "socket bind() failed: already bound");
84        }
85
86        let local_endpoint = from_core_sockaddr(local_addr);
87        let endpoint = IpListenEndpoint {
88            addr: (!is_unspecified(local_endpoint.addr)).then_some(local_endpoint.addr),
89            port: local_endpoint.port,
90        };
91        SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
92            socket.bind(endpoint).or_else(|e| match e {
93                BindError::InvalidState => ax_err!(AlreadyExists, "socket bind() failed"),
94                BindError::Unaddressable => ax_err!(InvalidInput, "socket bind() failed"),
95            })
96        })?;
97
98        *self_local_addr = Some(local_endpoint);
99        debug!("UDP socket {}: bound on {}", self.handle, endpoint);
100        Ok(())
101    }
102
103    /// Sends data on the socket to the given address. On success, returns the
104    /// number of bytes written.
105    pub fn send_to(&self, buf: &[u8], remote_addr: SocketAddr) -> AxResult<usize> {
106        if remote_addr.port() == 0 || remote_addr.ip().is_unspecified() {
107            return ax_err!(InvalidInput, "socket send_to() failed: invalid address");
108        }
109        self.send_impl(buf, from_core_sockaddr(remote_addr))
110    }
111
112    /// Receives a single datagram message on the socket. On success, returns
113    /// the number of bytes read and the origin.
114    pub fn recv_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
115        self.recv_impl(|socket| match socket.recv_slice(buf) {
116            Ok((len, meta)) => Ok((len, into_core_sockaddr(meta.endpoint))),
117            Err(_) => ax_err!(BadState, "socket recv_from() failed"),
118        })
119    }
120
121    /// Receives a single datagram message on the socket, without removing it from
122    /// the queue. On success, returns the number of bytes read and the origin.
123    pub fn peek_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
124        self.recv_impl(|socket| match socket.peek_slice(buf) {
125            Ok((len, meta)) => Ok((len, into_core_sockaddr(meta.endpoint))),
126            Err(_) => ax_err!(BadState, "socket recv_from() failed"),
127        })
128    }
129
130    /// Connects this UDP socket to a remote address, allowing the `send` and
131    /// `recv` to be used to send data and also applies filters to only receive
132    /// data from the specified address.
133    ///
134    /// The local port will be generated automatically if the socket is not bound.
135    /// It's must be called before [`send`](Self::send) and
136    /// [`recv`](Self::recv).
137    pub fn connect(&self, addr: SocketAddr) -> AxResult {
138        let mut self_peer_addr = self.peer_addr.write();
139
140        if self.local_addr.read().is_none() {
141            self.bind(into_core_sockaddr(UNSPECIFIED_ENDPOINT))?;
142        }
143
144        *self_peer_addr = Some(from_core_sockaddr(addr));
145        debug!("UDP socket {}: connected to {}", self.handle, addr);
146        Ok(())
147    }
148
149    /// Sends data on the socket to the remote address to which it is connected.
150    pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
151        let remote_endpoint = self.remote_endpoint()?;
152        self.send_impl(buf, remote_endpoint)
153    }
154
155    /// Receives a single datagram message on the socket from the remote address
156    /// to which it is connected. On success, returns the number of bytes read.
157    pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
158        let remote_endpoint = self.remote_endpoint()?;
159        self.recv_impl(|socket| {
160            let (len, meta) = socket
161                .recv_slice(buf)
162                .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
163            if !is_unspecified(remote_endpoint.addr) && remote_endpoint.addr != meta.endpoint.addr {
164                return Err(AxError::WouldBlock);
165            }
166            if remote_endpoint.port != 0 && remote_endpoint.port != meta.endpoint.port {
167                return Err(AxError::WouldBlock);
168            }
169            Ok(len)
170        })
171    }
172
173    /// Close the socket.
174    pub fn shutdown(&self) -> AxResult {
175        SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
176            debug!("UDP socket {}: shutting down", self.handle);
177            socket.close();
178        });
179        SOCKET_SET.poll_interfaces();
180        Ok(())
181    }
182
183    /// Whether the socket is readable or writable.
184    pub fn poll(&self) -> AxResult<PollState> {
185        if self.local_addr.read().is_none() {
186            return Ok(PollState {
187                readable: false,
188                writable: false,
189            });
190        }
191        SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
192            Ok(PollState {
193                readable: socket.can_recv(),
194                writable: socket.can_send(),
195            })
196        })
197    }
198}
199
200/// Private methods
201impl UdpSocket {
202    fn remote_endpoint(&self) -> AxResult<IpEndpoint> {
203        match self.peer_addr.try_read() {
204            Some(addr) => addr.ok_or(AxError::NotConnected),
205            None => Err(AxError::NotConnected),
206        }
207    }
208
209    fn send_impl(&self, buf: &[u8], remote_endpoint: IpEndpoint) -> AxResult<usize> {
210        if self.local_addr.read().is_none() {
211            return ax_err!(NotConnected, "socket send() failed");
212        }
213
214        self.block_on(|| {
215            SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
216                if socket.can_send() {
217                    socket
218                        .send_slice(buf, remote_endpoint)
219                        .map_err(|e| match e {
220                            SendError::BufferFull => AxError::WouldBlock,
221                            SendError::Unaddressable => {
222                                ax_err_type!(ConnectionRefused, "socket send() failed")
223                            }
224                        })?;
225                    Ok(buf.len())
226                } else {
227                    // tx buffer is full
228                    Err(AxError::WouldBlock)
229                }
230            })
231        })
232    }
233
234    fn recv_impl<F, T>(&self, mut op: F) -> AxResult<T>
235    where
236        F: FnMut(&mut udp::Socket) -> AxResult<T>,
237    {
238        if self.local_addr.read().is_none() {
239            return ax_err!(NotConnected, "socket send() failed");
240        }
241
242        self.block_on(|| {
243            SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
244                if socket.can_recv() {
245                    // data available
246                    op(socket)
247                } else {
248                    // no more data
249                    Err(AxError::WouldBlock)
250                }
251            })
252        })
253    }
254
255    fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
256    where
257        F: FnMut() -> AxResult<T>,
258    {
259        if self.is_nonblocking() {
260            f()
261        } else {
262            loop {
263                SOCKET_SET.poll_interfaces();
264                match f() {
265                    Ok(t) => return Ok(t),
266                    Err(AxError::WouldBlock) => axtask::yield_now(),
267                    Err(e) => return Err(e),
268                }
269            }
270        }
271    }
272}
273
274impl Drop for UdpSocket {
275    fn drop(&mut self) {
276        self.shutdown().ok();
277        SOCKET_SET.remove(self.handle);
278    }
279}
280
281fn get_ephemeral_port() -> AxResult<u16> {
282    const PORT_START: u16 = 0xc000;
283    const PORT_END: u16 = 0xffff;
284    static CURR: Mutex<u16> = Mutex::new(PORT_START);
285    let mut curr = CURR.lock();
286
287    let port = *curr;
288    if *curr == PORT_END {
289        *curr = PORT_START;
290    } else {
291        *curr += 1;
292    }
293    Ok(port)
294}