1use core::cell::UnsafeCell;
2use core::net::SocketAddr;
3use core::sync::atomic::{AtomicBool, AtomicU8, Ordering};
4
5use axerrno::{AxError, AxResult, ax_err, ax_err_type};
6use axio::PollState;
7use axsync::Mutex;
8
9use smoltcp::iface::SocketHandle;
10use smoltcp::socket::tcp::{self, ConnectError, State};
11use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
12
13use super::addr::UNSPECIFIED_ENDPOINT;
14use super::{ETH0, LISTEN_TABLE, SOCKET_SET, SocketSetWrapper};
15
16const STATE_CLOSED: u8 = 0;
23const STATE_BUSY: u8 = 1;
24const STATE_CONNECTING: u8 = 2;
25const STATE_CONNECTED: u8 = 3;
26const STATE_LISTENING: u8 = 4;
27
28pub struct TcpSocket {
39 state: AtomicU8,
40 handle: UnsafeCell<Option<SocketHandle>>,
41 local_addr: UnsafeCell<IpEndpoint>,
42 peer_addr: UnsafeCell<IpEndpoint>,
43 nonblock: AtomicBool,
44}
45
46unsafe impl Sync for TcpSocket {}
47
48impl TcpSocket {
49 pub const fn new() -> Self {
51 Self {
52 state: AtomicU8::new(STATE_CLOSED),
53 handle: UnsafeCell::new(None),
54 local_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT),
55 peer_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT),
56 nonblock: AtomicBool::new(false),
57 }
58 }
59
60 const fn new_connected(
62 handle: SocketHandle,
63 local_addr: IpEndpoint,
64 peer_addr: IpEndpoint,
65 ) -> Self {
66 Self {
67 state: AtomicU8::new(STATE_CONNECTED),
68 handle: UnsafeCell::new(Some(handle)),
69 local_addr: UnsafeCell::new(local_addr),
70 peer_addr: UnsafeCell::new(peer_addr),
71 nonblock: AtomicBool::new(false),
72 }
73 }
74
75 pub fn local_addr(&self) -> AxResult<SocketAddr> {
78 match self.get_state() {
79 STATE_CONNECTED | STATE_LISTENING => {
80 Ok(SocketAddr::from(unsafe { self.local_addr.get().read() }))
81 }
82 _ => Err(AxError::NotConnected),
83 }
84 }
85
86 pub fn peer_addr(&self) -> AxResult<SocketAddr> {
89 match self.get_state() {
90 STATE_CONNECTED | STATE_LISTENING => {
91 Ok(SocketAddr::from(unsafe { self.peer_addr.get().read() }))
92 }
93 _ => Err(AxError::NotConnected),
94 }
95 }
96
97 #[inline]
99 pub fn is_nonblocking(&self) -> bool {
100 self.nonblock.load(Ordering::Acquire)
101 }
102
103 #[inline]
112 pub fn set_nonblocking(&self, nonblocking: bool) {
113 self.nonblock.store(nonblocking, Ordering::Release);
114 }
115
116 pub fn connect(&self, remote_addr: SocketAddr) -> AxResult {
120 self.update_state(STATE_CLOSED, STATE_CONNECTING, || {
121 let handle = unsafe { self.handle.get().read() }
123 .unwrap_or_else(|| SOCKET_SET.add(SocketSetWrapper::new_tcp_socket()));
124
125 let bound_endpoint = self.bound_endpoint()?;
127 let iface = Ð0.iface;
128 let (local_endpoint, remote_endpoint) = SOCKET_SET
129 .with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
130 socket
131 .connect(iface.lock().context(), remote_addr, bound_endpoint)
132 .or_else(|e| match e {
133 ConnectError::InvalidState => {
134 ax_err!(BadState, "socket connect() failed")
135 }
136 ConnectError::Unaddressable => {
137 ax_err!(ConnectionRefused, "socket connect() failed")
138 }
139 })?;
140 Ok((
141 socket.local_endpoint().unwrap(),
142 socket.remote_endpoint().unwrap(),
143 ))
144 })?;
145 unsafe {
146 self.local_addr.get().write(local_endpoint);
149 self.peer_addr.get().write(remote_endpoint);
150 self.handle.get().write(Some(handle));
151 }
152 Ok(())
153 })
154 .unwrap_or_else(|_| ax_err!(AlreadyExists, "socket connect() failed: already connected"))?; if self.is_nonblocking() {
158 Err(AxError::WouldBlock)
159 } else {
160 self.block_on(|| {
161 let PollState { writable, .. } = self.poll_connect()?;
162 if !writable {
163 Err(AxError::WouldBlock)
164 } else if self.get_state() == STATE_CONNECTED {
165 Ok(())
166 } else {
167 ax_err!(ConnectionRefused, "socket connect() failed")
168 }
169 })
170 }
171 }
172
173 pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
180 self.update_state(STATE_CLOSED, STATE_CLOSED, || {
181 if local_addr.port() == 0 {
183 local_addr.set_port(get_ephemeral_port()?);
184 }
185 unsafe {
188 let old = self.local_addr.get().read();
189 if old != UNSPECIFIED_ENDPOINT {
190 return ax_err!(InvalidInput, "socket bind() failed: already bound");
191 }
192 self.local_addr.get().write(IpEndpoint::from(local_addr));
193 }
194 Ok(())
195 })
196 .unwrap_or_else(|_| ax_err!(InvalidInput, "socket bind() failed: already bound"))
197 }
198
199 pub fn listen(&self) -> AxResult {
204 self.update_state(STATE_CLOSED, STATE_LISTENING, || {
205 let bound_endpoint = self.bound_endpoint()?;
206 unsafe {
207 (*self.local_addr.get()).port = bound_endpoint.port;
208 }
209 LISTEN_TABLE.listen(bound_endpoint)?;
210 debug!("TCP socket listening on {}", bound_endpoint);
211 Ok(())
212 })
213 .unwrap_or(Ok(())) }
215
216 pub fn accept(&self) -> AxResult<TcpSocket> {
223 if !self.is_listening() {
224 return ax_err!(InvalidInput, "socket accept() failed: not listen");
225 }
226
227 let local_port = unsafe { self.local_addr.get().read().port };
229 self.block_on(|| {
230 let (handle, (local_addr, peer_addr)) = LISTEN_TABLE.accept(local_port)?;
231 debug!("TCP socket accepted a new connection {}", peer_addr);
232 Ok(TcpSocket::new_connected(handle, local_addr, peer_addr))
233 })
234 }
235
236 pub fn shutdown(&self) -> AxResult {
238 self.update_state(STATE_CONNECTED, STATE_CLOSED, || {
240 let handle = unsafe { self.handle.get().read().unwrap() };
243 SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
244 debug!("TCP socket {}: shutting down", handle);
245 socket.close();
246 });
247 unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; SOCKET_SET.poll_interfaces();
249 Ok(())
250 })
251 .unwrap_or(Ok(()))?;
252
253 self.update_state(STATE_LISTENING, STATE_CLOSED, || {
255 let local_port = unsafe { self.local_addr.get().read().port };
258 unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; LISTEN_TABLE.unlisten(local_port);
260 SOCKET_SET.poll_interfaces();
261 Ok(())
262 })
263 .unwrap_or(Ok(()))?;
264
265 Ok(())
267 }
268
269 pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
271 if self.is_connecting() {
272 return Err(AxError::WouldBlock);
273 } else if !self.is_connected() {
274 return ax_err!(NotConnected, "socket recv() failed");
275 }
276
277 let handle = unsafe { self.handle.get().read().unwrap() };
279 self.block_on(|| {
280 SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
281 if !socket.is_active() {
282 ax_err!(ConnectionRefused, "socket recv() failed")
284 } else if !socket.may_recv() {
285 Ok(0)
287 } else if socket.recv_queue() > 0 {
288 let len = socket
291 .recv_slice(buf)
292 .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
293 Ok(len)
294 } else {
295 Err(AxError::WouldBlock)
297 }
298 })
299 })
300 }
301
302 pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
304 if self.is_connecting() {
305 return Err(AxError::WouldBlock);
306 } else if !self.is_connected() {
307 return ax_err!(NotConnected, "socket send() failed");
308 }
309
310 let handle = unsafe { self.handle.get().read().unwrap() };
312 self.block_on(|| {
313 SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
314 if !socket.is_active() || !socket.may_send() {
315 ax_err!(ConnectionReset, "socket send() failed")
317 } else if socket.can_send() {
318 let len = socket
321 .send_slice(buf)
322 .map_err(|_| ax_err_type!(BadState, "socket send() failed"))?;
323 Ok(len)
324 } else {
325 Err(AxError::WouldBlock)
327 }
328 })
329 })
330 }
331
332 pub fn poll(&self) -> AxResult<PollState> {
334 match self.get_state() {
335 STATE_CONNECTING => self.poll_connect(),
336 STATE_CONNECTED => self.poll_stream(),
337 STATE_LISTENING => self.poll_listener(),
338 _ => Ok(PollState {
339 readable: false,
340 writable: false,
341 }),
342 }
343 }
344
345 pub fn nodelay(&self) -> AxResult<bool> {
347 if let Some(h) = unsafe { self.handle.get().read() } {
348 Ok(SOCKET_SET.with_socket::<tcp::Socket, _, _>(h, |socket| socket.nagle_enabled()))
349 } else {
350 ax_err!(NotConnected, "socket is not connected")
351 }
352 }
353
354 pub fn set_nodelay(&self, enabled: bool) -> AxResult<()> {
356 if let Some(h) = unsafe { self.handle.get().read() } {
357 SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(h, |socket| {
358 socket.set_nagle_enabled(enabled);
359 });
360 Ok(())
361 } else {
362 ax_err!(NotConnected, "socket is not connected")
363 }
364 }
365
366 pub fn recv_capacity(&self) -> AxResult<usize> {
368 if let Some(h) = unsafe { self.handle.get().read() } {
369 Ok(SOCKET_SET.with_socket::<tcp::Socket, _, _>(h, |socket| socket.recv_capacity()))
370 } else {
371 ax_err!(NotConnected, "socket is not connected")
372 }
373 }
374
375 pub fn send_capacity(&self) -> AxResult<usize> {
377 if let Some(h) = unsafe { self.handle.get().read() } {
378 Ok(SOCKET_SET.with_socket::<tcp::Socket, _, _>(h, |socket| socket.send_capacity()))
379 } else {
380 ax_err!(NotConnected, "socket is not connected")
381 }
382 }
383}
384
385impl TcpSocket {
387 #[inline]
388 fn get_state(&self) -> u8 {
389 self.state.load(Ordering::Acquire)
390 }
391
392 #[inline]
393 fn set_state(&self, state: u8) {
394 self.state.store(state, Ordering::Release);
395 }
396
397 fn update_state<F, T>(&self, expect: u8, new: u8, f: F) -> Result<AxResult<T>, u8>
406 where
407 F: FnOnce() -> AxResult<T>,
408 {
409 match self
410 .state
411 .compare_exchange(expect, STATE_BUSY, Ordering::Acquire, Ordering::Acquire)
412 {
413 Ok(_) => {
414 let res = f();
415 if res.is_ok() {
416 self.set_state(new);
417 } else {
418 self.set_state(expect);
419 }
420 Ok(res)
421 }
422 Err(old) => Err(old),
423 }
424 }
425
426 #[inline]
427 fn is_connecting(&self) -> bool {
428 self.get_state() == STATE_CONNECTING
429 }
430
431 #[inline]
432 fn is_connected(&self) -> bool {
433 self.get_state() == STATE_CONNECTED
434 }
435
436 #[inline]
437 fn is_listening(&self) -> bool {
438 self.get_state() == STATE_LISTENING
439 }
440
441 fn bound_endpoint(&self) -> AxResult<IpListenEndpoint> {
442 let local_addr = unsafe { self.local_addr.get().read() };
444 let port = if local_addr.port != 0 {
445 local_addr.port
446 } else {
447 get_ephemeral_port()?
448 };
449 assert_ne!(port, 0);
450 let addr = if !local_addr.addr.is_unspecified() {
451 Some(local_addr.addr)
452 } else {
453 None
454 };
455 Ok(IpListenEndpoint { addr, port })
456 }
457
458 fn poll_connect(&self) -> AxResult<PollState> {
459 let handle = unsafe { self.handle.get().read().unwrap() };
461 let writable =
462 SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| match socket.state() {
463 State::SynSent => false, State::Established => {
465 self.set_state(STATE_CONNECTED); debug!(
467 "TCP socket {}: connected to {}",
468 handle,
469 socket.remote_endpoint().unwrap(),
470 );
471 true
472 }
473 _ => {
474 unsafe {
475 self.local_addr.get().write(UNSPECIFIED_ENDPOINT);
476 self.peer_addr.get().write(UNSPECIFIED_ENDPOINT);
477 }
478 self.set_state(STATE_CLOSED); true
480 }
481 });
482 Ok(PollState {
483 readable: false,
484 writable,
485 })
486 }
487
488 fn poll_stream(&self) -> AxResult<PollState> {
489 let handle = unsafe { self.handle.get().read().unwrap() };
491 SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| {
492 Ok(PollState {
493 readable: !socket.may_recv() || socket.can_recv(),
494 writable: !socket.may_send() || socket.can_send(),
495 })
496 })
497 }
498
499 fn poll_listener(&self) -> AxResult<PollState> {
500 let local_addr = unsafe { self.local_addr.get().read() };
502 Ok(PollState {
503 readable: LISTEN_TABLE.can_accept(local_addr.port)?,
504 writable: false,
505 })
506 }
507
508 fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
514 where
515 F: FnMut() -> AxResult<T>,
516 {
517 if self.is_nonblocking() {
518 f()
519 } else {
520 loop {
521 SOCKET_SET.poll_interfaces();
522 match f() {
523 Ok(t) => return Ok(t),
524 Err(AxError::WouldBlock) => axtask::yield_now(),
525 Err(e) => return Err(e),
526 }
527 }
528 }
529 }
530}
531
532impl Drop for TcpSocket {
533 fn drop(&mut self) {
534 self.shutdown().ok();
535 if let Some(handle) = unsafe { self.handle.get().read() } {
537 SOCKET_SET.remove(handle);
538 }
539 }
540}
541
542fn get_ephemeral_port() -> AxResult<u16> {
543 const PORT_START: u16 = 0xc000;
544 const PORT_END: u16 = 0xffff;
545 static CURR: Mutex<u16> = Mutex::new(PORT_START);
546
547 let mut curr = CURR.lock();
548 let mut tries = 0;
549 while tries <= PORT_END - PORT_START {
551 let port = *curr;
552 if *curr == PORT_END {
553 *curr = PORT_START;
554 } else {
555 *curr += 1;
556 }
557 if LISTEN_TABLE.can_listen(port) {
558 return Ok(port);
559 }
560 tries += 1;
561 }
562 ax_err!(AddrInUse, "no avaliable ports!")
563}