arceos_posix_api/imp/
net.rs

1use alloc::{sync::Arc, vec, vec::Vec};
2use core::ffi::{c_char, c_int, c_void};
3use core::mem::size_of;
4use core::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
5
6use axerrno::{LinuxError, LinuxResult};
7use axio::PollState;
8use axnet::{TcpSocket, UdpSocket};
9use axsync::Mutex;
10
11use super::fd_ops::FileLike;
12use crate::ctypes;
13use crate::utils::char_ptr_to_str;
14
15pub enum Socket {
16    Udp(Mutex<UdpSocket>),
17    Tcp(Mutex<TcpSocket>),
18}
19
20impl Socket {
21    fn add_to_fd_table(self) -> LinuxResult<c_int> {
22        super::fd_ops::add_file_like(Arc::new(self))
23    }
24
25    fn from_fd(fd: c_int) -> LinuxResult<Arc<Self>> {
26        let f = super::fd_ops::get_file_like(fd)?;
27        f.into_any()
28            .downcast::<Self>()
29            .map_err(|_| LinuxError::EINVAL)
30    }
31
32    fn send(&self, buf: &[u8]) -> LinuxResult<usize> {
33        match self {
34            Socket::Udp(udpsocket) => Ok(udpsocket.lock().send(buf)?),
35            Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().send(buf)?),
36        }
37    }
38
39    fn recv(&self, buf: &mut [u8]) -> LinuxResult<usize> {
40        match self {
41            Socket::Udp(udpsocket) => Ok(udpsocket.lock().recv_from(buf).map(|e| e.0)?),
42            Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().recv(buf)?),
43        }
44    }
45
46    pub fn poll(&self) -> LinuxResult<PollState> {
47        match self {
48            Socket::Udp(udpsocket) => Ok(udpsocket.lock().poll()?),
49            Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().poll()?),
50        }
51    }
52
53    fn local_addr(&self) -> LinuxResult<SocketAddr> {
54        match self {
55            Socket::Udp(udpsocket) => Ok(udpsocket.lock().local_addr()?),
56            Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().local_addr()?),
57        }
58    }
59
60    fn peer_addr(&self) -> LinuxResult<SocketAddr> {
61        match self {
62            Socket::Udp(udpsocket) => Ok(udpsocket.lock().peer_addr()?),
63            Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().peer_addr()?),
64        }
65    }
66
67    fn bind(&self, addr: SocketAddr) -> LinuxResult {
68        match self {
69            Socket::Udp(udpsocket) => Ok(udpsocket.lock().bind(addr)?),
70            Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().bind(addr)?),
71        }
72    }
73
74    fn connect(&self, addr: SocketAddr) -> LinuxResult {
75        match self {
76            Socket::Udp(udpsocket) => Ok(udpsocket.lock().connect(addr)?),
77            Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().connect(addr)?),
78        }
79    }
80
81    fn sendto(&self, buf: &[u8], addr: SocketAddr) -> LinuxResult<usize> {
82        match self {
83            // diff: must bind before sendto
84            Socket::Udp(udpsocket) => Ok(udpsocket.lock().send_to(buf, addr)?),
85            Socket::Tcp(_) => Err(LinuxError::EISCONN),
86        }
87    }
88
89    fn recvfrom(&self, buf: &mut [u8]) -> LinuxResult<(usize, Option<SocketAddr>)> {
90        match self {
91            // diff: must bind before recvfrom
92            Socket::Udp(udpsocket) => Ok(udpsocket
93                .lock()
94                .recv_from(buf)
95                .map(|res| (res.0, Some(res.1)))?),
96            Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().recv(buf).map(|res| (res, None))?),
97        }
98    }
99
100    fn listen(&self) -> LinuxResult {
101        match self {
102            Socket::Udp(_) => Err(LinuxError::EOPNOTSUPP),
103            Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().listen()?),
104        }
105    }
106
107    fn accept(&self) -> LinuxResult<TcpSocket> {
108        match self {
109            Socket::Udp(_) => Err(LinuxError::EOPNOTSUPP),
110            Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().accept()?),
111        }
112    }
113
114    fn shutdown(&self) -> LinuxResult {
115        match self {
116            Socket::Udp(udpsocket) => {
117                let udpsocket = udpsocket.lock();
118                udpsocket.peer_addr()?;
119                udpsocket.shutdown()?;
120                Ok(())
121            }
122
123            Socket::Tcp(tcpsocket) => {
124                let tcpsocket = tcpsocket.lock();
125                tcpsocket.peer_addr()?;
126                tcpsocket.shutdown()?;
127                Ok(())
128            }
129        }
130    }
131}
132
133impl FileLike for Socket {
134    fn read(&self, buf: &mut [u8]) -> LinuxResult<usize> {
135        self.recv(buf)
136    }
137
138    fn write(&self, buf: &[u8]) -> LinuxResult<usize> {
139        self.send(buf)
140    }
141
142    fn stat(&self) -> LinuxResult<ctypes::stat> {
143        // not really implemented
144        let st_mode = 0o140000 | 0o777u32; // S_IFSOCK | rwxrwxrwx
145        Ok(ctypes::stat {
146            st_ino: 1,
147            st_nlink: 1,
148            st_mode,
149            st_uid: 1000,
150            st_gid: 1000,
151            st_blksize: 4096,
152            ..Default::default()
153        })
154    }
155
156    fn into_any(self: Arc<Self>) -> Arc<dyn core::any::Any + Send + Sync> {
157        self
158    }
159
160    fn poll(&self) -> LinuxResult<PollState> {
161        self.poll()
162    }
163
164    fn set_nonblocking(&self, nonblock: bool) -> LinuxResult {
165        match self {
166            Socket::Udp(udpsocket) => udpsocket.lock().set_nonblocking(nonblock),
167            Socket::Tcp(tcpsocket) => tcpsocket.lock().set_nonblocking(nonblock),
168        }
169        Ok(())
170    }
171}
172
173impl From<SocketAddrV4> for ctypes::sockaddr_in {
174    fn from(addr: SocketAddrV4) -> ctypes::sockaddr_in {
175        ctypes::sockaddr_in {
176            sin_family: ctypes::AF_INET as u16,
177            sin_port: addr.port().to_be(),
178            sin_addr: ctypes::in_addr {
179                // `s_addr` is stored as BE on all machines and the array is in BE order.
180                // So the native endian conversion method is used so that it's never swapped.
181                s_addr: u32::from_ne_bytes(addr.ip().octets()),
182            },
183            sin_zero: [0; 8],
184        }
185    }
186}
187
188impl From<ctypes::sockaddr_in> for SocketAddrV4 {
189    fn from(addr: ctypes::sockaddr_in) -> SocketAddrV4 {
190        SocketAddrV4::new(
191            Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes()),
192            u16::from_be(addr.sin_port),
193        )
194    }
195}
196
197fn into_sockaddr(addr: SocketAddr) -> (ctypes::sockaddr, ctypes::socklen_t) {
198    debug!("    Sockaddr: {}", addr);
199    match addr {
200        SocketAddr::V4(addr) => (
201            unsafe { *(&ctypes::sockaddr_in::from(addr) as *const _ as *const ctypes::sockaddr) },
202            size_of::<ctypes::sockaddr>() as _,
203        ),
204        SocketAddr::V6(_) => panic!("IPv6 is not supported"),
205    }
206}
207
208fn from_sockaddr(
209    addr: *const ctypes::sockaddr,
210    addrlen: ctypes::socklen_t,
211) -> LinuxResult<SocketAddr> {
212    if addr.is_null() {
213        return Err(LinuxError::EFAULT);
214    }
215    if addrlen != size_of::<ctypes::sockaddr>() as _ {
216        return Err(LinuxError::EINVAL);
217    }
218
219    let mid = unsafe { *(addr as *const ctypes::sockaddr_in) };
220    if mid.sin_family != ctypes::AF_INET as u16 {
221        return Err(LinuxError::EINVAL);
222    }
223
224    let res = SocketAddr::V4(mid.into());
225    debug!("    load sockaddr:{:#x} => {:?}", addr as usize, res);
226    Ok(res)
227}
228
229/// Create an socket for communication.
230///
231/// Return the socket file descriptor.
232pub fn sys_socket(domain: c_int, socktype: c_int, protocol: c_int) -> c_int {
233    debug!("sys_socket <= {} {} {}", domain, socktype, protocol);
234    let (domain, socktype, protocol) = (domain as u32, socktype as u32, protocol as u32);
235    syscall_body!(sys_socket, {
236        match (domain, socktype, protocol) {
237            (ctypes::AF_INET, ctypes::SOCK_STREAM, ctypes::IPPROTO_TCP)
238            | (ctypes::AF_INET, ctypes::SOCK_STREAM, 0) => {
239                Socket::Tcp(Mutex::new(TcpSocket::new())).add_to_fd_table()
240            }
241            (ctypes::AF_INET, ctypes::SOCK_DGRAM, ctypes::IPPROTO_UDP)
242            | (ctypes::AF_INET, ctypes::SOCK_DGRAM, 0) => {
243                Socket::Udp(Mutex::new(UdpSocket::new())).add_to_fd_table()
244            }
245            _ => Err(LinuxError::EINVAL),
246        }
247    })
248}
249
250/// Bind a address to a socket.
251///
252/// Return 0 if success.
253pub fn sys_bind(
254    socket_fd: c_int,
255    socket_addr: *const ctypes::sockaddr,
256    addrlen: ctypes::socklen_t,
257) -> c_int {
258    debug!(
259        "sys_bind <= {} {:#x} {}",
260        socket_fd, socket_addr as usize, addrlen
261    );
262    syscall_body!(sys_bind, {
263        let addr = from_sockaddr(socket_addr, addrlen)?;
264        Socket::from_fd(socket_fd)?.bind(addr)?;
265        Ok(0)
266    })
267}
268
269/// Connects the socket to the address specified.
270///
271/// Return 0 if success.
272pub fn sys_connect(
273    socket_fd: c_int,
274    socket_addr: *const ctypes::sockaddr,
275    addrlen: ctypes::socklen_t,
276) -> c_int {
277    debug!(
278        "sys_connect <= {} {:#x} {}",
279        socket_fd, socket_addr as usize, addrlen
280    );
281    syscall_body!(sys_connect, {
282        let addr = from_sockaddr(socket_addr, addrlen)?;
283        Socket::from_fd(socket_fd)?.connect(addr)?;
284        Ok(0)
285    })
286}
287
288/// Send a message on a socket to the address specified.
289///
290/// Return the number of bytes sent if success.
291pub fn sys_sendto(
292    socket_fd: c_int,
293    buf_ptr: *const c_void,
294    len: ctypes::size_t,
295    flag: c_int, // currently not used
296    socket_addr: *const ctypes::sockaddr,
297    addrlen: ctypes::socklen_t,
298) -> ctypes::ssize_t {
299    debug!(
300        "sys_sendto <= {} {:#x} {} {} {:#x} {}",
301        socket_fd, buf_ptr as usize, len, flag, socket_addr as usize, addrlen
302    );
303    syscall_body!(sys_sendto, {
304        if buf_ptr.is_null() {
305            return Err(LinuxError::EFAULT);
306        }
307        let addr = from_sockaddr(socket_addr, addrlen)?;
308        let buf = unsafe { core::slice::from_raw_parts(buf_ptr as *const u8, len) };
309        Socket::from_fd(socket_fd)?.sendto(buf, addr)
310    })
311}
312
313/// Send a message on a socket to the address connected.
314///
315/// Return the number of bytes sent if success.
316pub fn sys_send(
317    socket_fd: c_int,
318    buf_ptr: *const c_void,
319    len: ctypes::size_t,
320    flag: c_int, // currently not used
321) -> ctypes::ssize_t {
322    debug!(
323        "sys_sendto <= {} {:#x} {} {}",
324        socket_fd, buf_ptr as usize, len, flag
325    );
326    syscall_body!(sys_send, {
327        if buf_ptr.is_null() {
328            return Err(LinuxError::EFAULT);
329        }
330        let buf = unsafe { core::slice::from_raw_parts(buf_ptr as *const u8, len) };
331        Socket::from_fd(socket_fd)?.send(buf)
332    })
333}
334
335/// Receive a message on a socket and get its source address.
336///
337/// Return the number of bytes received if success.
338pub unsafe fn sys_recvfrom(
339    socket_fd: c_int,
340    buf_ptr: *mut c_void,
341    len: ctypes::size_t,
342    flag: c_int, // currently not used
343    socket_addr: *mut ctypes::sockaddr,
344    addrlen: *mut ctypes::socklen_t,
345) -> ctypes::ssize_t {
346    debug!(
347        "sys_recvfrom <= {} {:#x} {} {} {:#x} {:#x}",
348        socket_fd, buf_ptr as usize, len, flag, socket_addr as usize, addrlen as usize
349    );
350    syscall_body!(sys_recvfrom, {
351        if buf_ptr.is_null() || socket_addr.is_null() || addrlen.is_null() {
352            return Err(LinuxError::EFAULT);
353        }
354        let socket = Socket::from_fd(socket_fd)?;
355        let buf = unsafe { core::slice::from_raw_parts_mut(buf_ptr as *mut u8, len) };
356
357        let res = socket.recvfrom(buf)?;
358        if let Some(addr) = res.1 {
359            unsafe {
360                (*socket_addr, *addrlen) = into_sockaddr(addr);
361            }
362        }
363        Ok(res.0)
364    })
365}
366
367/// Receive a message on a socket.
368///
369/// Return the number of bytes received if success.
370pub fn sys_recv(
371    socket_fd: c_int,
372    buf_ptr: *mut c_void,
373    len: ctypes::size_t,
374    flag: c_int, // currently not used
375) -> ctypes::ssize_t {
376    debug!(
377        "sys_recv <= {} {:#x} {} {}",
378        socket_fd, buf_ptr as usize, len, flag
379    );
380    syscall_body!(sys_recv, {
381        if buf_ptr.is_null() {
382            return Err(LinuxError::EFAULT);
383        }
384        let buf = unsafe { core::slice::from_raw_parts_mut(buf_ptr as *mut u8, len) };
385        Socket::from_fd(socket_fd)?.recv(buf)
386    })
387}
388
389/// Listen for connections on a socket
390///
391/// Return 0 if success.
392pub fn sys_listen(
393    socket_fd: c_int,
394    backlog: c_int, // currently not used
395) -> c_int {
396    debug!("sys_listen <= {} {}", socket_fd, backlog);
397    syscall_body!(sys_listen, {
398        Socket::from_fd(socket_fd)?.listen()?;
399        Ok(0)
400    })
401}
402
403/// Accept for connections on a socket
404///
405/// Return file descriptor for the accepted socket if success.
406pub unsafe fn sys_accept(
407    socket_fd: c_int,
408    socket_addr: *mut ctypes::sockaddr,
409    socket_len: *mut ctypes::socklen_t,
410) -> c_int {
411    debug!(
412        "sys_accept <= {} {:#x} {:#x}",
413        socket_fd, socket_addr as usize, socket_len as usize
414    );
415    syscall_body!(sys_accept, {
416        if socket_addr.is_null() || socket_len.is_null() {
417            return Err(LinuxError::EFAULT);
418        }
419        let socket = Socket::from_fd(socket_fd)?;
420        let new_socket = socket.accept()?;
421        let addr = new_socket.peer_addr()?;
422        let new_fd = Socket::add_to_fd_table(Socket::Tcp(Mutex::new(new_socket)))?;
423        unsafe {
424            (*socket_addr, *socket_len) = into_sockaddr(addr);
425        }
426        Ok(new_fd)
427    })
428}
429
430/// Shut down a full-duplex connection.
431///
432/// Return 0 if success.
433pub fn sys_shutdown(
434    socket_fd: c_int,
435    flag: c_int, // currently not used
436) -> c_int {
437    debug!("sys_shutdown <= {} {}", socket_fd, flag);
438    syscall_body!(sys_shutdown, {
439        Socket::from_fd(socket_fd)?.shutdown()?;
440        Ok(0)
441    })
442}
443
444/// Query addresses for a domain name.
445///
446/// Only IPv4. Ports are always 0. Ignore servname and hint.
447/// Results' ai_flags and ai_canonname are 0 or NULL.
448///
449/// Return address number if success.
450pub unsafe fn sys_getaddrinfo(
451    nodename: *const c_char,
452    servname: *const c_char,
453    _hints: *const ctypes::addrinfo,
454    res: *mut *mut ctypes::addrinfo,
455) -> c_int {
456    let name = char_ptr_to_str(nodename);
457    let port = char_ptr_to_str(servname);
458    debug!("sys_getaddrinfo <= {:?} {:?}", name, port);
459    syscall_body!(sys_getaddrinfo, {
460        if nodename.is_null() && servname.is_null() {
461            return Ok(0);
462        }
463        if res.is_null() {
464            return Err(LinuxError::EFAULT);
465        }
466
467        let port = port.map_or(0, |p| p.parse::<u16>().unwrap_or(0));
468        let ip_addrs = if let Ok(domain) = name {
469            if let Ok(a) = domain.parse::<IpAddr>() {
470                vec![a]
471            } else {
472                axnet::dns_query(domain)?
473            }
474        } else {
475            vec![Ipv4Addr::LOCALHOST.into()]
476        };
477
478        let len = ip_addrs.len().min(ctypes::MAXADDRS as usize);
479        if len == 0 {
480            return Ok(0);
481        }
482
483        let mut out: Vec<ctypes::aibuf> = Vec::with_capacity(len);
484        for (i, &ip) in ip_addrs.iter().enumerate().take(len) {
485            let buf = match ip {
486                IpAddr::V4(ip) => ctypes::aibuf {
487                    ai: ctypes::addrinfo {
488                        ai_family: ctypes::AF_INET as _,
489                        // TODO: This is a hard-code part, only return TCP parameters
490                        ai_socktype: ctypes::SOCK_STREAM as _,
491                        ai_protocol: ctypes::IPPROTO_TCP as _,
492                        ai_addrlen: size_of::<ctypes::sockaddr_in>() as _,
493                        ai_addr: core::ptr::null_mut(),
494                        ai_canonname: core::ptr::null_mut(),
495                        ai_next: core::ptr::null_mut(),
496                        ai_flags: 0,
497                    },
498                    sa: ctypes::aibuf_sa {
499                        sin: SocketAddrV4::new(ip, port).into(),
500                    },
501                    slot: i as i16,
502                    lock: [0],
503                    ref_: 0,
504                },
505                _ => panic!("IPv6 is not supported"),
506            };
507            out.push(buf);
508            out[i].ai.ai_addr =
509                unsafe { core::ptr::addr_of_mut!(out[i].sa.sin) as *mut ctypes::sockaddr };
510            if i > 0 {
511                out[i - 1].ai.ai_next = core::ptr::addr_of_mut!(out[i].ai);
512            }
513        }
514
515        out[0].ref_ = len as i16;
516        unsafe { *res = core::ptr::addr_of_mut!(out[0].ai) };
517        core::mem::forget(out); // drop in `sys_freeaddrinfo`
518        Ok(len)
519    })
520}
521
522/// Free queried `addrinfo` struct
523pub unsafe fn sys_freeaddrinfo(res: *mut ctypes::addrinfo) {
524    if res.is_null() {
525        return;
526    }
527    let aibuf_ptr = res as *mut ctypes::aibuf;
528    let len = unsafe { *aibuf_ptr }.ref_ as usize;
529    assert!(unsafe { *aibuf_ptr }.slot == 0);
530    assert!(len > 0);
531    let vec = unsafe { Vec::from_raw_parts(aibuf_ptr, len, len) }; // TODO: lock
532    drop(vec);
533}
534
535/// Get current address to which the socket sockfd is bound.
536pub unsafe fn sys_getsockname(
537    sock_fd: c_int,
538    addr: *mut ctypes::sockaddr,
539    addrlen: *mut ctypes::socklen_t,
540) -> c_int {
541    debug!(
542        "sys_getsockname <= {} {:#x} {:#x}",
543        sock_fd, addr as usize, addrlen as usize
544    );
545    syscall_body!(sys_getsockname, {
546        if addr.is_null() || addrlen.is_null() {
547            return Err(LinuxError::EFAULT);
548        }
549        if unsafe { *addrlen } < size_of::<ctypes::sockaddr>() as u32 {
550            return Err(LinuxError::EINVAL);
551        }
552        unsafe {
553            (*addr, *addrlen) = into_sockaddr(Socket::from_fd(sock_fd)?.local_addr()?);
554        }
555        Ok(0)
556    })
557}
558
559/// Get peer address to which the socket sockfd is connected.
560pub unsafe fn sys_getpeername(
561    sock_fd: c_int,
562    addr: *mut ctypes::sockaddr,
563    addrlen: *mut ctypes::socklen_t,
564) -> c_int {
565    debug!(
566        "sys_getpeername <= {} {:#x} {:#x}",
567        sock_fd, addr as usize, addrlen as usize
568    );
569    syscall_body!(sys_getpeername, {
570        if addr.is_null() || addrlen.is_null() {
571            return Err(LinuxError::EFAULT);
572        }
573        if unsafe { *addrlen } < size_of::<ctypes::sockaddr>() as u32 {
574            return Err(LinuxError::EINVAL);
575        }
576        unsafe {
577            (*addr, *addrlen) = into_sockaddr(Socket::from_fd(sock_fd)?.peer_addr()?);
578        }
579        Ok(0)
580    })
581}