1use alloc::{sync::Arc, vec, vec::Vec};
2use core::ffi::{c_char, c_int, c_void};
3use core::mem::size_of;
4use core::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
5
6use axerrno::{LinuxError, LinuxResult};
7use axio::PollState;
8use axnet::{TcpSocket, UdpSocket};
9use axsync::Mutex;
10
11use super::fd_ops::FileLike;
12use crate::ctypes;
13use crate::utils::char_ptr_to_str;
14
15pub enum Socket {
16 Udp(Mutex<UdpSocket>),
17 Tcp(Mutex<TcpSocket>),
18}
19
20impl Socket {
21 fn add_to_fd_table(self) -> LinuxResult<c_int> {
22 super::fd_ops::add_file_like(Arc::new(self))
23 }
24
25 fn from_fd(fd: c_int) -> LinuxResult<Arc<Self>> {
26 let f = super::fd_ops::get_file_like(fd)?;
27 f.into_any()
28 .downcast::<Self>()
29 .map_err(|_| LinuxError::EINVAL)
30 }
31
32 fn send(&self, buf: &[u8]) -> LinuxResult<usize> {
33 match self {
34 Socket::Udp(udpsocket) => Ok(udpsocket.lock().send(buf)?),
35 Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().send(buf)?),
36 }
37 }
38
39 fn recv(&self, buf: &mut [u8]) -> LinuxResult<usize> {
40 match self {
41 Socket::Udp(udpsocket) => Ok(udpsocket.lock().recv_from(buf).map(|e| e.0)?),
42 Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().recv(buf)?),
43 }
44 }
45
46 pub fn poll(&self) -> LinuxResult<PollState> {
47 match self {
48 Socket::Udp(udpsocket) => Ok(udpsocket.lock().poll()?),
49 Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().poll()?),
50 }
51 }
52
53 fn local_addr(&self) -> LinuxResult<SocketAddr> {
54 match self {
55 Socket::Udp(udpsocket) => Ok(udpsocket.lock().local_addr()?),
56 Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().local_addr()?),
57 }
58 }
59
60 fn peer_addr(&self) -> LinuxResult<SocketAddr> {
61 match self {
62 Socket::Udp(udpsocket) => Ok(udpsocket.lock().peer_addr()?),
63 Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().peer_addr()?),
64 }
65 }
66
67 fn bind(&self, addr: SocketAddr) -> LinuxResult {
68 match self {
69 Socket::Udp(udpsocket) => Ok(udpsocket.lock().bind(addr)?),
70 Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().bind(addr)?),
71 }
72 }
73
74 fn connect(&self, addr: SocketAddr) -> LinuxResult {
75 match self {
76 Socket::Udp(udpsocket) => Ok(udpsocket.lock().connect(addr)?),
77 Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().connect(addr)?),
78 }
79 }
80
81 fn sendto(&self, buf: &[u8], addr: SocketAddr) -> LinuxResult<usize> {
82 match self {
83 Socket::Udp(udpsocket) => Ok(udpsocket.lock().send_to(buf, addr)?),
85 Socket::Tcp(_) => Err(LinuxError::EISCONN),
86 }
87 }
88
89 fn recvfrom(&self, buf: &mut [u8]) -> LinuxResult<(usize, Option<SocketAddr>)> {
90 match self {
91 Socket::Udp(udpsocket) => Ok(udpsocket
93 .lock()
94 .recv_from(buf)
95 .map(|res| (res.0, Some(res.1)))?),
96 Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().recv(buf).map(|res| (res, None))?),
97 }
98 }
99
100 fn listen(&self) -> LinuxResult {
101 match self {
102 Socket::Udp(_) => Err(LinuxError::EOPNOTSUPP),
103 Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().listen()?),
104 }
105 }
106
107 fn accept(&self) -> LinuxResult<TcpSocket> {
108 match self {
109 Socket::Udp(_) => Err(LinuxError::EOPNOTSUPP),
110 Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().accept()?),
111 }
112 }
113
114 fn shutdown(&self) -> LinuxResult {
115 match self {
116 Socket::Udp(udpsocket) => {
117 let udpsocket = udpsocket.lock();
118 udpsocket.peer_addr()?;
119 udpsocket.shutdown()?;
120 Ok(())
121 }
122
123 Socket::Tcp(tcpsocket) => {
124 let tcpsocket = tcpsocket.lock();
125 tcpsocket.peer_addr()?;
126 tcpsocket.shutdown()?;
127 Ok(())
128 }
129 }
130 }
131}
132
133impl FileLike for Socket {
134 fn read(&self, buf: &mut [u8]) -> LinuxResult<usize> {
135 self.recv(buf)
136 }
137
138 fn write(&self, buf: &[u8]) -> LinuxResult<usize> {
139 self.send(buf)
140 }
141
142 fn stat(&self) -> LinuxResult<ctypes::stat> {
143 let st_mode = 0o140000 | 0o777u32; Ok(ctypes::stat {
146 st_ino: 1,
147 st_nlink: 1,
148 st_mode,
149 st_uid: 1000,
150 st_gid: 1000,
151 st_blksize: 4096,
152 ..Default::default()
153 })
154 }
155
156 fn into_any(self: Arc<Self>) -> Arc<dyn core::any::Any + Send + Sync> {
157 self
158 }
159
160 fn poll(&self) -> LinuxResult<PollState> {
161 self.poll()
162 }
163
164 fn set_nonblocking(&self, nonblock: bool) -> LinuxResult {
165 match self {
166 Socket::Udp(udpsocket) => udpsocket.lock().set_nonblocking(nonblock),
167 Socket::Tcp(tcpsocket) => tcpsocket.lock().set_nonblocking(nonblock),
168 }
169 Ok(())
170 }
171}
172
173impl From<SocketAddrV4> for ctypes::sockaddr_in {
174 fn from(addr: SocketAddrV4) -> ctypes::sockaddr_in {
175 ctypes::sockaddr_in {
176 sin_family: ctypes::AF_INET as u16,
177 sin_port: addr.port().to_be(),
178 sin_addr: ctypes::in_addr {
179 s_addr: u32::from_ne_bytes(addr.ip().octets()),
182 },
183 sin_zero: [0; 8],
184 }
185 }
186}
187
188impl From<ctypes::sockaddr_in> for SocketAddrV4 {
189 fn from(addr: ctypes::sockaddr_in) -> SocketAddrV4 {
190 SocketAddrV4::new(
191 Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes()),
192 u16::from_be(addr.sin_port),
193 )
194 }
195}
196
197fn into_sockaddr(addr: SocketAddr) -> (ctypes::sockaddr, ctypes::socklen_t) {
198 debug!(" Sockaddr: {}", addr);
199 match addr {
200 SocketAddr::V4(addr) => (
201 unsafe { *(&ctypes::sockaddr_in::from(addr) as *const _ as *const ctypes::sockaddr) },
202 size_of::<ctypes::sockaddr>() as _,
203 ),
204 SocketAddr::V6(_) => panic!("IPv6 is not supported"),
205 }
206}
207
208fn from_sockaddr(
209 addr: *const ctypes::sockaddr,
210 addrlen: ctypes::socklen_t,
211) -> LinuxResult<SocketAddr> {
212 if addr.is_null() {
213 return Err(LinuxError::EFAULT);
214 }
215 if addrlen != size_of::<ctypes::sockaddr>() as _ {
216 return Err(LinuxError::EINVAL);
217 }
218
219 let mid = unsafe { *(addr as *const ctypes::sockaddr_in) };
220 if mid.sin_family != ctypes::AF_INET as u16 {
221 return Err(LinuxError::EINVAL);
222 }
223
224 let res = SocketAddr::V4(mid.into());
225 debug!(" load sockaddr:{:#x} => {:?}", addr as usize, res);
226 Ok(res)
227}
228
229pub fn sys_socket(domain: c_int, socktype: c_int, protocol: c_int) -> c_int {
233 debug!("sys_socket <= {} {} {}", domain, socktype, protocol);
234 let (domain, socktype, protocol) = (domain as u32, socktype as u32, protocol as u32);
235 syscall_body!(sys_socket, {
236 match (domain, socktype, protocol) {
237 (ctypes::AF_INET, ctypes::SOCK_STREAM, ctypes::IPPROTO_TCP)
238 | (ctypes::AF_INET, ctypes::SOCK_STREAM, 0) => {
239 Socket::Tcp(Mutex::new(TcpSocket::new())).add_to_fd_table()
240 }
241 (ctypes::AF_INET, ctypes::SOCK_DGRAM, ctypes::IPPROTO_UDP)
242 | (ctypes::AF_INET, ctypes::SOCK_DGRAM, 0) => {
243 Socket::Udp(Mutex::new(UdpSocket::new())).add_to_fd_table()
244 }
245 _ => Err(LinuxError::EINVAL),
246 }
247 })
248}
249
250pub fn sys_bind(
254 socket_fd: c_int,
255 socket_addr: *const ctypes::sockaddr,
256 addrlen: ctypes::socklen_t,
257) -> c_int {
258 debug!(
259 "sys_bind <= {} {:#x} {}",
260 socket_fd, socket_addr as usize, addrlen
261 );
262 syscall_body!(sys_bind, {
263 let addr = from_sockaddr(socket_addr, addrlen)?;
264 Socket::from_fd(socket_fd)?.bind(addr)?;
265 Ok(0)
266 })
267}
268
269pub fn sys_connect(
273 socket_fd: c_int,
274 socket_addr: *const ctypes::sockaddr,
275 addrlen: ctypes::socklen_t,
276) -> c_int {
277 debug!(
278 "sys_connect <= {} {:#x} {}",
279 socket_fd, socket_addr as usize, addrlen
280 );
281 syscall_body!(sys_connect, {
282 let addr = from_sockaddr(socket_addr, addrlen)?;
283 Socket::from_fd(socket_fd)?.connect(addr)?;
284 Ok(0)
285 })
286}
287
288pub fn sys_sendto(
292 socket_fd: c_int,
293 buf_ptr: *const c_void,
294 len: ctypes::size_t,
295 flag: c_int, socket_addr: *const ctypes::sockaddr,
297 addrlen: ctypes::socklen_t,
298) -> ctypes::ssize_t {
299 debug!(
300 "sys_sendto <= {} {:#x} {} {} {:#x} {}",
301 socket_fd, buf_ptr as usize, len, flag, socket_addr as usize, addrlen
302 );
303 syscall_body!(sys_sendto, {
304 if buf_ptr.is_null() {
305 return Err(LinuxError::EFAULT);
306 }
307 let addr = from_sockaddr(socket_addr, addrlen)?;
308 let buf = unsafe { core::slice::from_raw_parts(buf_ptr as *const u8, len) };
309 Socket::from_fd(socket_fd)?.sendto(buf, addr)
310 })
311}
312
313pub fn sys_send(
317 socket_fd: c_int,
318 buf_ptr: *const c_void,
319 len: ctypes::size_t,
320 flag: c_int, ) -> ctypes::ssize_t {
322 debug!(
323 "sys_sendto <= {} {:#x} {} {}",
324 socket_fd, buf_ptr as usize, len, flag
325 );
326 syscall_body!(sys_send, {
327 if buf_ptr.is_null() {
328 return Err(LinuxError::EFAULT);
329 }
330 let buf = unsafe { core::slice::from_raw_parts(buf_ptr as *const u8, len) };
331 Socket::from_fd(socket_fd)?.send(buf)
332 })
333}
334
335pub unsafe fn sys_recvfrom(
339 socket_fd: c_int,
340 buf_ptr: *mut c_void,
341 len: ctypes::size_t,
342 flag: c_int, socket_addr: *mut ctypes::sockaddr,
344 addrlen: *mut ctypes::socklen_t,
345) -> ctypes::ssize_t {
346 debug!(
347 "sys_recvfrom <= {} {:#x} {} {} {:#x} {:#x}",
348 socket_fd, buf_ptr as usize, len, flag, socket_addr as usize, addrlen as usize
349 );
350 syscall_body!(sys_recvfrom, {
351 if buf_ptr.is_null() || socket_addr.is_null() || addrlen.is_null() {
352 return Err(LinuxError::EFAULT);
353 }
354 let socket = Socket::from_fd(socket_fd)?;
355 let buf = unsafe { core::slice::from_raw_parts_mut(buf_ptr as *mut u8, len) };
356
357 let res = socket.recvfrom(buf)?;
358 if let Some(addr) = res.1 {
359 unsafe {
360 (*socket_addr, *addrlen) = into_sockaddr(addr);
361 }
362 }
363 Ok(res.0)
364 })
365}
366
367pub fn sys_recv(
371 socket_fd: c_int,
372 buf_ptr: *mut c_void,
373 len: ctypes::size_t,
374 flag: c_int, ) -> ctypes::ssize_t {
376 debug!(
377 "sys_recv <= {} {:#x} {} {}",
378 socket_fd, buf_ptr as usize, len, flag
379 );
380 syscall_body!(sys_recv, {
381 if buf_ptr.is_null() {
382 return Err(LinuxError::EFAULT);
383 }
384 let buf = unsafe { core::slice::from_raw_parts_mut(buf_ptr as *mut u8, len) };
385 Socket::from_fd(socket_fd)?.recv(buf)
386 })
387}
388
389pub fn sys_listen(
393 socket_fd: c_int,
394 backlog: c_int, ) -> c_int {
396 debug!("sys_listen <= {} {}", socket_fd, backlog);
397 syscall_body!(sys_listen, {
398 Socket::from_fd(socket_fd)?.listen()?;
399 Ok(0)
400 })
401}
402
403pub unsafe fn sys_accept(
407 socket_fd: c_int,
408 socket_addr: *mut ctypes::sockaddr,
409 socket_len: *mut ctypes::socklen_t,
410) -> c_int {
411 debug!(
412 "sys_accept <= {} {:#x} {:#x}",
413 socket_fd, socket_addr as usize, socket_len as usize
414 );
415 syscall_body!(sys_accept, {
416 if socket_addr.is_null() || socket_len.is_null() {
417 return Err(LinuxError::EFAULT);
418 }
419 let socket = Socket::from_fd(socket_fd)?;
420 let new_socket = socket.accept()?;
421 let addr = new_socket.peer_addr()?;
422 let new_fd = Socket::add_to_fd_table(Socket::Tcp(Mutex::new(new_socket)))?;
423 unsafe {
424 (*socket_addr, *socket_len) = into_sockaddr(addr);
425 }
426 Ok(new_fd)
427 })
428}
429
430pub fn sys_shutdown(
434 socket_fd: c_int,
435 flag: c_int, ) -> c_int {
437 debug!("sys_shutdown <= {} {}", socket_fd, flag);
438 syscall_body!(sys_shutdown, {
439 Socket::from_fd(socket_fd)?.shutdown()?;
440 Ok(0)
441 })
442}
443
444pub unsafe fn sys_getaddrinfo(
451 nodename: *const c_char,
452 servname: *const c_char,
453 _hints: *const ctypes::addrinfo,
454 res: *mut *mut ctypes::addrinfo,
455) -> c_int {
456 let name = char_ptr_to_str(nodename);
457 let port = char_ptr_to_str(servname);
458 debug!("sys_getaddrinfo <= {:?} {:?}", name, port);
459 syscall_body!(sys_getaddrinfo, {
460 if nodename.is_null() && servname.is_null() {
461 return Ok(0);
462 }
463 if res.is_null() {
464 return Err(LinuxError::EFAULT);
465 }
466
467 let port = port.map_or(0, |p| p.parse::<u16>().unwrap_or(0));
468 let ip_addrs = if let Ok(domain) = name {
469 if let Ok(a) = domain.parse::<IpAddr>() {
470 vec![a]
471 } else {
472 axnet::dns_query(domain)?
473 }
474 } else {
475 vec![Ipv4Addr::LOCALHOST.into()]
476 };
477
478 let len = ip_addrs.len().min(ctypes::MAXADDRS as usize);
479 if len == 0 {
480 return Ok(0);
481 }
482
483 let mut out: Vec<ctypes::aibuf> = Vec::with_capacity(len);
484 for (i, &ip) in ip_addrs.iter().enumerate().take(len) {
485 let buf = match ip {
486 IpAddr::V4(ip) => ctypes::aibuf {
487 ai: ctypes::addrinfo {
488 ai_family: ctypes::AF_INET as _,
489 ai_socktype: ctypes::SOCK_STREAM as _,
491 ai_protocol: ctypes::IPPROTO_TCP as _,
492 ai_addrlen: size_of::<ctypes::sockaddr_in>() as _,
493 ai_addr: core::ptr::null_mut(),
494 ai_canonname: core::ptr::null_mut(),
495 ai_next: core::ptr::null_mut(),
496 ai_flags: 0,
497 },
498 sa: ctypes::aibuf_sa {
499 sin: SocketAddrV4::new(ip, port).into(),
500 },
501 slot: i as i16,
502 lock: [0],
503 ref_: 0,
504 },
505 _ => panic!("IPv6 is not supported"),
506 };
507 out.push(buf);
508 out[i].ai.ai_addr =
509 unsafe { core::ptr::addr_of_mut!(out[i].sa.sin) as *mut ctypes::sockaddr };
510 if i > 0 {
511 out[i - 1].ai.ai_next = core::ptr::addr_of_mut!(out[i].ai);
512 }
513 }
514
515 out[0].ref_ = len as i16;
516 unsafe { *res = core::ptr::addr_of_mut!(out[0].ai) };
517 core::mem::forget(out); Ok(len)
519 })
520}
521
522pub unsafe fn sys_freeaddrinfo(res: *mut ctypes::addrinfo) {
524 if res.is_null() {
525 return;
526 }
527 let aibuf_ptr = res as *mut ctypes::aibuf;
528 let len = unsafe { *aibuf_ptr }.ref_ as usize;
529 assert!(unsafe { *aibuf_ptr }.slot == 0);
530 assert!(len > 0);
531 let vec = unsafe { Vec::from_raw_parts(aibuf_ptr, len, len) }; drop(vec);
533}
534
535pub unsafe fn sys_getsockname(
537 sock_fd: c_int,
538 addr: *mut ctypes::sockaddr,
539 addrlen: *mut ctypes::socklen_t,
540) -> c_int {
541 debug!(
542 "sys_getsockname <= {} {:#x} {:#x}",
543 sock_fd, addr as usize, addrlen as usize
544 );
545 syscall_body!(sys_getsockname, {
546 if addr.is_null() || addrlen.is_null() {
547 return Err(LinuxError::EFAULT);
548 }
549 if unsafe { *addrlen } < size_of::<ctypes::sockaddr>() as u32 {
550 return Err(LinuxError::EINVAL);
551 }
552 unsafe {
553 (*addr, *addrlen) = into_sockaddr(Socket::from_fd(sock_fd)?.local_addr()?);
554 }
555 Ok(0)
556 })
557}
558
559pub unsafe fn sys_getpeername(
561 sock_fd: c_int,
562 addr: *mut ctypes::sockaddr,
563 addrlen: *mut ctypes::socklen_t,
564) -> c_int {
565 debug!(
566 "sys_getpeername <= {} {:#x} {:#x}",
567 sock_fd, addr as usize, addrlen as usize
568 );
569 syscall_body!(sys_getpeername, {
570 if addr.is_null() || addrlen.is_null() {
571 return Err(LinuxError::EFAULT);
572 }
573 if unsafe { *addrlen } < size_of::<ctypes::sockaddr>() as u32 {
574 return Err(LinuxError::EINVAL);
575 }
576 unsafe {
577 (*addr, *addrlen) = into_sockaddr(Socket::from_fd(sock_fd)?.peer_addr()?);
578 }
579 Ok(0)
580 })
581}