axnet/smoltcp_impl/
tcp.rs

1use core::cell::UnsafeCell;
2use core::net::SocketAddr;
3use core::sync::atomic::{AtomicBool, AtomicU8, Ordering};
4
5use axerrno::{AxError, AxResult, ax_err, ax_err_type};
6use axio::PollState;
7use axsync::Mutex;
8
9use smoltcp::iface::SocketHandle;
10use smoltcp::socket::tcp::{self, ConnectError, State};
11use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
12
13use super::addr::UNSPECIFIED_ENDPOINT;
14use super::{ETH0, LISTEN_TABLE, SOCKET_SET, SocketSetWrapper};
15
16// State transitions:
17// CLOSED -(connect)-> BUSY -> CONNECTING -> CONNECTED -(shutdown)-> BUSY -> CLOSED
18//       |
19//       |-(listen)-> BUSY -> LISTENING -(shutdown)-> BUSY -> CLOSED
20//       |
21//        -(bind)-> BUSY -> CLOSED
22const STATE_CLOSED: u8 = 0;
23const STATE_BUSY: u8 = 1;
24const STATE_CONNECTING: u8 = 2;
25const STATE_CONNECTED: u8 = 3;
26const STATE_LISTENING: u8 = 4;
27
28/// A TCP socket that provides POSIX-like APIs.
29///
30/// - [`connect`] is for TCP clients.
31/// - [`bind`], [`listen`], and [`accept`] are for TCP servers.
32/// - Other methods are for both TCP clients and servers.
33///
34/// [`connect`]: TcpSocket::connect
35/// [`bind`]: TcpSocket::bind
36/// [`listen`]: TcpSocket::listen
37/// [`accept`]: TcpSocket::accept
38pub struct TcpSocket {
39    state: AtomicU8,
40    handle: UnsafeCell<Option<SocketHandle>>,
41    local_addr: UnsafeCell<IpEndpoint>,
42    peer_addr: UnsafeCell<IpEndpoint>,
43    nonblock: AtomicBool,
44}
45
46unsafe impl Sync for TcpSocket {}
47
48impl TcpSocket {
49    /// Creates a new TCP socket.
50    pub const fn new() -> Self {
51        Self {
52            state: AtomicU8::new(STATE_CLOSED),
53            handle: UnsafeCell::new(None),
54            local_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT),
55            peer_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT),
56            nonblock: AtomicBool::new(false),
57        }
58    }
59
60    /// Creates a new TCP socket that is already connected.
61    const fn new_connected(
62        handle: SocketHandle,
63        local_addr: IpEndpoint,
64        peer_addr: IpEndpoint,
65    ) -> Self {
66        Self {
67            state: AtomicU8::new(STATE_CONNECTED),
68            handle: UnsafeCell::new(Some(handle)),
69            local_addr: UnsafeCell::new(local_addr),
70            peer_addr: UnsafeCell::new(peer_addr),
71            nonblock: AtomicBool::new(false),
72        }
73    }
74
75    /// Returns the local address and port, or
76    /// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
77    pub fn local_addr(&self) -> AxResult<SocketAddr> {
78        match self.get_state() {
79            STATE_CONNECTED | STATE_LISTENING => {
80                Ok(SocketAddr::from(unsafe { self.local_addr.get().read() }))
81            }
82            _ => Err(AxError::NotConnected),
83        }
84    }
85
86    /// Returns the remote address and port, or
87    /// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
88    pub fn peer_addr(&self) -> AxResult<SocketAddr> {
89        match self.get_state() {
90            STATE_CONNECTED | STATE_LISTENING => {
91                Ok(SocketAddr::from(unsafe { self.peer_addr.get().read() }))
92            }
93            _ => Err(AxError::NotConnected),
94        }
95    }
96
97    /// Returns whether this socket is in nonblocking mode.
98    #[inline]
99    pub fn is_nonblocking(&self) -> bool {
100        self.nonblock.load(Ordering::Acquire)
101    }
102
103    /// Moves this TCP stream into or out of nonblocking mode.
104    ///
105    /// This will result in `read`, `write`, `recv` and `send` operations
106    /// becoming nonblocking, i.e., immediately returning from their calls.
107    /// If the IO operation is successful, `Ok` is returned and no further
108    /// action is required. If the IO operation could not be completed and needs
109    /// to be retried, an error with kind  [`Err(WouldBlock)`](AxError::WouldBlock) is
110    /// returned.
111    #[inline]
112    pub fn set_nonblocking(&self, nonblocking: bool) {
113        self.nonblock.store(nonblocking, Ordering::Release);
114    }
115
116    /// Connects to the given address and port.
117    ///
118    /// The local port is generated automatically.
119    pub fn connect(&self, remote_addr: SocketAddr) -> AxResult {
120        self.update_state(STATE_CLOSED, STATE_CONNECTING, || {
121            // SAFETY: no other threads can read or write these fields.
122            let handle = unsafe { self.handle.get().read() }
123                .unwrap_or_else(|| SOCKET_SET.add(SocketSetWrapper::new_tcp_socket()));
124
125            // TODO: check remote addr unreachable
126            let bound_endpoint = self.bound_endpoint()?;
127            let iface = &ETH0.iface;
128            let (local_endpoint, remote_endpoint) = SOCKET_SET
129                .with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
130                    socket
131                        .connect(iface.lock().context(), remote_addr, bound_endpoint)
132                        .or_else(|e| match e {
133                            ConnectError::InvalidState => {
134                                ax_err!(BadState, "socket connect() failed")
135                            }
136                            ConnectError::Unaddressable => {
137                                ax_err!(ConnectionRefused, "socket connect() failed")
138                            }
139                        })?;
140                    Ok((
141                        socket.local_endpoint().unwrap(),
142                        socket.remote_endpoint().unwrap(),
143                    ))
144                })?;
145            unsafe {
146                // SAFETY: no other threads can read or write these fields as we
147                // have changed the state to `BUSY`.
148                self.local_addr.get().write(local_endpoint);
149                self.peer_addr.get().write(remote_endpoint);
150                self.handle.get().write(Some(handle));
151            }
152            Ok(())
153        })
154        .unwrap_or_else(|_| ax_err!(AlreadyExists, "socket connect() failed: already connected"))?; // EISCONN
155
156        // Here our state must be `CONNECTING`, and only one thread can run here.
157        if self.is_nonblocking() {
158            Err(AxError::WouldBlock)
159        } else {
160            self.block_on(|| {
161                let PollState { writable, .. } = self.poll_connect()?;
162                if !writable {
163                    Err(AxError::WouldBlock)
164                } else if self.get_state() == STATE_CONNECTED {
165                    Ok(())
166                } else {
167                    ax_err!(ConnectionRefused, "socket connect() failed")
168                }
169            })
170        }
171    }
172
173    /// Binds an unbound socket to the given address and port.
174    ///
175    /// If the given port is 0, it generates one automatically.
176    ///
177    /// It's must be called before [`listen`](Self::listen) and
178    /// [`accept`](Self::accept).
179    pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
180        self.update_state(STATE_CLOSED, STATE_CLOSED, || {
181            // TODO: check addr is available
182            if local_addr.port() == 0 {
183                local_addr.set_port(get_ephemeral_port()?);
184            }
185            // SAFETY: no other threads can read or write `self.local_addr` as we
186            // have changed the state to `BUSY`.
187            unsafe {
188                let old = self.local_addr.get().read();
189                if old != UNSPECIFIED_ENDPOINT {
190                    return ax_err!(InvalidInput, "socket bind() failed: already bound");
191                }
192                self.local_addr.get().write(IpEndpoint::from(local_addr));
193            }
194            Ok(())
195        })
196        .unwrap_or_else(|_| ax_err!(InvalidInput, "socket bind() failed: already bound"))
197    }
198
199    /// Starts listening on the bound address and port.
200    ///
201    /// It's must be called after [`bind`](Self::bind) and before
202    /// [`accept`](Self::accept).
203    pub fn listen(&self) -> AxResult {
204        self.update_state(STATE_CLOSED, STATE_LISTENING, || {
205            let bound_endpoint = self.bound_endpoint()?;
206            unsafe {
207                (*self.local_addr.get()).port = bound_endpoint.port;
208            }
209            LISTEN_TABLE.listen(bound_endpoint)?;
210            debug!("TCP socket listening on {}", bound_endpoint);
211            Ok(())
212        })
213        .unwrap_or(Ok(())) // ignore simultaneous `listen`s.
214    }
215
216    /// Accepts a new connection.
217    ///
218    /// This function will block the calling thread until a new TCP connection
219    /// is established. When established, a new [`TcpSocket`] is returned.
220    ///
221    /// It's must be called after [`bind`](Self::bind) and [`listen`](Self::listen).
222    pub fn accept(&self) -> AxResult<TcpSocket> {
223        if !self.is_listening() {
224            return ax_err!(InvalidInput, "socket accept() failed: not listen");
225        }
226
227        // SAFETY: `self.local_addr` should be initialized after `bind()`.
228        let local_port = unsafe { self.local_addr.get().read().port };
229        self.block_on(|| {
230            let (handle, (local_addr, peer_addr)) = LISTEN_TABLE.accept(local_port)?;
231            debug!("TCP socket accepted a new connection {}", peer_addr);
232            Ok(TcpSocket::new_connected(handle, local_addr, peer_addr))
233        })
234    }
235
236    /// Close the connection.
237    pub fn shutdown(&self) -> AxResult {
238        // stream
239        self.update_state(STATE_CONNECTED, STATE_CLOSED, || {
240            // SAFETY: `self.handle` should be initialized in a connected socket, and
241            // no other threads can read or write it.
242            let handle = unsafe { self.handle.get().read().unwrap() };
243            SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
244                debug!("TCP socket {}: shutting down", handle);
245                socket.close();
246            });
247            unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; // clear bound address
248            SOCKET_SET.poll_interfaces();
249            Ok(())
250        })
251        .unwrap_or(Ok(()))?;
252
253        // listener
254        self.update_state(STATE_LISTENING, STATE_CLOSED, || {
255            // SAFETY: `self.local_addr` should be initialized in a listening socket,
256            // and no other threads can read or write it.
257            let local_port = unsafe { self.local_addr.get().read().port };
258            unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; // clear bound address
259            LISTEN_TABLE.unlisten(local_port);
260            SOCKET_SET.poll_interfaces();
261            Ok(())
262        })
263        .unwrap_or(Ok(()))?;
264
265        // ignore for other states
266        Ok(())
267    }
268
269    /// Receives data from the socket, stores it in the given buffer.
270    pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
271        if self.is_connecting() {
272            return Err(AxError::WouldBlock);
273        } else if !self.is_connected() {
274            return ax_err!(NotConnected, "socket recv() failed");
275        }
276
277        // SAFETY: `self.handle` should be initialized in a connected socket.
278        let handle = unsafe { self.handle.get().read().unwrap() };
279        self.block_on(|| {
280            SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
281                if !socket.is_active() {
282                    // not open
283                    ax_err!(ConnectionRefused, "socket recv() failed")
284                } else if !socket.may_recv() {
285                    // connection closed
286                    Ok(0)
287                } else if socket.recv_queue() > 0 {
288                    // data available
289                    // TODO: use socket.recv(|buf| {...})
290                    let len = socket
291                        .recv_slice(buf)
292                        .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
293                    Ok(len)
294                } else {
295                    // no more data
296                    Err(AxError::WouldBlock)
297                }
298            })
299        })
300    }
301
302    /// Transmits data in the given buffer.
303    pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
304        if self.is_connecting() {
305            return Err(AxError::WouldBlock);
306        } else if !self.is_connected() {
307            return ax_err!(NotConnected, "socket send() failed");
308        }
309
310        // SAFETY: `self.handle` should be initialized in a connected socket.
311        let handle = unsafe { self.handle.get().read().unwrap() };
312        self.block_on(|| {
313            SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
314                if !socket.is_active() || !socket.may_send() {
315                    // closed by remote
316                    ax_err!(ConnectionReset, "socket send() failed")
317                } else if socket.can_send() {
318                    // connected, and the tx buffer is not full
319                    // TODO: use socket.send(|buf| {...})
320                    let len = socket
321                        .send_slice(buf)
322                        .map_err(|_| ax_err_type!(BadState, "socket send() failed"))?;
323                    Ok(len)
324                } else {
325                    // tx buffer is full
326                    Err(AxError::WouldBlock)
327                }
328            })
329        })
330    }
331
332    /// Whether the socket is readable or writable.
333    pub fn poll(&self) -> AxResult<PollState> {
334        match self.get_state() {
335            STATE_CONNECTING => self.poll_connect(),
336            STATE_CONNECTED => self.poll_stream(),
337            STATE_LISTENING => self.poll_listener(),
338            _ => Ok(PollState {
339                readable: false,
340                writable: false,
341            }),
342        }
343    }
344
345    /// Checks if Nagle's algorithm is enabled for this TCP socket.
346    pub fn nodelay(&self) -> AxResult<bool> {
347        if let Some(h) = unsafe { self.handle.get().read() } {
348            Ok(SOCKET_SET.with_socket::<tcp::Socket, _, _>(h, |socket| socket.nagle_enabled()))
349        } else {
350            ax_err!(NotConnected, "socket is not connected")
351        }
352    }
353
354    /// Enables or disables Nagle's algorithm for this TCP socket.
355    pub fn set_nodelay(&self, enabled: bool) -> AxResult<()> {
356        if let Some(h) = unsafe { self.handle.get().read() } {
357            SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(h, |socket| {
358                socket.set_nagle_enabled(enabled);
359            });
360            Ok(())
361        } else {
362            ax_err!(NotConnected, "socket is not connected")
363        }
364    }
365
366    /// Returns the maximum capacity of the receive buffer in bytes.
367    pub fn recv_capacity(&self) -> AxResult<usize> {
368        if let Some(h) = unsafe { self.handle.get().read() } {
369            Ok(SOCKET_SET.with_socket::<tcp::Socket, _, _>(h, |socket| socket.recv_capacity()))
370        } else {
371            ax_err!(NotConnected, "socket is not connected")
372        }
373    }
374
375    /// Returns the maximum capacity of the send buffer in bytes.
376    pub fn send_capacity(&self) -> AxResult<usize> {
377        if let Some(h) = unsafe { self.handle.get().read() } {
378            Ok(SOCKET_SET.with_socket::<tcp::Socket, _, _>(h, |socket| socket.send_capacity()))
379        } else {
380            ax_err!(NotConnected, "socket is not connected")
381        }
382    }
383}
384
385/// Private methods
386impl TcpSocket {
387    #[inline]
388    fn get_state(&self) -> u8 {
389        self.state.load(Ordering::Acquire)
390    }
391
392    #[inline]
393    fn set_state(&self, state: u8) {
394        self.state.store(state, Ordering::Release);
395    }
396
397    /// Update the state of the socket atomically.
398    ///
399    /// If the current state is `expect`, it first changes the state to `STATE_BUSY`,
400    /// then calls the given function. If the function returns `Ok`, it changes the
401    /// state to `new`, otherwise it changes the state back to `expect`.
402    ///
403    /// It returns `Ok` if the current state is `expect`, otherwise it returns
404    /// the current state in `Err`.
405    fn update_state<F, T>(&self, expect: u8, new: u8, f: F) -> Result<AxResult<T>, u8>
406    where
407        F: FnOnce() -> AxResult<T>,
408    {
409        match self
410            .state
411            .compare_exchange(expect, STATE_BUSY, Ordering::Acquire, Ordering::Acquire)
412        {
413            Ok(_) => {
414                let res = f();
415                if res.is_ok() {
416                    self.set_state(new);
417                } else {
418                    self.set_state(expect);
419                }
420                Ok(res)
421            }
422            Err(old) => Err(old),
423        }
424    }
425
426    #[inline]
427    fn is_connecting(&self) -> bool {
428        self.get_state() == STATE_CONNECTING
429    }
430
431    #[inline]
432    fn is_connected(&self) -> bool {
433        self.get_state() == STATE_CONNECTED
434    }
435
436    #[inline]
437    fn is_listening(&self) -> bool {
438        self.get_state() == STATE_LISTENING
439    }
440
441    fn bound_endpoint(&self) -> AxResult<IpListenEndpoint> {
442        // SAFETY: no other threads can read or write `self.local_addr`.
443        let local_addr = unsafe { self.local_addr.get().read() };
444        let port = if local_addr.port != 0 {
445            local_addr.port
446        } else {
447            get_ephemeral_port()?
448        };
449        assert_ne!(port, 0);
450        let addr = if !local_addr.addr.is_unspecified() {
451            Some(local_addr.addr)
452        } else {
453            None
454        };
455        Ok(IpListenEndpoint { addr, port })
456    }
457
458    fn poll_connect(&self) -> AxResult<PollState> {
459        // SAFETY: `self.handle` should be initialized above.
460        let handle = unsafe { self.handle.get().read().unwrap() };
461        let writable =
462            SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| match socket.state() {
463                State::SynSent => false, // wait for connection
464                State::Established => {
465                    self.set_state(STATE_CONNECTED); // connected
466                    debug!(
467                        "TCP socket {}: connected to {}",
468                        handle,
469                        socket.remote_endpoint().unwrap(),
470                    );
471                    true
472                }
473                _ => {
474                    unsafe {
475                        self.local_addr.get().write(UNSPECIFIED_ENDPOINT);
476                        self.peer_addr.get().write(UNSPECIFIED_ENDPOINT);
477                    }
478                    self.set_state(STATE_CLOSED); // connection failed
479                    true
480                }
481            });
482        Ok(PollState {
483            readable: false,
484            writable,
485        })
486    }
487
488    fn poll_stream(&self) -> AxResult<PollState> {
489        // SAFETY: `self.handle` should be initialized in a connected socket.
490        let handle = unsafe { self.handle.get().read().unwrap() };
491        SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| {
492            Ok(PollState {
493                readable: !socket.may_recv() || socket.can_recv(),
494                writable: !socket.may_send() || socket.can_send(),
495            })
496        })
497    }
498
499    fn poll_listener(&self) -> AxResult<PollState> {
500        // SAFETY: `self.local_addr` should be initialized in a listening socket.
501        let local_addr = unsafe { self.local_addr.get().read() };
502        Ok(PollState {
503            readable: LISTEN_TABLE.can_accept(local_addr.port)?,
504            writable: false,
505        })
506    }
507
508    /// Block the current thread until the given function completes or fails.
509    ///
510    /// If the socket is non-blocking, it calls the function once and returns
511    /// immediately. Otherwise, it may call the function multiple times if it
512    /// returns [`Err(WouldBlock)`](AxError::WouldBlock).
513    fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
514    where
515        F: FnMut() -> AxResult<T>,
516    {
517        if self.is_nonblocking() {
518            f()
519        } else {
520            loop {
521                SOCKET_SET.poll_interfaces();
522                match f() {
523                    Ok(t) => return Ok(t),
524                    Err(AxError::WouldBlock) => axtask::yield_now(),
525                    Err(e) => return Err(e),
526                }
527            }
528        }
529    }
530}
531
532impl Drop for TcpSocket {
533    fn drop(&mut self) {
534        self.shutdown().ok();
535        // Safe because we have mut reference to `self`.
536        if let Some(handle) = unsafe { self.handle.get().read() } {
537            SOCKET_SET.remove(handle);
538        }
539    }
540}
541
542fn get_ephemeral_port() -> AxResult<u16> {
543    const PORT_START: u16 = 0xc000;
544    const PORT_END: u16 = 0xffff;
545    static CURR: Mutex<u16> = Mutex::new(PORT_START);
546
547    let mut curr = CURR.lock();
548    let mut tries = 0;
549    // TODO: more robust
550    while tries <= PORT_END - PORT_START {
551        let port = *curr;
552        if *curr == PORT_END {
553            *curr = PORT_START;
554        } else {
555            *curr += 1;
556        }
557        if LISTEN_TABLE.can_listen(port) {
558            return Ok(port);
559        }
560        tries += 1;
561    }
562    ax_err!(AddrInUse, "no avaliable ports!")
563}