axnet/smoltcp_impl/
udp.rs

1use core::net::SocketAddr;
2use core::sync::atomic::{AtomicBool, Ordering};
3
4use axerrno::{AxError, AxResult, ax_err, ax_err_type};
5use axio::PollState;
6use axsync::Mutex;
7use spin::RwLock;
8
9use smoltcp::iface::SocketHandle;
10use smoltcp::socket::udp::{self, BindError, SendError};
11use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
12
13use super::addr::UNSPECIFIED_ENDPOINT;
14use super::{SOCKET_SET, SocketSetWrapper};
15
16/// A UDP socket that provides POSIX-like APIs.
17pub struct UdpSocket {
18    handle: SocketHandle,
19    local_addr: RwLock<Option<IpEndpoint>>,
20    peer_addr: RwLock<Option<IpEndpoint>>,
21    nonblock: AtomicBool,
22}
23
24impl UdpSocket {
25    /// Creates a new UDP socket.
26    #[allow(clippy::new_without_default)]
27    pub fn new() -> Self {
28        let socket = SocketSetWrapper::new_udp_socket();
29        let handle = SOCKET_SET.add(socket);
30        Self {
31            handle,
32            local_addr: RwLock::new(None),
33            peer_addr: RwLock::new(None),
34            nonblock: AtomicBool::new(false),
35        }
36    }
37
38    /// Returns the local address and port, or
39    /// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
40    pub fn local_addr(&self) -> AxResult<SocketAddr> {
41        match self.local_addr.try_read() {
42            Some(addr) => addr.map(Into::into).ok_or(AxError::NotConnected),
43            None => Err(AxError::NotConnected),
44        }
45    }
46
47    /// Returns the remote address and port, or
48    /// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
49    pub fn peer_addr(&self) -> AxResult<SocketAddr> {
50        self.remote_endpoint().map(Into::into)
51    }
52
53    /// Returns whether this socket is in nonblocking mode.
54    #[inline]
55    pub fn is_nonblocking(&self) -> bool {
56        self.nonblock.load(Ordering::Acquire)
57    }
58
59    /// Moves this UDP socket into or out of nonblocking mode.
60    ///
61    /// This will result in `recv`, `recv_from`, `send`, and `send_to`
62    /// operations becoming nonblocking, i.e., immediately returning from their
63    /// calls. If the IO operation is successful, `Ok` is returned and no
64    /// further action is required. If the IO operation could not be completed
65    /// and needs to be retried, an error with kind
66    /// [`Err(WouldBlock)`](AxError::WouldBlock) is returned.
67    #[inline]
68    pub fn set_nonblocking(&self, nonblocking: bool) {
69        self.nonblock.store(nonblocking, Ordering::Release);
70    }
71
72    /// Binds an unbound socket to the given address and port.
73    ///
74    /// It's must be called before [`send_to`](Self::send_to) and
75    /// [`recv_from`](Self::recv_from).
76    pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
77        let mut self_local_addr = self.local_addr.write();
78
79        if local_addr.port() == 0 {
80            local_addr.set_port(get_ephemeral_port()?);
81        }
82        if self_local_addr.is_some() {
83            return ax_err!(InvalidInput, "socket bind() failed: already bound");
84        }
85
86        let local_endpoint = IpEndpoint::from(local_addr);
87        let endpoint = IpListenEndpoint {
88            addr: (!local_endpoint.addr.is_unspecified()).then_some(local_endpoint.addr),
89            port: local_endpoint.port,
90        };
91        SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
92            socket.bind(endpoint).or_else(|e| match e {
93                BindError::InvalidState => ax_err!(AlreadyExists, "socket bind() failed"),
94                BindError::Unaddressable => ax_err!(InvalidInput, "socket bind() failed"),
95            })
96        })?;
97
98        *self_local_addr = Some(local_endpoint);
99        debug!("UDP socket {}: bound on {}", self.handle, endpoint);
100        Ok(())
101    }
102
103    /// Sends data on the socket to the given address. On success, returns the
104    /// number of bytes written.
105    pub fn send_to(&self, buf: &[u8], remote_addr: SocketAddr) -> AxResult<usize> {
106        if remote_addr.port() == 0 || remote_addr.ip().is_unspecified() {
107            return ax_err!(InvalidInput, "socket send_to() failed: invalid address");
108        }
109        self.send_impl(buf, IpEndpoint::from(remote_addr))
110    }
111
112    /// Receives a single datagram message on the socket. On success, returns
113    /// the number of bytes read and the origin.
114    pub fn recv_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
115        self.recv_impl(|socket| match socket.recv_slice(buf) {
116            Ok((len, meta)) => Ok((len, SocketAddr::from(meta.endpoint))),
117            Err(_) => ax_err!(BadState, "socket recv_from() failed"),
118        })
119    }
120
121    /// Receives a single datagram message on the socket, without removing it from
122    /// the queue. On success, returns the number of bytes read and the origin.
123    pub fn peek_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
124        self.recv_impl(|socket| match socket.peek_slice(buf) {
125            Ok((len, meta)) => Ok((len, SocketAddr::from(meta.endpoint))),
126            Err(_) => ax_err!(BadState, "socket recv_from() failed"),
127        })
128    }
129
130    /// Connects this UDP socket to a remote address, allowing the `send` and
131    /// `recv` to be used to send data and also applies filters to only receive
132    /// data from the specified address.
133    ///
134    /// The local port will be generated automatically if the socket is not bound.
135    /// It's must be called before [`send`](Self::send) and
136    /// [`recv`](Self::recv).
137    pub fn connect(&self, addr: SocketAddr) -> AxResult {
138        let mut self_peer_addr = self.peer_addr.write();
139
140        if self.local_addr.read().is_none() {
141            self.bind(SocketAddr::from(UNSPECIFIED_ENDPOINT))?;
142        }
143
144        *self_peer_addr = Some(IpEndpoint::from(addr));
145        debug!("UDP socket {}: connected to {}", self.handle, addr);
146        Ok(())
147    }
148
149    /// Sends data on the socket to the remote address to which it is connected.
150    pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
151        let remote_endpoint = self.remote_endpoint()?;
152        self.send_impl(buf, remote_endpoint)
153    }
154
155    /// Receives a single datagram message on the socket from the remote address
156    /// to which it is connected. On success, returns the number of bytes read.
157    pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
158        let remote_endpoint = self.remote_endpoint()?;
159        self.recv_impl(|socket| {
160            let (len, meta) = socket
161                .recv_slice(buf)
162                .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
163            if !remote_endpoint.addr.is_unspecified() && remote_endpoint.addr != meta.endpoint.addr
164            {
165                return Err(AxError::WouldBlock);
166            }
167            if remote_endpoint.port != 0 && remote_endpoint.port != meta.endpoint.port {
168                return Err(AxError::WouldBlock);
169            }
170            Ok(len)
171        })
172    }
173
174    /// Close the socket.
175    pub fn shutdown(&self) -> AxResult {
176        SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
177            debug!("UDP socket {}: shutting down", self.handle);
178            socket.close();
179        });
180        SOCKET_SET.poll_interfaces();
181        Ok(())
182    }
183
184    /// Whether the socket is readable or writable.
185    pub fn poll(&self) -> AxResult<PollState> {
186        if self.local_addr.read().is_none() {
187            return Ok(PollState {
188                readable: false,
189                writable: false,
190            });
191        }
192        SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
193            Ok(PollState {
194                readable: socket.can_recv(),
195                writable: socket.can_send(),
196            })
197        })
198    }
199}
200
201/// Private methods
202impl UdpSocket {
203    fn remote_endpoint(&self) -> AxResult<IpEndpoint> {
204        match self.peer_addr.try_read() {
205            Some(addr) => addr.ok_or(AxError::NotConnected),
206            None => Err(AxError::NotConnected),
207        }
208    }
209
210    fn send_impl(&self, buf: &[u8], remote_endpoint: IpEndpoint) -> AxResult<usize> {
211        if self.local_addr.read().is_none() {
212            return ax_err!(NotConnected, "socket send() failed");
213        }
214
215        self.block_on(|| {
216            SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
217                if socket.can_send() {
218                    socket
219                        .send_slice(buf, remote_endpoint)
220                        .map_err(|e| match e {
221                            SendError::BufferFull => AxError::WouldBlock,
222                            SendError::Unaddressable => {
223                                ax_err_type!(ConnectionRefused, "socket send() failed")
224                            }
225                        })?;
226                    Ok(buf.len())
227                } else {
228                    // tx buffer is full
229                    Err(AxError::WouldBlock)
230                }
231            })
232        })
233    }
234
235    fn recv_impl<F, T>(&self, mut op: F) -> AxResult<T>
236    where
237        F: FnMut(&mut udp::Socket) -> AxResult<T>,
238    {
239        if self.local_addr.read().is_none() {
240            return ax_err!(NotConnected, "socket send() failed");
241        }
242
243        self.block_on(|| {
244            SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
245                if socket.can_recv() {
246                    // data available
247                    op(socket)
248                } else {
249                    // no more data
250                    Err(AxError::WouldBlock)
251                }
252            })
253        })
254    }
255
256    fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
257    where
258        F: FnMut() -> AxResult<T>,
259    {
260        if self.is_nonblocking() {
261            f()
262        } else {
263            loop {
264                SOCKET_SET.poll_interfaces();
265                match f() {
266                    Ok(t) => return Ok(t),
267                    Err(AxError::WouldBlock) => axtask::yield_now(),
268                    Err(e) => return Err(e),
269                }
270            }
271        }
272    }
273}
274
275impl Drop for UdpSocket {
276    fn drop(&mut self) {
277        self.shutdown().ok();
278        SOCKET_SET.remove(self.handle);
279    }
280}
281
282fn get_ephemeral_port() -> AxResult<u16> {
283    const PORT_START: u16 = 0xc000;
284    const PORT_END: u16 = 0xffff;
285    static CURR: Mutex<u16> = Mutex::new(PORT_START);
286    let mut curr = CURR.lock();
287
288    let port = *curr;
289    if *curr == PORT_END {
290        *curr = PORT_START;
291    } else {
292        *curr += 1;
293    }
294    Ok(port)
295}