axnet/smoltcp_impl/
tcp.rs

1use core::cell::UnsafeCell;
2use core::net::SocketAddr;
3use core::sync::atomic::{AtomicBool, AtomicU8, Ordering};
4
5use axerrno::{AxError, AxResult, ax_err, ax_err_type};
6use axio::PollState;
7use axsync::Mutex;
8
9use smoltcp::iface::SocketHandle;
10use smoltcp::socket::tcp::{self, ConnectError, State};
11use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
12
13use super::addr::{UNSPECIFIED_ENDPOINT, from_core_sockaddr, into_core_sockaddr, is_unspecified};
14use super::{ETH0, LISTEN_TABLE, SOCKET_SET, SocketSetWrapper};
15
16// State transitions:
17// CLOSED -(connect)-> BUSY -> CONNECTING -> CONNECTED -(shutdown)-> BUSY -> CLOSED
18//       |
19//       |-(listen)-> BUSY -> LISTENING -(shutdown)-> BUSY -> CLOSED
20//       |
21//        -(bind)-> BUSY -> CLOSED
22const STATE_CLOSED: u8 = 0;
23const STATE_BUSY: u8 = 1;
24const STATE_CONNECTING: u8 = 2;
25const STATE_CONNECTED: u8 = 3;
26const STATE_LISTENING: u8 = 4;
27
28/// A TCP socket that provides POSIX-like APIs.
29///
30/// - [`connect`] is for TCP clients.
31/// - [`bind`], [`listen`], and [`accept`] are for TCP servers.
32/// - Other methods are for both TCP clients and servers.
33///
34/// [`connect`]: TcpSocket::connect
35/// [`bind`]: TcpSocket::bind
36/// [`listen`]: TcpSocket::listen
37/// [`accept`]: TcpSocket::accept
38pub struct TcpSocket {
39    state: AtomicU8,
40    handle: UnsafeCell<Option<SocketHandle>>,
41    local_addr: UnsafeCell<IpEndpoint>,
42    peer_addr: UnsafeCell<IpEndpoint>,
43    nonblock: AtomicBool,
44}
45
46unsafe impl Sync for TcpSocket {}
47
48impl TcpSocket {
49    /// Creates a new TCP socket.
50    pub const fn new() -> Self {
51        Self {
52            state: AtomicU8::new(STATE_CLOSED),
53            handle: UnsafeCell::new(None),
54            local_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT),
55            peer_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT),
56            nonblock: AtomicBool::new(false),
57        }
58    }
59
60    /// Creates a new TCP socket that is already connected.
61    const fn new_connected(
62        handle: SocketHandle,
63        local_addr: IpEndpoint,
64        peer_addr: IpEndpoint,
65    ) -> Self {
66        Self {
67            state: AtomicU8::new(STATE_CONNECTED),
68            handle: UnsafeCell::new(Some(handle)),
69            local_addr: UnsafeCell::new(local_addr),
70            peer_addr: UnsafeCell::new(peer_addr),
71            nonblock: AtomicBool::new(false),
72        }
73    }
74
75    /// Returns the local address and port, or
76    /// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
77    #[inline]
78    pub fn local_addr(&self) -> AxResult<SocketAddr> {
79        match self.get_state() {
80            STATE_CONNECTED | STATE_LISTENING => {
81                Ok(into_core_sockaddr(unsafe { self.local_addr.get().read() }))
82            }
83            _ => Err(AxError::NotConnected),
84        }
85    }
86
87    /// Returns the remote address and port, or
88    /// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
89    #[inline]
90    pub fn peer_addr(&self) -> AxResult<SocketAddr> {
91        match self.get_state() {
92            STATE_CONNECTED | STATE_LISTENING => {
93                Ok(into_core_sockaddr(unsafe { self.peer_addr.get().read() }))
94            }
95            _ => Err(AxError::NotConnected),
96        }
97    }
98
99    /// Returns whether this socket is in nonblocking mode.
100    #[inline]
101    pub fn is_nonblocking(&self) -> bool {
102        self.nonblock.load(Ordering::Acquire)
103    }
104
105    /// Moves this TCP stream into or out of nonblocking mode.
106    ///
107    /// This will result in `read`, `write`, `recv` and `send` operations
108    /// becoming nonblocking, i.e., immediately returning from their calls.
109    /// If the IO operation is successful, `Ok` is returned and no further
110    /// action is required. If the IO operation could not be completed and needs
111    /// to be retried, an error with kind  [`Err(WouldBlock)`](AxError::WouldBlock) is
112    /// returned.
113    #[inline]
114    pub fn set_nonblocking(&self, nonblocking: bool) {
115        self.nonblock.store(nonblocking, Ordering::Release);
116    }
117
118    /// Connects to the given address and port.
119    ///
120    /// The local port is generated automatically.
121    pub fn connect(&self, remote_addr: SocketAddr) -> AxResult {
122        self.update_state(STATE_CLOSED, STATE_CONNECTING, || {
123            // SAFETY: no other threads can read or write these fields.
124            let handle = unsafe { self.handle.get().read() }
125                .unwrap_or_else(|| SOCKET_SET.add(SocketSetWrapper::new_tcp_socket()));
126
127            // TODO: check remote addr unreachable
128            let remote_endpoint = from_core_sockaddr(remote_addr);
129            let bound_endpoint = self.bound_endpoint()?;
130            let iface = &ETH0.iface;
131            let (local_endpoint, remote_endpoint) = SOCKET_SET
132                .with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
133                    socket
134                        .connect(iface.lock().context(), remote_endpoint, bound_endpoint)
135                        .or_else(|e| match e {
136                            ConnectError::InvalidState => {
137                                ax_err!(BadState, "socket connect() failed")
138                            }
139                            ConnectError::Unaddressable => {
140                                ax_err!(ConnectionRefused, "socket connect() failed")
141                            }
142                        })?;
143                    Ok((
144                        socket.local_endpoint().unwrap(),
145                        socket.remote_endpoint().unwrap(),
146                    ))
147                })?;
148            unsafe {
149                // SAFETY: no other threads can read or write these fields as we
150                // have changed the state to `BUSY`.
151                self.local_addr.get().write(local_endpoint);
152                self.peer_addr.get().write(remote_endpoint);
153                self.handle.get().write(Some(handle));
154            }
155            Ok(())
156        })
157        .unwrap_or_else(|_| ax_err!(AlreadyExists, "socket connect() failed: already connected"))?; // EISCONN
158
159        // Here our state must be `CONNECTING`, and only one thread can run here.
160        if self.is_nonblocking() {
161            Err(AxError::WouldBlock)
162        } else {
163            self.block_on(|| {
164                let PollState { writable, .. } = self.poll_connect()?;
165                if !writable {
166                    Err(AxError::WouldBlock)
167                } else if self.get_state() == STATE_CONNECTED {
168                    Ok(())
169                } else {
170                    ax_err!(ConnectionRefused, "socket connect() failed")
171                }
172            })
173        }
174    }
175
176    /// Binds an unbound socket to the given address and port.
177    ///
178    /// If the given port is 0, it generates one automatically.
179    ///
180    /// It's must be called before [`listen`](Self::listen) and
181    /// [`accept`](Self::accept).
182    pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
183        self.update_state(STATE_CLOSED, STATE_CLOSED, || {
184            // TODO: check addr is available
185            if local_addr.port() == 0 {
186                local_addr.set_port(get_ephemeral_port()?);
187            }
188            // SAFETY: no other threads can read or write `self.local_addr` as we
189            // have changed the state to `BUSY`.
190            unsafe {
191                let old = self.local_addr.get().read();
192                if old != UNSPECIFIED_ENDPOINT {
193                    return ax_err!(InvalidInput, "socket bind() failed: already bound");
194                }
195                self.local_addr.get().write(from_core_sockaddr(local_addr));
196            }
197            Ok(())
198        })
199        .unwrap_or_else(|_| ax_err!(InvalidInput, "socket bind() failed: already bound"))
200    }
201
202    /// Starts listening on the bound address and port.
203    ///
204    /// It's must be called after [`bind`](Self::bind) and before
205    /// [`accept`](Self::accept).
206    pub fn listen(&self) -> AxResult {
207        self.update_state(STATE_CLOSED, STATE_LISTENING, || {
208            let bound_endpoint = self.bound_endpoint()?;
209            unsafe {
210                (*self.local_addr.get()).port = bound_endpoint.port;
211            }
212            LISTEN_TABLE.listen(bound_endpoint)?;
213            debug!("TCP socket listening on {}", bound_endpoint);
214            Ok(())
215        })
216        .unwrap_or(Ok(())) // ignore simultaneous `listen`s.
217    }
218
219    /// Accepts a new connection.
220    ///
221    /// This function will block the calling thread until a new TCP connection
222    /// is established. When established, a new [`TcpSocket`] is returned.
223    ///
224    /// It's must be called after [`bind`](Self::bind) and [`listen`](Self::listen).
225    pub fn accept(&self) -> AxResult<TcpSocket> {
226        if !self.is_listening() {
227            return ax_err!(InvalidInput, "socket accept() failed: not listen");
228        }
229
230        // SAFETY: `self.local_addr` should be initialized after `bind()`.
231        let local_port = unsafe { self.local_addr.get().read().port };
232        self.block_on(|| {
233            let (handle, (local_addr, peer_addr)) = LISTEN_TABLE.accept(local_port)?;
234            debug!("TCP socket accepted a new connection {}", peer_addr);
235            Ok(TcpSocket::new_connected(handle, local_addr, peer_addr))
236        })
237    }
238
239    /// Close the connection.
240    pub fn shutdown(&self) -> AxResult {
241        // stream
242        self.update_state(STATE_CONNECTED, STATE_CLOSED, || {
243            // SAFETY: `self.handle` should be initialized in a connected socket, and
244            // no other threads can read or write it.
245            let handle = unsafe { self.handle.get().read().unwrap() };
246            SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
247                debug!("TCP socket {}: shutting down", handle);
248                socket.close();
249            });
250            unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; // clear bound address
251            SOCKET_SET.poll_interfaces();
252            Ok(())
253        })
254        .unwrap_or(Ok(()))?;
255
256        // listener
257        self.update_state(STATE_LISTENING, STATE_CLOSED, || {
258            // SAFETY: `self.local_addr` should be initialized in a listening socket,
259            // and no other threads can read or write it.
260            let local_port = unsafe { self.local_addr.get().read().port };
261            unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; // clear bound address
262            LISTEN_TABLE.unlisten(local_port);
263            SOCKET_SET.poll_interfaces();
264            Ok(())
265        })
266        .unwrap_or(Ok(()))?;
267
268        // ignore for other states
269        Ok(())
270    }
271
272    /// Receives data from the socket, stores it in the given buffer.
273    pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
274        if self.is_connecting() {
275            return Err(AxError::WouldBlock);
276        } else if !self.is_connected() {
277            return ax_err!(NotConnected, "socket recv() failed");
278        }
279
280        // SAFETY: `self.handle` should be initialized in a connected socket.
281        let handle = unsafe { self.handle.get().read().unwrap() };
282        self.block_on(|| {
283            SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
284                if !socket.is_active() {
285                    // not open
286                    ax_err!(ConnectionRefused, "socket recv() failed")
287                } else if !socket.may_recv() {
288                    // connection closed
289                    Ok(0)
290                } else if socket.recv_queue() > 0 {
291                    // data available
292                    // TODO: use socket.recv(|buf| {...})
293                    let len = socket
294                        .recv_slice(buf)
295                        .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
296                    Ok(len)
297                } else {
298                    // no more data
299                    Err(AxError::WouldBlock)
300                }
301            })
302        })
303    }
304
305    /// Transmits data in the given buffer.
306    pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
307        if self.is_connecting() {
308            return Err(AxError::WouldBlock);
309        } else if !self.is_connected() {
310            return ax_err!(NotConnected, "socket send() failed");
311        }
312
313        // SAFETY: `self.handle` should be initialized in a connected socket.
314        let handle = unsafe { self.handle.get().read().unwrap() };
315        self.block_on(|| {
316            SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
317                if !socket.is_active() || !socket.may_send() {
318                    // closed by remote
319                    ax_err!(ConnectionReset, "socket send() failed")
320                } else if socket.can_send() {
321                    // connected, and the tx buffer is not full
322                    // TODO: use socket.send(|buf| {...})
323                    let len = socket
324                        .send_slice(buf)
325                        .map_err(|_| ax_err_type!(BadState, "socket send() failed"))?;
326                    Ok(len)
327                } else {
328                    // tx buffer is full
329                    Err(AxError::WouldBlock)
330                }
331            })
332        })
333    }
334
335    /// Whether the socket is readable or writable.
336    pub fn poll(&self) -> AxResult<PollState> {
337        match self.get_state() {
338            STATE_CONNECTING => self.poll_connect(),
339            STATE_CONNECTED => self.poll_stream(),
340            STATE_LISTENING => self.poll_listener(),
341            _ => Ok(PollState {
342                readable: false,
343                writable: false,
344            }),
345        }
346    }
347}
348
349/// Private methods
350impl TcpSocket {
351    #[inline]
352    fn get_state(&self) -> u8 {
353        self.state.load(Ordering::Acquire)
354    }
355
356    #[inline]
357    fn set_state(&self, state: u8) {
358        self.state.store(state, Ordering::Release);
359    }
360
361    /// Update the state of the socket atomically.
362    ///
363    /// If the current state is `expect`, it first changes the state to `STATE_BUSY`,
364    /// then calls the given function. If the function returns `Ok`, it changes the
365    /// state to `new`, otherwise it changes the state back to `expect`.
366    ///
367    /// It returns `Ok` if the current state is `expect`, otherwise it returns
368    /// the current state in `Err`.
369    fn update_state<F, T>(&self, expect: u8, new: u8, f: F) -> Result<AxResult<T>, u8>
370    where
371        F: FnOnce() -> AxResult<T>,
372    {
373        match self
374            .state
375            .compare_exchange(expect, STATE_BUSY, Ordering::Acquire, Ordering::Acquire)
376        {
377            Ok(_) => {
378                let res = f();
379                if res.is_ok() {
380                    self.set_state(new);
381                } else {
382                    self.set_state(expect);
383                }
384                Ok(res)
385            }
386            Err(old) => Err(old),
387        }
388    }
389
390    #[inline]
391    fn is_connecting(&self) -> bool {
392        self.get_state() == STATE_CONNECTING
393    }
394
395    #[inline]
396    fn is_connected(&self) -> bool {
397        self.get_state() == STATE_CONNECTED
398    }
399
400    #[inline]
401    fn is_listening(&self) -> bool {
402        self.get_state() == STATE_LISTENING
403    }
404
405    fn bound_endpoint(&self) -> AxResult<IpListenEndpoint> {
406        // SAFETY: no other threads can read or write `self.local_addr`.
407        let local_addr = unsafe { self.local_addr.get().read() };
408        let port = if local_addr.port != 0 {
409            local_addr.port
410        } else {
411            get_ephemeral_port()?
412        };
413        assert_ne!(port, 0);
414        let addr = if !is_unspecified(local_addr.addr) {
415            Some(local_addr.addr)
416        } else {
417            None
418        };
419        Ok(IpListenEndpoint { addr, port })
420    }
421
422    fn poll_connect(&self) -> AxResult<PollState> {
423        // SAFETY: `self.handle` should be initialized above.
424        let handle = unsafe { self.handle.get().read().unwrap() };
425        let writable =
426            SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| match socket.state() {
427                State::SynSent => false, // wait for connection
428                State::Established => {
429                    self.set_state(STATE_CONNECTED); // connected
430                    debug!(
431                        "TCP socket {}: connected to {}",
432                        handle,
433                        socket.remote_endpoint().unwrap(),
434                    );
435                    true
436                }
437                _ => {
438                    unsafe {
439                        self.local_addr.get().write(UNSPECIFIED_ENDPOINT);
440                        self.peer_addr.get().write(UNSPECIFIED_ENDPOINT);
441                    }
442                    self.set_state(STATE_CLOSED); // connection failed
443                    true
444                }
445            });
446        Ok(PollState {
447            readable: false,
448            writable,
449        })
450    }
451
452    fn poll_stream(&self) -> AxResult<PollState> {
453        // SAFETY: `self.handle` should be initialized in a connected socket.
454        let handle = unsafe { self.handle.get().read().unwrap() };
455        SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| {
456            Ok(PollState {
457                readable: !socket.may_recv() || socket.can_recv(),
458                writable: !socket.may_send() || socket.can_send(),
459            })
460        })
461    }
462
463    fn poll_listener(&self) -> AxResult<PollState> {
464        // SAFETY: `self.local_addr` should be initialized in a listening socket.
465        let local_addr = unsafe { self.local_addr.get().read() };
466        Ok(PollState {
467            readable: LISTEN_TABLE.can_accept(local_addr.port)?,
468            writable: false,
469        })
470    }
471
472    /// Block the current thread until the given function completes or fails.
473    ///
474    /// If the socket is non-blocking, it calls the function once and returns
475    /// immediately. Otherwise, it may call the function multiple times if it
476    /// returns [`Err(WouldBlock)`](AxError::WouldBlock).
477    fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
478    where
479        F: FnMut() -> AxResult<T>,
480    {
481        if self.is_nonblocking() {
482            f()
483        } else {
484            loop {
485                SOCKET_SET.poll_interfaces();
486                match f() {
487                    Ok(t) => return Ok(t),
488                    Err(AxError::WouldBlock) => axtask::yield_now(),
489                    Err(e) => return Err(e),
490                }
491            }
492        }
493    }
494}
495
496impl Drop for TcpSocket {
497    fn drop(&mut self) {
498        self.shutdown().ok();
499        // Safe because we have mut reference to `self`.
500        if let Some(handle) = unsafe { self.handle.get().read() } {
501            SOCKET_SET.remove(handle);
502        }
503    }
504}
505
506fn get_ephemeral_port() -> AxResult<u16> {
507    const PORT_START: u16 = 0xc000;
508    const PORT_END: u16 = 0xffff;
509    static CURR: Mutex<u16> = Mutex::new(PORT_START);
510
511    let mut curr = CURR.lock();
512    let mut tries = 0;
513    // TODO: more robust
514    while tries <= PORT_END - PORT_START {
515        let port = *curr;
516        if *curr == PORT_END {
517            *curr = PORT_START;
518        } else {
519            *curr += 1;
520        }
521        if LISTEN_TABLE.can_listen(port) {
522            return Ok(port);
523        }
524        tries += 1;
525    }
526    ax_err!(AddrInUse, "no avaliable ports!")
527}