axstd/thread/multi.rs
1//! Thread APIs for multi-threading configuration.
2
3extern crate alloc;
4
5use crate::io;
6use alloc::{string::String, sync::Arc};
7use core::{cell::UnsafeCell, num::NonZeroU64};
8
9use arceos_api::task::{self as api, AxTaskHandle};
10use axerrno::ax_err_type;
11
12/// A unique identifier for a running thread.
13#[derive(Eq, PartialEq, Clone, Copy, Debug)]
14pub struct ThreadId(NonZeroU64);
15
16/// A handle to a thread.
17pub struct Thread {
18 id: ThreadId,
19}
20
21impl ThreadId {
22 /// This returns a numeric identifier for the thread identified by this
23 /// `ThreadId`.
24 pub fn as_u64(&self) -> NonZeroU64 {
25 self.0
26 }
27}
28
29impl Thread {
30 fn from_id(id: u64) -> Self {
31 Self {
32 id: ThreadId(NonZeroU64::new(id).unwrap()),
33 }
34 }
35
36 /// Gets the thread's unique identifier.
37 pub fn id(&self) -> ThreadId {
38 self.id
39 }
40}
41
42/// Thread factory, which can be used in order to configure the properties of
43/// a new thread.
44///
45/// Methods can be chained on it in order to configure it.
46#[derive(Debug)]
47pub struct Builder {
48 // A name for the thread-to-be, for identification in panic messages
49 name: Option<String>,
50 // The size of the stack for the spawned thread in bytes
51 stack_size: Option<usize>,
52}
53
54impl Builder {
55 /// Generates the base configuration for spawning a thread, from which
56 /// configuration methods can be chained.
57 pub const fn new() -> Builder {
58 Builder {
59 name: None,
60 stack_size: None,
61 }
62 }
63
64 /// Names the thread-to-be.
65 pub fn name(mut self, name: String) -> Builder {
66 self.name = Some(name);
67 self
68 }
69
70 /// Sets the size of the stack (in bytes) for the new thread.
71 pub fn stack_size(mut self, size: usize) -> Builder {
72 self.stack_size = Some(size);
73 self
74 }
75
76 /// Spawns a new thread by taking ownership of the `Builder`, and returns an
77 /// [`io::Result`] to its [`JoinHandle`].
78 ///
79 /// The spawned thread may outlive the caller (unless the caller thread
80 /// is the main thread; the whole process is terminated when the main
81 /// thread finishes). The join handle can be used to block on
82 /// termination of the spawned thread.
83 pub fn spawn<F, T>(self, f: F) -> io::Result<JoinHandle<T>>
84 where
85 F: FnOnce() -> T,
86 F: Send + 'static,
87 T: Send + 'static,
88 {
89 unsafe { self.spawn_unchecked(f) }
90 }
91
92 unsafe fn spawn_unchecked<F, T>(self, f: F) -> io::Result<JoinHandle<T>>
93 where
94 F: FnOnce() -> T,
95 F: Send + 'static,
96 T: Send + 'static,
97 {
98 let name = self.name.unwrap_or_default();
99 let stack_size = self
100 .stack_size
101 .unwrap_or(arceos_api::config::TASK_STACK_SIZE);
102
103 let my_packet = Arc::new(Packet {
104 result: UnsafeCell::new(None),
105 });
106 let their_packet = my_packet.clone();
107
108 let main = move || {
109 let ret = f();
110 // SAFETY: `their_packet` as been built just above and moved by the
111 // closure (it is an Arc<...>) and `my_packet` will be stored in the
112 // same `JoinHandle` as this closure meaning the mutation will be
113 // safe (not modify it and affect a value far away).
114 unsafe { *their_packet.result.get() = Some(ret) };
115 drop(their_packet);
116 };
117
118 let task = api::ax_spawn(main, name, stack_size);
119 Ok(JoinHandle {
120 thread: Thread::from_id(task.id()),
121 native: task,
122 packet: my_packet,
123 })
124 }
125}
126
127/// Gets a handle to the thread that invokes it.
128pub fn current() -> Thread {
129 let id = api::ax_current_task_id();
130 Thread::from_id(id)
131}
132
133/// Spawns a new thread, returning a [`JoinHandle`] for it.
134///
135/// The join handle provides a [`join`] method that can be used to join the
136/// spawned thread.
137///
138/// The default task name is an empty string. The default thread stack size is
139/// [`arceos_api::config::TASK_STACK_SIZE`].
140///
141/// [`join`]: JoinHandle::join
142pub fn spawn<T, F>(f: F) -> JoinHandle<T>
143where
144 F: FnOnce() -> T + Send + 'static,
145 T: Send + 'static,
146{
147 Builder::new().spawn(f).expect("failed to spawn thread")
148}
149
150struct Packet<T> {
151 result: UnsafeCell<Option<T>>,
152}
153
154unsafe impl<T> Sync for Packet<T> {}
155
156/// An owned permission to join on a thread (block on its termination).
157///
158/// A `JoinHandle` *detaches* the associated thread when it is dropped, which
159/// means that there is no longer any handle to the thread and no way to `join`
160/// on it.
161pub struct JoinHandle<T> {
162 native: AxTaskHandle,
163 thread: Thread,
164 packet: Arc<Packet<T>>,
165}
166
167unsafe impl<T> Send for JoinHandle<T> {}
168unsafe impl<T> Sync for JoinHandle<T> {}
169
170impl<T> JoinHandle<T> {
171 /// Extracts a handle to the underlying thread.
172 pub fn thread(&self) -> &Thread {
173 &self.thread
174 }
175
176 /// Waits for the associated thread to finish.
177 ///
178 /// This function will return immediately if the associated thread has
179 /// already finished.
180 pub fn join(mut self) -> io::Result<T> {
181 api::ax_wait_for_exit(self.native).ok_or_else(|| ax_err_type!(BadState))?;
182 Arc::get_mut(&mut self.packet)
183 .unwrap()
184 .result
185 .get_mut()
186 .take()
187 .ok_or_else(|| ax_err_type!(BadState))
188 }
189}