axstd/thread/
multi.rs

1//! Thread APIs for multi-threading configuration.
2
3extern crate alloc;
4
5use crate::io;
6use alloc::{string::String, sync::Arc};
7use core::{cell::UnsafeCell, num::NonZeroU64};
8
9use arceos_api::task::{self as api, AxTaskHandle};
10use axerrno::ax_err_type;
11
12/// A unique identifier for a running thread.
13#[derive(Eq, PartialEq, Clone, Copy, Debug)]
14pub struct ThreadId(NonZeroU64);
15
16/// A handle to a thread.
17pub struct Thread {
18    id: ThreadId,
19}
20
21impl ThreadId {
22    /// This returns a numeric identifier for the thread identified by this
23    /// `ThreadId`.
24    pub fn as_u64(&self) -> NonZeroU64 {
25        self.0
26    }
27}
28
29impl Thread {
30    fn from_id(id: u64) -> Self {
31        Self {
32            id: ThreadId(NonZeroU64::new(id).unwrap()),
33        }
34    }
35
36    /// Gets the thread's unique identifier.
37    pub fn id(&self) -> ThreadId {
38        self.id
39    }
40}
41
42/// Thread factory, which can be used in order to configure the properties of
43/// a new thread.
44///
45/// Methods can be chained on it in order to configure it.
46#[derive(Debug)]
47pub struct Builder {
48    // A name for the thread-to-be, for identification in panic messages
49    name: Option<String>,
50    // The size of the stack for the spawned thread in bytes
51    stack_size: Option<usize>,
52}
53
54impl Builder {
55    /// Generates the base configuration for spawning a thread, from which
56    /// configuration methods can be chained.
57    pub const fn new() -> Builder {
58        Builder {
59            name: None,
60            stack_size: None,
61        }
62    }
63
64    /// Names the thread-to-be.
65    pub fn name(mut self, name: String) -> Builder {
66        self.name = Some(name);
67        self
68    }
69
70    /// Sets the size of the stack (in bytes) for the new thread.
71    pub fn stack_size(mut self, size: usize) -> Builder {
72        self.stack_size = Some(size);
73        self
74    }
75
76    /// Spawns a new thread by taking ownership of the `Builder`, and returns an
77    /// [`io::Result`] to its [`JoinHandle`].
78    ///
79    /// The spawned thread may outlive the caller (unless the caller thread
80    /// is the main thread; the whole process is terminated when the main
81    /// thread finishes). The join handle can be used to block on
82    /// termination of the spawned thread.
83    pub fn spawn<F, T>(self, f: F) -> io::Result<JoinHandle<T>>
84    where
85        F: FnOnce() -> T,
86        F: Send + 'static,
87        T: Send + 'static,
88    {
89        unsafe { self.spawn_unchecked(f) }
90    }
91
92    unsafe fn spawn_unchecked<F, T>(self, f: F) -> io::Result<JoinHandle<T>>
93    where
94        F: FnOnce() -> T,
95        F: Send + 'static,
96        T: Send + 'static,
97    {
98        let name = self.name.unwrap_or_default();
99        let stack_size = self
100            .stack_size
101            .unwrap_or(arceos_api::config::TASK_STACK_SIZE);
102
103        let my_packet = Arc::new(Packet {
104            result: UnsafeCell::new(None),
105        });
106        let their_packet = my_packet.clone();
107
108        let main = move || {
109            let ret = f();
110            // SAFETY: `their_packet` as been built just above and moved by the
111            // closure (it is an Arc<...>) and `my_packet` will be stored in the
112            // same `JoinHandle` as this closure meaning the mutation will be
113            // safe (not modify it and affect a value far away).
114            unsafe { *their_packet.result.get() = Some(ret) };
115            drop(their_packet);
116        };
117
118        let task = api::ax_spawn(main, name, stack_size);
119        Ok(JoinHandle {
120            thread: Thread::from_id(task.id()),
121            native: task,
122            packet: my_packet,
123        })
124    }
125}
126
127/// Gets a handle to the thread that invokes it.
128pub fn current() -> Thread {
129    let id = api::ax_current_task_id();
130    Thread::from_id(id)
131}
132
133/// Spawns a new thread, returning a [`JoinHandle`] for it.
134///
135/// The join handle provides a [`join`] method that can be used to join the
136/// spawned thread.
137///
138/// The default task name is an empty string. The default thread stack size is
139/// [`arceos_api::config::TASK_STACK_SIZE`].
140///
141/// [`join`]: JoinHandle::join
142pub fn spawn<T, F>(f: F) -> JoinHandle<T>
143where
144    F: FnOnce() -> T + Send + 'static,
145    T: Send + 'static,
146{
147    Builder::new().spawn(f).expect("failed to spawn thread")
148}
149
150struct Packet<T> {
151    result: UnsafeCell<Option<T>>,
152}
153
154unsafe impl<T> Sync for Packet<T> {}
155
156/// An owned permission to join on a thread (block on its termination).
157///
158/// A `JoinHandle` *detaches* the associated thread when it is dropped, which
159/// means that there is no longer any handle to the thread and no way to `join`
160/// on it.
161pub struct JoinHandle<T> {
162    native: AxTaskHandle,
163    thread: Thread,
164    packet: Arc<Packet<T>>,
165}
166
167unsafe impl<T> Send for JoinHandle<T> {}
168unsafe impl<T> Sync for JoinHandle<T> {}
169
170impl<T> JoinHandle<T> {
171    /// Extracts a handle to the underlying thread.
172    pub fn thread(&self) -> &Thread {
173        &self.thread
174    }
175
176    /// Waits for the associated thread to finish.
177    ///
178    /// This function will return immediately if the associated thread has
179    /// already finished.
180    pub fn join(mut self) -> io::Result<T> {
181        api::ax_wait_for_exit(self.native).ok_or_else(|| ax_err_type!(BadState))?;
182        Arc::get_mut(&mut self.packet)
183            .unwrap()
184            .result
185            .get_mut()
186            .take()
187            .ok_or_else(|| ax_err_type!(BadState))
188    }
189}