arceos_posix_api/imp/io_mpx/
select.rs

1use core::ffi::c_int;
2
3use axerrno::{LinuxError, LinuxResult};
4use axhal::time::wall_time;
5
6use crate::{ctypes, imp::fd_ops::get_file_like};
7
8const FD_SETSIZE: usize = 1024;
9const BITS_PER_USIZE: usize = usize::BITS as usize;
10const FD_SETSIZE_USIZES: usize = FD_SETSIZE.div_ceil(BITS_PER_USIZE);
11
12struct FdSets {
13    nfds: usize,
14    bits: [usize; FD_SETSIZE_USIZES * 3],
15}
16
17impl FdSets {
18    fn from(
19        nfds: usize,
20        read_fds: *const ctypes::fd_set,
21        write_fds: *const ctypes::fd_set,
22        except_fds: *const ctypes::fd_set,
23    ) -> Self {
24        let nfds = nfds.min(FD_SETSIZE);
25        let nfds_usizes = nfds.div_ceil(BITS_PER_USIZE);
26        let mut bits = core::mem::MaybeUninit::<[usize; FD_SETSIZE_USIZES * 3]>::uninit();
27        let bits_ptr: *mut usize = unsafe { core::mem::transmute(bits.as_mut_ptr()) };
28
29        let copy_from_fd_set = |bits_ptr: *mut usize, fds: *const ctypes::fd_set| unsafe {
30            let dst = core::slice::from_raw_parts_mut(bits_ptr, nfds_usizes);
31            if fds.is_null() {
32                dst.fill(0);
33            } else {
34                let fds_ptr = (*fds).fds_bits.as_ptr() as *const usize;
35                let src = core::slice::from_raw_parts(fds_ptr, nfds_usizes);
36                dst.copy_from_slice(src);
37            }
38        };
39
40        let bits = unsafe {
41            copy_from_fd_set(bits_ptr, read_fds);
42            copy_from_fd_set(bits_ptr.add(FD_SETSIZE_USIZES), write_fds);
43            copy_from_fd_set(bits_ptr.add(FD_SETSIZE_USIZES * 2), except_fds);
44            bits.assume_init()
45        };
46        Self { nfds, bits }
47    }
48
49    fn poll_all(
50        &self,
51        res_read_fds: *mut ctypes::fd_set,
52        res_write_fds: *mut ctypes::fd_set,
53        res_except_fds: *mut ctypes::fd_set,
54    ) -> LinuxResult<usize> {
55        let mut read_bits_ptr = self.bits.as_ptr();
56        let mut write_bits_ptr = unsafe { read_bits_ptr.add(FD_SETSIZE_USIZES) };
57        let mut execpt_bits_ptr = unsafe { read_bits_ptr.add(FD_SETSIZE_USIZES * 2) };
58        let mut i = 0;
59        let mut res_num = 0;
60        while i < self.nfds {
61            let read_bits = unsafe { *read_bits_ptr };
62            let write_bits = unsafe { *write_bits_ptr };
63            let except_bits = unsafe { *execpt_bits_ptr };
64            unsafe {
65                read_bits_ptr = read_bits_ptr.add(1);
66                write_bits_ptr = write_bits_ptr.add(1);
67                execpt_bits_ptr = execpt_bits_ptr.add(1);
68            }
69
70            let all_bits = read_bits | write_bits | except_bits;
71            if all_bits == 0 {
72                i += BITS_PER_USIZE;
73                continue;
74            }
75            let mut j = 0;
76            while j < BITS_PER_USIZE && i + j < self.nfds {
77                let bit = 1 << j;
78                if all_bits & bit == 0 {
79                    j += 1;
80                    continue;
81                }
82                let fd = i + j;
83                match get_file_like(fd as _)?.poll() {
84                    Ok(state) => {
85                        if state.readable && read_bits & bit != 0 {
86                            unsafe { set_fd_set(res_read_fds, fd) };
87                            res_num += 1;
88                        }
89                        if state.writable && write_bits & bit != 0 {
90                            unsafe { set_fd_set(res_write_fds, fd) };
91                            res_num += 1;
92                        }
93                    }
94                    Err(e) => {
95                        debug!("    except: {} {:?}", fd, e);
96                        if except_bits & bit != 0 {
97                            unsafe { set_fd_set(res_except_fds, fd) };
98                            res_num += 1;
99                        }
100                    }
101                }
102                j += 1;
103            }
104            i += BITS_PER_USIZE;
105        }
106        Ok(res_num)
107    }
108}
109
110/// Monitor multiple file descriptors, waiting until one or more of the file descriptors become "ready" for some class of I/O operation
111pub unsafe fn sys_select(
112    nfds: c_int,
113    readfds: *mut ctypes::fd_set,
114    writefds: *mut ctypes::fd_set,
115    exceptfds: *mut ctypes::fd_set,
116    timeout: *mut ctypes::timeval,
117) -> c_int {
118    debug!(
119        "sys_select <= {} {:#x} {:#x} {:#x}",
120        nfds, readfds as usize, writefds as usize, exceptfds as usize
121    );
122    syscall_body!(sys_select, {
123        if nfds < 0 {
124            return Err(LinuxError::EINVAL);
125        }
126        let nfds = (nfds as usize).min(FD_SETSIZE);
127        let deadline = unsafe { timeout.as_ref().map(|t| wall_time() + (*t).into()) };
128        let fd_sets = FdSets::from(nfds, readfds, writefds, exceptfds);
129
130        unsafe {
131            zero_fd_set(readfds, nfds);
132            zero_fd_set(writefds, nfds);
133            zero_fd_set(exceptfds, nfds);
134        }
135
136        loop {
137            #[cfg(feature = "net")]
138            axnet::poll_interfaces();
139            let res = fd_sets.poll_all(readfds, writefds, exceptfds)?;
140            if res > 0 {
141                return Ok(res);
142            }
143
144            if deadline.is_some_and(|ddl| wall_time() >= ddl) {
145                debug!("    timeout!");
146                return Ok(0);
147            }
148            crate::sys_sched_yield();
149        }
150    })
151}
152
153unsafe fn zero_fd_set(fds: *mut ctypes::fd_set, nfds: usize) {
154    if !fds.is_null() {
155        let nfds_usizes = nfds.div_ceil(BITS_PER_USIZE);
156        let dst = &mut unsafe { *fds }.fds_bits[..nfds_usizes];
157        dst.fill(0);
158    }
159}
160
161unsafe fn set_fd_set(fds: *mut ctypes::fd_set, fd: usize) {
162    if !fds.is_null() {
163        unsafe { *fds }.fds_bits[fd / BITS_PER_USIZE] |= 1 << (fd % BITS_PER_USIZE);
164    }
165}