arceos_posix_api/imp/pthread/
mod.rs

1use alloc::{boxed::Box, collections::BTreeMap, sync::Arc};
2use core::cell::UnsafeCell;
3use core::ffi::{c_int, c_void};
4
5use axerrno::{LinuxError, LinuxResult};
6use axtask::AxTaskRef;
7use spin::RwLock;
8
9use crate::ctypes;
10
11pub mod mutex;
12
13lazy_static::lazy_static! {
14    static ref TID_TO_PTHREAD: RwLock<BTreeMap<u64, ForceSendSync<ctypes::pthread_t>>> = {
15        let mut map = BTreeMap::new();
16        let main_task = axtask::current();
17        let main_tid = main_task.id().as_u64();
18        let main_thread = Pthread {
19            inner: main_task.as_task_ref().clone(),
20            retval: Arc::new(Packet {
21                result: UnsafeCell::new(core::ptr::null_mut()),
22            }),
23        };
24        let ptr = Box::into_raw(Box::new(main_thread)) as *mut c_void;
25        map.insert(main_tid, ForceSendSync(ptr));
26        RwLock::new(map)
27    };
28}
29
30struct Packet<T> {
31    result: UnsafeCell<T>,
32}
33
34unsafe impl<T> Send for Packet<T> {}
35unsafe impl<T> Sync for Packet<T> {}
36
37pub struct Pthread {
38    inner: AxTaskRef,
39    retval: Arc<Packet<*mut c_void>>,
40}
41
42impl Pthread {
43    fn create(
44        _attr: *const ctypes::pthread_attr_t,
45        start_routine: extern "C" fn(arg: *mut c_void) -> *mut c_void,
46        arg: *mut c_void,
47    ) -> LinuxResult<ctypes::pthread_t> {
48        let arg_wrapper = ForceSendSync(arg);
49
50        let my_packet: Arc<Packet<*mut c_void>> = Arc::new(Packet {
51            result: UnsafeCell::new(core::ptr::null_mut()),
52        });
53        let their_packet = my_packet.clone();
54
55        let main = move || {
56            let arg = arg_wrapper;
57            let ret = start_routine(arg.0);
58            unsafe { *their_packet.result.get() = ret };
59            drop(their_packet);
60        };
61
62        let task_inner = axtask::spawn(main);
63        let tid = task_inner.id().as_u64();
64        let thread = Pthread {
65            inner: task_inner,
66            retval: my_packet,
67        };
68        let ptr = Box::into_raw(Box::new(thread)) as *mut c_void;
69        TID_TO_PTHREAD.write().insert(tid, ForceSendSync(ptr));
70        Ok(ptr)
71    }
72
73    fn current_ptr() -> *mut Pthread {
74        let tid = axtask::current().id().as_u64();
75        match TID_TO_PTHREAD.read().get(&tid) {
76            None => core::ptr::null_mut(),
77            Some(ptr) => ptr.0 as *mut Pthread,
78        }
79    }
80
81    fn current() -> Option<&'static Pthread> {
82        unsafe { core::ptr::NonNull::new(Self::current_ptr()).map(|ptr| ptr.as_ref()) }
83    }
84
85    fn exit_current(retval: *mut c_void) -> ! {
86        let thread = Self::current().expect("fail to get current thread");
87        unsafe { *thread.retval.result.get() = retval };
88        axtask::exit(0);
89    }
90
91    fn join(ptr: ctypes::pthread_t) -> LinuxResult<*mut c_void> {
92        if core::ptr::eq(ptr, Self::current_ptr() as _) {
93            return Err(LinuxError::EDEADLK);
94        }
95
96        let thread = unsafe { Box::from_raw(ptr as *mut Pthread) };
97        thread.inner.join();
98        let tid = thread.inner.id().as_u64();
99        let retval = unsafe { *thread.retval.result.get() };
100        TID_TO_PTHREAD.write().remove(&tid);
101        drop(thread);
102        Ok(retval)
103    }
104}
105
106/// Returns the `pthread` struct of current thread.
107pub fn sys_pthread_self() -> ctypes::pthread_t {
108    Pthread::current().expect("fail to get current thread") as *const Pthread as _
109}
110
111/// Create a new thread with the given entry point and argument.
112///
113/// If successful, it stores the pointer to the newly created `struct __pthread`
114/// in `res` and returns 0.
115pub unsafe fn sys_pthread_create(
116    res: *mut ctypes::pthread_t,
117    attr: *const ctypes::pthread_attr_t,
118    start_routine: extern "C" fn(arg: *mut c_void) -> *mut c_void,
119    arg: *mut c_void,
120) -> c_int {
121    debug!(
122        "sys_pthread_create <= {:#x}, {:#x}",
123        start_routine as usize, arg as usize
124    );
125    syscall_body!(sys_pthread_create, {
126        let ptr = Pthread::create(attr, start_routine, arg)?;
127        unsafe { core::ptr::write(res, ptr) };
128        Ok(0)
129    })
130}
131
132/// Exits the current thread. The value `retval` will be returned to the joiner.
133pub fn sys_pthread_exit(retval: *mut c_void) -> ! {
134    debug!("sys_pthread_exit <= {:#x}", retval as usize);
135    Pthread::exit_current(retval);
136}
137
138/// Waits for the given thread to exit, and stores the return value in `retval`.
139pub unsafe fn sys_pthread_join(thread: ctypes::pthread_t, retval: *mut *mut c_void) -> c_int {
140    debug!("sys_pthread_join <= {:#x}", retval as usize);
141    syscall_body!(sys_pthread_join, {
142        let ret = Pthread::join(thread)?;
143        if !retval.is_null() {
144            unsafe { core::ptr::write(retval, ret) };
145        }
146        Ok(0)
147    })
148}
149
150#[derive(Clone, Copy)]
151struct ForceSendSync<T>(T);
152
153unsafe impl<T> Send for ForceSendSync<T> {}
154unsafe impl<T> Sync for ForceSendSync<T> {}