axsync/
mutex.rs

1//! A naïve sleeping mutex.
2
3use core::sync::atomic::{AtomicU64, Ordering};
4
5use axtask::{WaitQueue, current};
6
7/// A [`lock_api::RawMutex`] implementation.
8///
9/// When the mutex is locked, the current task will block and be put into the
10/// wait queue. When the mutex is unlocked, all tasks waiting on the queue
11/// will be woken up.
12pub struct RawMutex {
13    wq: WaitQueue,
14    owner_id: AtomicU64,
15}
16
17impl RawMutex {
18    /// Creates a [`RawMutex`].
19    #[inline(always)]
20    pub const fn new() -> Self {
21        Self {
22            wq: WaitQueue::new(),
23            owner_id: AtomicU64::new(0),
24        }
25    }
26}
27
28unsafe impl lock_api::RawMutex for RawMutex {
29    const INIT: Self = RawMutex::new();
30
31    type GuardMarker = lock_api::GuardSend;
32
33    #[inline(always)]
34    fn lock(&self) {
35        let current_id = current().id().as_u64();
36        loop {
37            // Can fail to lock even if the spinlock is not locked. May be more efficient than `try_lock`
38            // when called in a loop.
39            match self.owner_id.compare_exchange_weak(
40                0,
41                current_id,
42                Ordering::Acquire,
43                Ordering::Relaxed,
44            ) {
45                Ok(_) => break,
46                Err(owner_id) => {
47                    assert_ne!(
48                        owner_id,
49                        current_id,
50                        "{} tried to acquire mutex it already owns.",
51                        current().id_name()
52                    );
53                    // Wait until the lock looks unlocked before retrying
54                    self.wq.wait_until(|| !self.is_locked());
55                }
56            }
57        }
58    }
59
60    #[inline(always)]
61    fn try_lock(&self) -> bool {
62        let current_id = current().id().as_u64();
63        // The reason for using a strong compare_exchange is explained here:
64        // https://github.com/Amanieu/parking_lot/pull/207#issuecomment-575869107
65        self.owner_id
66            .compare_exchange(0, current_id, Ordering::Acquire, Ordering::Relaxed)
67            .is_ok()
68    }
69
70    #[inline(always)]
71    unsafe fn unlock(&self) {
72        let owner_id = self.owner_id.swap(0, Ordering::Release);
73        assert_eq!(
74            owner_id,
75            current().id().as_u64(),
76            "{} tried to release mutex it doesn't own",
77            current().id_name()
78        );
79        self.wq.notify_one(true);
80    }
81
82    #[inline(always)]
83    fn is_locked(&self) -> bool {
84        self.owner_id.load(Ordering::Relaxed) != 0
85    }
86}
87
88/// An alias of [`lock_api::Mutex`].
89pub type Mutex<T> = lock_api::Mutex<RawMutex, T>;
90/// An alias of [`lock_api::MutexGuard`].
91pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawMutex, T>;
92
93#[cfg(test)]
94mod tests {
95    use crate::Mutex;
96    use axtask as thread;
97    use std::sync::Once;
98
99    static INIT: Once = Once::new();
100
101    fn may_interrupt() {
102        // simulate interrupts
103        if rand::random::<u32>() % 3 == 0 {
104            thread::yield_now();
105        }
106    }
107
108    #[test]
109    fn lots_and_lots() {
110        INIT.call_once(thread::init_scheduler);
111
112        const NUM_TASKS: u32 = 10;
113        const NUM_ITERS: u32 = 10_000;
114        static M: Mutex<u32> = Mutex::new(0);
115
116        fn inc(delta: u32) {
117            for _ in 0..NUM_ITERS {
118                let mut val = M.lock();
119                *val += delta;
120                may_interrupt();
121                drop(val);
122                may_interrupt();
123            }
124        }
125
126        for _ in 0..NUM_TASKS {
127            thread::spawn(|| inc(1));
128            thread::spawn(|| inc(2));
129        }
130
131        println!("spawn OK");
132        loop {
133            let val = M.lock();
134            if *val == NUM_ITERS * NUM_TASKS * 3 {
135                break;
136            }
137            may_interrupt();
138            drop(val);
139            may_interrupt();
140        }
141
142        assert_eq!(*M.lock(), NUM_ITERS * NUM_TASKS * 3);
143        println!("Mutex test OK");
144    }
145}