1use core::cell::UnsafeCell;
2use core::net::SocketAddr;
3use core::sync::atomic::{AtomicBool, AtomicU8, Ordering};
4
5use axerrno::{AxError, AxResult, ax_err, ax_err_type};
6use axio::PollState;
7use axsync::Mutex;
8
9use smoltcp::iface::SocketHandle;
10use smoltcp::socket::tcp::{self, ConnectError, State};
11use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
12
13use super::addr::{UNSPECIFIED_ENDPOINT, from_core_sockaddr, into_core_sockaddr, is_unspecified};
14use super::{ETH0, LISTEN_TABLE, SOCKET_SET, SocketSetWrapper};
15
16const STATE_CLOSED: u8 = 0;
23const STATE_BUSY: u8 = 1;
24const STATE_CONNECTING: u8 = 2;
25const STATE_CONNECTED: u8 = 3;
26const STATE_LISTENING: u8 = 4;
27
28pub struct TcpSocket {
39 state: AtomicU8,
40 handle: UnsafeCell<Option<SocketHandle>>,
41 local_addr: UnsafeCell<IpEndpoint>,
42 peer_addr: UnsafeCell<IpEndpoint>,
43 nonblock: AtomicBool,
44}
45
46unsafe impl Sync for TcpSocket {}
47
48impl TcpSocket {
49 pub const fn new() -> Self {
51 Self {
52 state: AtomicU8::new(STATE_CLOSED),
53 handle: UnsafeCell::new(None),
54 local_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT),
55 peer_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT),
56 nonblock: AtomicBool::new(false),
57 }
58 }
59
60 const fn new_connected(
62 handle: SocketHandle,
63 local_addr: IpEndpoint,
64 peer_addr: IpEndpoint,
65 ) -> Self {
66 Self {
67 state: AtomicU8::new(STATE_CONNECTED),
68 handle: UnsafeCell::new(Some(handle)),
69 local_addr: UnsafeCell::new(local_addr),
70 peer_addr: UnsafeCell::new(peer_addr),
71 nonblock: AtomicBool::new(false),
72 }
73 }
74
75 #[inline]
78 pub fn local_addr(&self) -> AxResult<SocketAddr> {
79 match self.get_state() {
80 STATE_CONNECTED | STATE_LISTENING => {
81 Ok(into_core_sockaddr(unsafe { self.local_addr.get().read() }))
82 }
83 _ => Err(AxError::NotConnected),
84 }
85 }
86
87 #[inline]
90 pub fn peer_addr(&self) -> AxResult<SocketAddr> {
91 match self.get_state() {
92 STATE_CONNECTED | STATE_LISTENING => {
93 Ok(into_core_sockaddr(unsafe { self.peer_addr.get().read() }))
94 }
95 _ => Err(AxError::NotConnected),
96 }
97 }
98
99 #[inline]
101 pub fn is_nonblocking(&self) -> bool {
102 self.nonblock.load(Ordering::Acquire)
103 }
104
105 #[inline]
114 pub fn set_nonblocking(&self, nonblocking: bool) {
115 self.nonblock.store(nonblocking, Ordering::Release);
116 }
117
118 pub fn connect(&self, remote_addr: SocketAddr) -> AxResult {
122 self.update_state(STATE_CLOSED, STATE_CONNECTING, || {
123 let handle = unsafe { self.handle.get().read() }
125 .unwrap_or_else(|| SOCKET_SET.add(SocketSetWrapper::new_tcp_socket()));
126
127 let remote_endpoint = from_core_sockaddr(remote_addr);
129 let bound_endpoint = self.bound_endpoint()?;
130 let iface = Ð0.iface;
131 let (local_endpoint, remote_endpoint) = SOCKET_SET
132 .with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
133 socket
134 .connect(iface.lock().context(), remote_endpoint, bound_endpoint)
135 .or_else(|e| match e {
136 ConnectError::InvalidState => {
137 ax_err!(BadState, "socket connect() failed")
138 }
139 ConnectError::Unaddressable => {
140 ax_err!(ConnectionRefused, "socket connect() failed")
141 }
142 })?;
143 Ok((
144 socket.local_endpoint().unwrap(),
145 socket.remote_endpoint().unwrap(),
146 ))
147 })?;
148 unsafe {
149 self.local_addr.get().write(local_endpoint);
152 self.peer_addr.get().write(remote_endpoint);
153 self.handle.get().write(Some(handle));
154 }
155 Ok(())
156 })
157 .unwrap_or_else(|_| ax_err!(AlreadyExists, "socket connect() failed: already connected"))?; if self.is_nonblocking() {
161 Err(AxError::WouldBlock)
162 } else {
163 self.block_on(|| {
164 let PollState { writable, .. } = self.poll_connect()?;
165 if !writable {
166 Err(AxError::WouldBlock)
167 } else if self.get_state() == STATE_CONNECTED {
168 Ok(())
169 } else {
170 ax_err!(ConnectionRefused, "socket connect() failed")
171 }
172 })
173 }
174 }
175
176 pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
183 self.update_state(STATE_CLOSED, STATE_CLOSED, || {
184 if local_addr.port() == 0 {
186 local_addr.set_port(get_ephemeral_port()?);
187 }
188 unsafe {
191 let old = self.local_addr.get().read();
192 if old != UNSPECIFIED_ENDPOINT {
193 return ax_err!(InvalidInput, "socket bind() failed: already bound");
194 }
195 self.local_addr.get().write(from_core_sockaddr(local_addr));
196 }
197 Ok(())
198 })
199 .unwrap_or_else(|_| ax_err!(InvalidInput, "socket bind() failed: already bound"))
200 }
201
202 pub fn listen(&self) -> AxResult {
207 self.update_state(STATE_CLOSED, STATE_LISTENING, || {
208 let bound_endpoint = self.bound_endpoint()?;
209 unsafe {
210 (*self.local_addr.get()).port = bound_endpoint.port;
211 }
212 LISTEN_TABLE.listen(bound_endpoint)?;
213 debug!("TCP socket listening on {}", bound_endpoint);
214 Ok(())
215 })
216 .unwrap_or(Ok(())) }
218
219 pub fn accept(&self) -> AxResult<TcpSocket> {
226 if !self.is_listening() {
227 return ax_err!(InvalidInput, "socket accept() failed: not listen");
228 }
229
230 let local_port = unsafe { self.local_addr.get().read().port };
232 self.block_on(|| {
233 let (handle, (local_addr, peer_addr)) = LISTEN_TABLE.accept(local_port)?;
234 debug!("TCP socket accepted a new connection {}", peer_addr);
235 Ok(TcpSocket::new_connected(handle, local_addr, peer_addr))
236 })
237 }
238
239 pub fn shutdown(&self) -> AxResult {
241 self.update_state(STATE_CONNECTED, STATE_CLOSED, || {
243 let handle = unsafe { self.handle.get().read().unwrap() };
246 SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
247 debug!("TCP socket {}: shutting down", handle);
248 socket.close();
249 });
250 unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; SOCKET_SET.poll_interfaces();
252 Ok(())
253 })
254 .unwrap_or(Ok(()))?;
255
256 self.update_state(STATE_LISTENING, STATE_CLOSED, || {
258 let local_port = unsafe { self.local_addr.get().read().port };
261 unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; LISTEN_TABLE.unlisten(local_port);
263 SOCKET_SET.poll_interfaces();
264 Ok(())
265 })
266 .unwrap_or(Ok(()))?;
267
268 Ok(())
270 }
271
272 pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
274 if self.is_connecting() {
275 return Err(AxError::WouldBlock);
276 } else if !self.is_connected() {
277 return ax_err!(NotConnected, "socket recv() failed");
278 }
279
280 let handle = unsafe { self.handle.get().read().unwrap() };
282 self.block_on(|| {
283 SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
284 if !socket.is_active() {
285 ax_err!(ConnectionRefused, "socket recv() failed")
287 } else if !socket.may_recv() {
288 Ok(0)
290 } else if socket.recv_queue() > 0 {
291 let len = socket
294 .recv_slice(buf)
295 .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
296 Ok(len)
297 } else {
298 Err(AxError::WouldBlock)
300 }
301 })
302 })
303 }
304
305 pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
307 if self.is_connecting() {
308 return Err(AxError::WouldBlock);
309 } else if !self.is_connected() {
310 return ax_err!(NotConnected, "socket send() failed");
311 }
312
313 let handle = unsafe { self.handle.get().read().unwrap() };
315 self.block_on(|| {
316 SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
317 if !socket.is_active() || !socket.may_send() {
318 ax_err!(ConnectionReset, "socket send() failed")
320 } else if socket.can_send() {
321 let len = socket
324 .send_slice(buf)
325 .map_err(|_| ax_err_type!(BadState, "socket send() failed"))?;
326 Ok(len)
327 } else {
328 Err(AxError::WouldBlock)
330 }
331 })
332 })
333 }
334
335 pub fn poll(&self) -> AxResult<PollState> {
337 match self.get_state() {
338 STATE_CONNECTING => self.poll_connect(),
339 STATE_CONNECTED => self.poll_stream(),
340 STATE_LISTENING => self.poll_listener(),
341 _ => Ok(PollState {
342 readable: false,
343 writable: false,
344 }),
345 }
346 }
347}
348
349impl TcpSocket {
351 #[inline]
352 fn get_state(&self) -> u8 {
353 self.state.load(Ordering::Acquire)
354 }
355
356 #[inline]
357 fn set_state(&self, state: u8) {
358 self.state.store(state, Ordering::Release);
359 }
360
361 fn update_state<F, T>(&self, expect: u8, new: u8, f: F) -> Result<AxResult<T>, u8>
370 where
371 F: FnOnce() -> AxResult<T>,
372 {
373 match self
374 .state
375 .compare_exchange(expect, STATE_BUSY, Ordering::Acquire, Ordering::Acquire)
376 {
377 Ok(_) => {
378 let res = f();
379 if res.is_ok() {
380 self.set_state(new);
381 } else {
382 self.set_state(expect);
383 }
384 Ok(res)
385 }
386 Err(old) => Err(old),
387 }
388 }
389
390 #[inline]
391 fn is_connecting(&self) -> bool {
392 self.get_state() == STATE_CONNECTING
393 }
394
395 #[inline]
396 fn is_connected(&self) -> bool {
397 self.get_state() == STATE_CONNECTED
398 }
399
400 #[inline]
401 fn is_listening(&self) -> bool {
402 self.get_state() == STATE_LISTENING
403 }
404
405 fn bound_endpoint(&self) -> AxResult<IpListenEndpoint> {
406 let local_addr = unsafe { self.local_addr.get().read() };
408 let port = if local_addr.port != 0 {
409 local_addr.port
410 } else {
411 get_ephemeral_port()?
412 };
413 assert_ne!(port, 0);
414 let addr = if !is_unspecified(local_addr.addr) {
415 Some(local_addr.addr)
416 } else {
417 None
418 };
419 Ok(IpListenEndpoint { addr, port })
420 }
421
422 fn poll_connect(&self) -> AxResult<PollState> {
423 let handle = unsafe { self.handle.get().read().unwrap() };
425 let writable =
426 SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| match socket.state() {
427 State::SynSent => false, State::Established => {
429 self.set_state(STATE_CONNECTED); debug!(
431 "TCP socket {}: connected to {}",
432 handle,
433 socket.remote_endpoint().unwrap(),
434 );
435 true
436 }
437 _ => {
438 unsafe {
439 self.local_addr.get().write(UNSPECIFIED_ENDPOINT);
440 self.peer_addr.get().write(UNSPECIFIED_ENDPOINT);
441 }
442 self.set_state(STATE_CLOSED); true
444 }
445 });
446 Ok(PollState {
447 readable: false,
448 writable,
449 })
450 }
451
452 fn poll_stream(&self) -> AxResult<PollState> {
453 let handle = unsafe { self.handle.get().read().unwrap() };
455 SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| {
456 Ok(PollState {
457 readable: !socket.may_recv() || socket.can_recv(),
458 writable: !socket.may_send() || socket.can_send(),
459 })
460 })
461 }
462
463 fn poll_listener(&self) -> AxResult<PollState> {
464 let local_addr = unsafe { self.local_addr.get().read() };
466 Ok(PollState {
467 readable: LISTEN_TABLE.can_accept(local_addr.port)?,
468 writable: false,
469 })
470 }
471
472 fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
478 where
479 F: FnMut() -> AxResult<T>,
480 {
481 if self.is_nonblocking() {
482 f()
483 } else {
484 loop {
485 SOCKET_SET.poll_interfaces();
486 match f() {
487 Ok(t) => return Ok(t),
488 Err(AxError::WouldBlock) => axtask::yield_now(),
489 Err(e) => return Err(e),
490 }
491 }
492 }
493 }
494}
495
496impl Drop for TcpSocket {
497 fn drop(&mut self) {
498 self.shutdown().ok();
499 if let Some(handle) = unsafe { self.handle.get().read() } {
501 SOCKET_SET.remove(handle);
502 }
503 }
504}
505
506fn get_ephemeral_port() -> AxResult<u16> {
507 const PORT_START: u16 = 0xc000;
508 const PORT_END: u16 = 0xffff;
509 static CURR: Mutex<u16> = Mutex::new(PORT_START);
510
511 let mut curr = CURR.lock();
512 let mut tries = 0;
513 while tries <= PORT_END - PORT_START {
515 let port = *curr;
516 if *curr == PORT_END {
517 *curr = PORT_START;
518 } else {
519 *curr += 1;
520 }
521 if LISTEN_TABLE.can_listen(port) {
522 return Ok(port);
523 }
524 tries += 1;
525 }
526 ax_err!(AddrInUse, "no avaliable ports!")
527}