1use core::net::SocketAddr;
2use core::sync::atomic::{AtomicBool, Ordering};
3
4use axerrno::{AxError, AxResult, ax_err, ax_err_type};
5use axio::PollState;
6use axsync::Mutex;
7use spin::RwLock;
8
9use smoltcp::iface::SocketHandle;
10use smoltcp::socket::udp::{self, BindError, SendError};
11use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
12
13use super::addr::UNSPECIFIED_ENDPOINT;
14use super::{SOCKET_SET, SocketSetWrapper};
15
16pub struct UdpSocket {
18 handle: SocketHandle,
19 local_addr: RwLock<Option<IpEndpoint>>,
20 peer_addr: RwLock<Option<IpEndpoint>>,
21 nonblock: AtomicBool,
22}
23
24impl UdpSocket {
25 #[allow(clippy::new_without_default)]
27 pub fn new() -> Self {
28 let socket = SocketSetWrapper::new_udp_socket();
29 let handle = SOCKET_SET.add(socket);
30 Self {
31 handle,
32 local_addr: RwLock::new(None),
33 peer_addr: RwLock::new(None),
34 nonblock: AtomicBool::new(false),
35 }
36 }
37
38 pub fn local_addr(&self) -> AxResult<SocketAddr> {
41 match self.local_addr.try_read() {
42 Some(addr) => addr.map(Into::into).ok_or(AxError::NotConnected),
43 None => Err(AxError::NotConnected),
44 }
45 }
46
47 pub fn peer_addr(&self) -> AxResult<SocketAddr> {
50 self.remote_endpoint().map(Into::into)
51 }
52
53 #[inline]
55 pub fn is_nonblocking(&self) -> bool {
56 self.nonblock.load(Ordering::Acquire)
57 }
58
59 #[inline]
68 pub fn set_nonblocking(&self, nonblocking: bool) {
69 self.nonblock.store(nonblocking, Ordering::Release);
70 }
71
72 pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
77 let mut self_local_addr = self.local_addr.write();
78
79 if local_addr.port() == 0 {
80 local_addr.set_port(get_ephemeral_port()?);
81 }
82 if self_local_addr.is_some() {
83 return ax_err!(InvalidInput, "socket bind() failed: already bound");
84 }
85
86 let local_endpoint = IpEndpoint::from(local_addr);
87 let endpoint = IpListenEndpoint {
88 addr: (!local_endpoint.addr.is_unspecified()).then_some(local_endpoint.addr),
89 port: local_endpoint.port,
90 };
91 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
92 socket.bind(endpoint).or_else(|e| match e {
93 BindError::InvalidState => ax_err!(AlreadyExists, "socket bind() failed"),
94 BindError::Unaddressable => ax_err!(InvalidInput, "socket bind() failed"),
95 })
96 })?;
97
98 *self_local_addr = Some(local_endpoint);
99 debug!("UDP socket {}: bound on {}", self.handle, endpoint);
100 Ok(())
101 }
102
103 pub fn send_to(&self, buf: &[u8], remote_addr: SocketAddr) -> AxResult<usize> {
106 if remote_addr.port() == 0 || remote_addr.ip().is_unspecified() {
107 return ax_err!(InvalidInput, "socket send_to() failed: invalid address");
108 }
109 self.send_impl(buf, IpEndpoint::from(remote_addr))
110 }
111
112 pub fn recv_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
115 self.recv_impl(|socket| match socket.recv_slice(buf) {
116 Ok((len, meta)) => Ok((len, SocketAddr::from(meta.endpoint))),
117 Err(_) => ax_err!(BadState, "socket recv_from() failed"),
118 })
119 }
120
121 pub fn peek_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
124 self.recv_impl(|socket| match socket.peek_slice(buf) {
125 Ok((len, meta)) => Ok((len, SocketAddr::from(meta.endpoint))),
126 Err(_) => ax_err!(BadState, "socket recv_from() failed"),
127 })
128 }
129
130 pub fn connect(&self, addr: SocketAddr) -> AxResult {
138 let mut self_peer_addr = self.peer_addr.write();
139
140 if self.local_addr.read().is_none() {
141 self.bind(SocketAddr::from(UNSPECIFIED_ENDPOINT))?;
142 }
143
144 *self_peer_addr = Some(IpEndpoint::from(addr));
145 debug!("UDP socket {}: connected to {}", self.handle, addr);
146 Ok(())
147 }
148
149 pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
151 let remote_endpoint = self.remote_endpoint()?;
152 self.send_impl(buf, remote_endpoint)
153 }
154
155 pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
158 let remote_endpoint = self.remote_endpoint()?;
159 self.recv_impl(|socket| {
160 let (len, meta) = socket
161 .recv_slice(buf)
162 .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
163 if !remote_endpoint.addr.is_unspecified() && remote_endpoint.addr != meta.endpoint.addr
164 {
165 return Err(AxError::WouldBlock);
166 }
167 if remote_endpoint.port != 0 && remote_endpoint.port != meta.endpoint.port {
168 return Err(AxError::WouldBlock);
169 }
170 Ok(len)
171 })
172 }
173
174 pub fn shutdown(&self) -> AxResult {
176 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
177 debug!("UDP socket {}: shutting down", self.handle);
178 socket.close();
179 });
180 SOCKET_SET.poll_interfaces();
181 Ok(())
182 }
183
184 pub fn poll(&self) -> AxResult<PollState> {
186 if self.local_addr.read().is_none() {
187 return Ok(PollState {
188 readable: false,
189 writable: false,
190 });
191 }
192 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
193 Ok(PollState {
194 readable: socket.can_recv(),
195 writable: socket.can_send(),
196 })
197 })
198 }
199}
200
201impl UdpSocket {
203 fn remote_endpoint(&self) -> AxResult<IpEndpoint> {
204 match self.peer_addr.try_read() {
205 Some(addr) => addr.ok_or(AxError::NotConnected),
206 None => Err(AxError::NotConnected),
207 }
208 }
209
210 fn send_impl(&self, buf: &[u8], remote_endpoint: IpEndpoint) -> AxResult<usize> {
211 if self.local_addr.read().is_none() {
212 return ax_err!(NotConnected, "socket send() failed");
213 }
214
215 self.block_on(|| {
216 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
217 if socket.can_send() {
218 socket
219 .send_slice(buf, remote_endpoint)
220 .map_err(|e| match e {
221 SendError::BufferFull => AxError::WouldBlock,
222 SendError::Unaddressable => {
223 ax_err_type!(ConnectionRefused, "socket send() failed")
224 }
225 })?;
226 Ok(buf.len())
227 } else {
228 Err(AxError::WouldBlock)
230 }
231 })
232 })
233 }
234
235 fn recv_impl<F, T>(&self, mut op: F) -> AxResult<T>
236 where
237 F: FnMut(&mut udp::Socket) -> AxResult<T>,
238 {
239 if self.local_addr.read().is_none() {
240 return ax_err!(NotConnected, "socket send() failed");
241 }
242
243 self.block_on(|| {
244 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
245 if socket.can_recv() {
246 op(socket)
248 } else {
249 Err(AxError::WouldBlock)
251 }
252 })
253 })
254 }
255
256 fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
257 where
258 F: FnMut() -> AxResult<T>,
259 {
260 if self.is_nonblocking() {
261 f()
262 } else {
263 loop {
264 SOCKET_SET.poll_interfaces();
265 match f() {
266 Ok(t) => return Ok(t),
267 Err(AxError::WouldBlock) => axtask::yield_now(),
268 Err(e) => return Err(e),
269 }
270 }
271 }
272 }
273}
274
275impl Drop for UdpSocket {
276 fn drop(&mut self) {
277 self.shutdown().ok();
278 SOCKET_SET.remove(self.handle);
279 }
280}
281
282fn get_ephemeral_port() -> AxResult<u16> {
283 const PORT_START: u16 = 0xc000;
284 const PORT_END: u16 = 0xffff;
285 static CURR: Mutex<u16> = Mutex::new(PORT_START);
286 let mut curr = CURR.lock();
287
288 let port = *curr;
289 if *curr == PORT_END {
290 *curr = PORT_START;
291 } else {
292 *curr += 1;
293 }
294 Ok(port)
295}