1use core::sync::atomic::{AtomicU64, Ordering};
4
5use axtask::{WaitQueue, current};
6
7pub struct RawMutex {
13 wq: WaitQueue,
14 owner_id: AtomicU64,
15}
16
17impl RawMutex {
18 #[inline(always)]
20 pub const fn new() -> Self {
21 Self {
22 wq: WaitQueue::new(),
23 owner_id: AtomicU64::new(0),
24 }
25 }
26}
27
28unsafe impl lock_api::RawMutex for RawMutex {
29 const INIT: Self = RawMutex::new();
30
31 type GuardMarker = lock_api::GuardSend;
32
33 #[inline(always)]
34 fn lock(&self) {
35 let current_id = current().id().as_u64();
36 loop {
37 match self.owner_id.compare_exchange_weak(
40 0,
41 current_id,
42 Ordering::Acquire,
43 Ordering::Relaxed,
44 ) {
45 Ok(_) => break,
46 Err(owner_id) => {
47 assert_ne!(
48 owner_id,
49 current_id,
50 "{} tried to acquire mutex it already owns.",
51 current().id_name()
52 );
53 self.wq.wait_until(|| !self.is_locked());
55 }
56 }
57 }
58 }
59
60 #[inline(always)]
61 fn try_lock(&self) -> bool {
62 let current_id = current().id().as_u64();
63 self.owner_id
66 .compare_exchange(0, current_id, Ordering::Acquire, Ordering::Relaxed)
67 .is_ok()
68 }
69
70 #[inline(always)]
71 unsafe fn unlock(&self) {
72 let owner_id = self.owner_id.swap(0, Ordering::Release);
73 assert_eq!(
74 owner_id,
75 current().id().as_u64(),
76 "{} tried to release mutex it doesn't own",
77 current().id_name()
78 );
79 self.wq.notify_one(true);
80 }
81
82 #[inline(always)]
83 fn is_locked(&self) -> bool {
84 self.owner_id.load(Ordering::Relaxed) != 0
85 }
86}
87
88pub type Mutex<T> = lock_api::Mutex<RawMutex, T>;
90pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawMutex, T>;
92
93#[cfg(test)]
94mod tests {
95 use crate::Mutex;
96 use axtask as thread;
97 use std::sync::Once;
98
99 static INIT: Once = Once::new();
100
101 fn may_interrupt() {
102 if rand::random::<u32>() % 3 == 0 {
104 thread::yield_now();
105 }
106 }
107
108 #[test]
109 fn lots_and_lots() {
110 INIT.call_once(thread::init_scheduler);
111
112 const NUM_TASKS: u32 = 10;
113 const NUM_ITERS: u32 = 10_000;
114 static M: Mutex<u32> = Mutex::new(0);
115
116 fn inc(delta: u32) {
117 for _ in 0..NUM_ITERS {
118 let mut val = M.lock();
119 *val += delta;
120 may_interrupt();
121 drop(val);
122 may_interrupt();
123 }
124 }
125
126 for _ in 0..NUM_TASKS {
127 thread::spawn(|| inc(1));
128 thread::spawn(|| inc(2));
129 }
130
131 println!("spawn OK");
132 loop {
133 let val = M.lock();
134 if *val == NUM_ITERS * NUM_TASKS * 3 {
135 break;
136 }
137 may_interrupt();
138 drop(val);
139 may_interrupt();
140 }
141
142 assert_eq!(*M.lock(), NUM_ITERS * NUM_TASKS * 3);
143 println!("Mutex test OK");
144 }
145}