axsync/
mutex.rs

1//! A naïve sleeping mutex.
2
3use core::sync::atomic::{AtomicU64, Ordering};
4
5use axtask::{WaitQueue, current};
6
7/// A [`lock_api::RawMutex`] implementation.
8///
9/// When the mutex is locked, the current task will block and be put into the
10/// wait queue. When the mutex is unlocked, all tasks waiting on the queue
11/// will be woken up.
12pub struct RawMutex {
13    wq: WaitQueue,
14    owner_id: AtomicU64,
15}
16
17impl RawMutex {
18    /// Creates a [`RawMutex`].
19    #[inline(always)]
20    pub const fn new() -> Self {
21        Self {
22            wq: WaitQueue::new(),
23            owner_id: AtomicU64::new(0),
24        }
25    }
26}
27
28unsafe impl lock_api::RawMutex for RawMutex {
29    /// Initial value for an unlocked mutex.
30    ///
31    /// A “non-constant” const item is a legacy way to supply an initialized value to downstream
32    /// static items. Can hopefully be replaced with `const fn new() -> Self` at some point.
33    #[allow(clippy::declare_interior_mutable_const)]
34    const INIT: Self = RawMutex::new();
35
36    type GuardMarker = lock_api::GuardSend;
37
38    #[inline(always)]
39    fn lock(&self) {
40        let current_id = current().id().as_u64();
41        loop {
42            // Can fail to lock even if the spinlock is not locked. May be more efficient than `try_lock`
43            // when called in a loop.
44            match self.owner_id.compare_exchange_weak(
45                0,
46                current_id,
47                Ordering::Acquire,
48                Ordering::Relaxed,
49            ) {
50                Ok(_) => break,
51                Err(owner_id) => {
52                    assert_ne!(
53                        owner_id,
54                        current_id,
55                        "{} tried to acquire mutex it already owns.",
56                        current().id_name()
57                    );
58                    // Wait until the lock looks unlocked before retrying
59                    self.wq.wait_until(|| !self.is_locked());
60                }
61            }
62        }
63    }
64
65    #[inline(always)]
66    fn try_lock(&self) -> bool {
67        let current_id = current().id().as_u64();
68        // The reason for using a strong compare_exchange is explained here:
69        // https://github.com/Amanieu/parking_lot/pull/207#issuecomment-575869107
70        self.owner_id
71            .compare_exchange(0, current_id, Ordering::Acquire, Ordering::Relaxed)
72            .is_ok()
73    }
74
75    #[inline(always)]
76    unsafe fn unlock(&self) {
77        let owner_id = self.owner_id.swap(0, Ordering::Release);
78        assert_eq!(
79            owner_id,
80            current().id().as_u64(),
81            "{} tried to release mutex it doesn't own",
82            current().id_name()
83        );
84        self.wq.notify_one(true);
85    }
86
87    #[inline(always)]
88    fn is_locked(&self) -> bool {
89        self.owner_id.load(Ordering::Relaxed) != 0
90    }
91}
92
93/// An alias of [`lock_api::Mutex`].
94pub type Mutex<T> = lock_api::Mutex<RawMutex, T>;
95/// An alias of [`lock_api::MutexGuard`].
96pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawMutex, T>;
97
98#[cfg(test)]
99mod tests {
100    use crate::Mutex;
101    use axtask as thread;
102    use std::sync::Once;
103
104    static INIT: Once = Once::new();
105
106    fn may_interrupt() {
107        // simulate interrupts
108        if rand::random::<u32>() % 3 == 0 {
109            thread::yield_now();
110        }
111    }
112
113    #[test]
114    fn lots_and_lots() {
115        INIT.call_once(thread::init_scheduler);
116
117        const NUM_TASKS: u32 = 10;
118        const NUM_ITERS: u32 = 10_000;
119        static M: Mutex<u32> = Mutex::new(0);
120
121        fn inc(delta: u32) {
122            for _ in 0..NUM_ITERS {
123                let mut val = M.lock();
124                *val += delta;
125                may_interrupt();
126                drop(val);
127                may_interrupt();
128            }
129        }
130
131        for _ in 0..NUM_TASKS {
132            thread::spawn(|| inc(1));
133            thread::spawn(|| inc(2));
134        }
135
136        println!("spawn OK");
137        loop {
138            let val = M.lock();
139            if *val == NUM_ITERS * NUM_TASKS * 3 {
140                break;
141            }
142            may_interrupt();
143            drop(val);
144            may_interrupt();
145        }
146
147        assert_eq!(*M.lock(), NUM_ITERS * NUM_TASKS * 3);
148        println!("Mutex test OK");
149    }
150}