1use core::net::SocketAddr;
2use core::sync::atomic::{AtomicBool, Ordering};
3
4use axerrno::{AxError, AxResult, ax_err, ax_err_type};
5use axio::PollState;
6use axsync::Mutex;
7use spin::RwLock;
8
9use smoltcp::iface::SocketHandle;
10use smoltcp::socket::udp::{self, BindError, SendError};
11use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
12
13use super::addr::{UNSPECIFIED_ENDPOINT, from_core_sockaddr, into_core_sockaddr, is_unspecified};
14use super::{SOCKET_SET, SocketSetWrapper};
15
16pub struct UdpSocket {
18 handle: SocketHandle,
19 local_addr: RwLock<Option<IpEndpoint>>,
20 peer_addr: RwLock<Option<IpEndpoint>>,
21 nonblock: AtomicBool,
22}
23
24impl UdpSocket {
25 #[allow(clippy::new_without_default)]
27 pub fn new() -> Self {
28 let socket = SocketSetWrapper::new_udp_socket();
29 let handle = SOCKET_SET.add(socket);
30 Self {
31 handle,
32 local_addr: RwLock::new(None),
33 peer_addr: RwLock::new(None),
34 nonblock: AtomicBool::new(false),
35 }
36 }
37
38 pub fn local_addr(&self) -> AxResult<SocketAddr> {
41 match self.local_addr.try_read() {
42 Some(addr) => addr.map(into_core_sockaddr).ok_or(AxError::NotConnected),
43 None => Err(AxError::NotConnected),
44 }
45 }
46
47 pub fn peer_addr(&self) -> AxResult<SocketAddr> {
50 self.remote_endpoint().map(into_core_sockaddr)
51 }
52
53 #[inline]
55 pub fn is_nonblocking(&self) -> bool {
56 self.nonblock.load(Ordering::Acquire)
57 }
58
59 #[inline]
68 pub fn set_nonblocking(&self, nonblocking: bool) {
69 self.nonblock.store(nonblocking, Ordering::Release);
70 }
71
72 pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
77 let mut self_local_addr = self.local_addr.write();
78
79 if local_addr.port() == 0 {
80 local_addr.set_port(get_ephemeral_port()?);
81 }
82 if self_local_addr.is_some() {
83 return ax_err!(InvalidInput, "socket bind() failed: already bound");
84 }
85
86 let local_endpoint = from_core_sockaddr(local_addr);
87 let endpoint = IpListenEndpoint {
88 addr: (!is_unspecified(local_endpoint.addr)).then_some(local_endpoint.addr),
89 port: local_endpoint.port,
90 };
91 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
92 socket.bind(endpoint).or_else(|e| match e {
93 BindError::InvalidState => ax_err!(AlreadyExists, "socket bind() failed"),
94 BindError::Unaddressable => ax_err!(InvalidInput, "socket bind() failed"),
95 })
96 })?;
97
98 *self_local_addr = Some(local_endpoint);
99 debug!("UDP socket {}: bound on {}", self.handle, endpoint);
100 Ok(())
101 }
102
103 pub fn send_to(&self, buf: &[u8], remote_addr: SocketAddr) -> AxResult<usize> {
106 if remote_addr.port() == 0 || remote_addr.ip().is_unspecified() {
107 return ax_err!(InvalidInput, "socket send_to() failed: invalid address");
108 }
109 self.send_impl(buf, from_core_sockaddr(remote_addr))
110 }
111
112 pub fn recv_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
115 self.recv_impl(|socket| match socket.recv_slice(buf) {
116 Ok((len, meta)) => Ok((len, into_core_sockaddr(meta.endpoint))),
117 Err(_) => ax_err!(BadState, "socket recv_from() failed"),
118 })
119 }
120
121 pub fn peek_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
124 self.recv_impl(|socket| match socket.peek_slice(buf) {
125 Ok((len, meta)) => Ok((len, into_core_sockaddr(meta.endpoint))),
126 Err(_) => ax_err!(BadState, "socket recv_from() failed"),
127 })
128 }
129
130 pub fn connect(&self, addr: SocketAddr) -> AxResult {
138 let mut self_peer_addr = self.peer_addr.write();
139
140 if self.local_addr.read().is_none() {
141 self.bind(into_core_sockaddr(UNSPECIFIED_ENDPOINT))?;
142 }
143
144 *self_peer_addr = Some(from_core_sockaddr(addr));
145 debug!("UDP socket {}: connected to {}", self.handle, addr);
146 Ok(())
147 }
148
149 pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
151 let remote_endpoint = self.remote_endpoint()?;
152 self.send_impl(buf, remote_endpoint)
153 }
154
155 pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
158 let remote_endpoint = self.remote_endpoint()?;
159 self.recv_impl(|socket| {
160 let (len, meta) = socket
161 .recv_slice(buf)
162 .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
163 if !is_unspecified(remote_endpoint.addr) && remote_endpoint.addr != meta.endpoint.addr {
164 return Err(AxError::WouldBlock);
165 }
166 if remote_endpoint.port != 0 && remote_endpoint.port != meta.endpoint.port {
167 return Err(AxError::WouldBlock);
168 }
169 Ok(len)
170 })
171 }
172
173 pub fn shutdown(&self) -> AxResult {
175 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
176 debug!("UDP socket {}: shutting down", self.handle);
177 socket.close();
178 });
179 SOCKET_SET.poll_interfaces();
180 Ok(())
181 }
182
183 pub fn poll(&self) -> AxResult<PollState> {
185 if self.local_addr.read().is_none() {
186 return Ok(PollState {
187 readable: false,
188 writable: false,
189 });
190 }
191 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
192 Ok(PollState {
193 readable: socket.can_recv(),
194 writable: socket.can_send(),
195 })
196 })
197 }
198}
199
200impl UdpSocket {
202 fn remote_endpoint(&self) -> AxResult<IpEndpoint> {
203 match self.peer_addr.try_read() {
204 Some(addr) => addr.ok_or(AxError::NotConnected),
205 None => Err(AxError::NotConnected),
206 }
207 }
208
209 fn send_impl(&self, buf: &[u8], remote_endpoint: IpEndpoint) -> AxResult<usize> {
210 if self.local_addr.read().is_none() {
211 return ax_err!(NotConnected, "socket send() failed");
212 }
213
214 self.block_on(|| {
215 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
216 if socket.can_send() {
217 socket
218 .send_slice(buf, remote_endpoint)
219 .map_err(|e| match e {
220 SendError::BufferFull => AxError::WouldBlock,
221 SendError::Unaddressable => {
222 ax_err_type!(ConnectionRefused, "socket send() failed")
223 }
224 })?;
225 Ok(buf.len())
226 } else {
227 Err(AxError::WouldBlock)
229 }
230 })
231 })
232 }
233
234 fn recv_impl<F, T>(&self, mut op: F) -> AxResult<T>
235 where
236 F: FnMut(&mut udp::Socket) -> AxResult<T>,
237 {
238 if self.local_addr.read().is_none() {
239 return ax_err!(NotConnected, "socket send() failed");
240 }
241
242 self.block_on(|| {
243 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
244 if socket.can_recv() {
245 op(socket)
247 } else {
248 Err(AxError::WouldBlock)
250 }
251 })
252 })
253 }
254
255 fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
256 where
257 F: FnMut() -> AxResult<T>,
258 {
259 if self.is_nonblocking() {
260 f()
261 } else {
262 loop {
263 SOCKET_SET.poll_interfaces();
264 match f() {
265 Ok(t) => return Ok(t),
266 Err(AxError::WouldBlock) => axtask::yield_now(),
267 Err(e) => return Err(e),
268 }
269 }
270 }
271 }
272}
273
274impl Drop for UdpSocket {
275 fn drop(&mut self) {
276 self.shutdown().ok();
277 SOCKET_SET.remove(self.handle);
278 }
279}
280
281fn get_ephemeral_port() -> AxResult<u16> {
282 const PORT_START: u16 = 0xc000;
283 const PORT_END: u16 = 0xffff;
284 static CURR: Mutex<u16> = Mutex::new(PORT_START);
285 let mut curr = CURR.lock();
286
287 let port = *curr;
288 if *curr == PORT_END {
289 *curr = PORT_START;
290 } else {
291 *curr += 1;
292 }
293 Ok(port)
294}