1use core::sync::atomic::{AtomicU64, Ordering};
4
5use axtask::{WaitQueue, current};
6
7pub struct RawMutex {
13 wq: WaitQueue,
14 owner_id: AtomicU64,
15}
16
17impl RawMutex {
18 #[inline(always)]
20 pub const fn new() -> Self {
21 Self {
22 wq: WaitQueue::new(),
23 owner_id: AtomicU64::new(0),
24 }
25 }
26}
27
28unsafe impl lock_api::RawMutex for RawMutex {
29 #[allow(clippy::declare_interior_mutable_const)]
34 const INIT: Self = RawMutex::new();
35
36 type GuardMarker = lock_api::GuardSend;
37
38 #[inline(always)]
39 fn lock(&self) {
40 let current_id = current().id().as_u64();
41 loop {
42 match self.owner_id.compare_exchange_weak(
45 0,
46 current_id,
47 Ordering::Acquire,
48 Ordering::Relaxed,
49 ) {
50 Ok(_) => break,
51 Err(owner_id) => {
52 assert_ne!(
53 owner_id,
54 current_id,
55 "{} tried to acquire mutex it already owns.",
56 current().id_name()
57 );
58 self.wq.wait_until(|| !self.is_locked());
60 }
61 }
62 }
63 }
64
65 #[inline(always)]
66 fn try_lock(&self) -> bool {
67 let current_id = current().id().as_u64();
68 self.owner_id
71 .compare_exchange(0, current_id, Ordering::Acquire, Ordering::Relaxed)
72 .is_ok()
73 }
74
75 #[inline(always)]
76 unsafe fn unlock(&self) {
77 let owner_id = self.owner_id.swap(0, Ordering::Release);
78 assert_eq!(
79 owner_id,
80 current().id().as_u64(),
81 "{} tried to release mutex it doesn't own",
82 current().id_name()
83 );
84 self.wq.notify_one(true);
85 }
86
87 #[inline(always)]
88 fn is_locked(&self) -> bool {
89 self.owner_id.load(Ordering::Relaxed) != 0
90 }
91}
92
93pub type Mutex<T> = lock_api::Mutex<RawMutex, T>;
95pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawMutex, T>;
97
98#[cfg(test)]
99mod tests {
100 use crate::Mutex;
101 use axtask as thread;
102 use std::sync::Once;
103
104 static INIT: Once = Once::new();
105
106 fn may_interrupt() {
107 if rand::random::<u32>() % 3 == 0 {
109 thread::yield_now();
110 }
111 }
112
113 #[test]
114 fn lots_and_lots() {
115 INIT.call_once(thread::init_scheduler);
116
117 const NUM_TASKS: u32 = 10;
118 const NUM_ITERS: u32 = 10_000;
119 static M: Mutex<u32> = Mutex::new(0);
120
121 fn inc(delta: u32) {
122 for _ in 0..NUM_ITERS {
123 let mut val = M.lock();
124 *val += delta;
125 may_interrupt();
126 drop(val);
127 may_interrupt();
128 }
129 }
130
131 for _ in 0..NUM_TASKS {
132 thread::spawn(|| inc(1));
133 thread::spawn(|| inc(2));
134 }
135
136 println!("spawn OK");
137 loop {
138 let val = M.lock();
139 if *val == NUM_ITERS * NUM_TASKS * 3 {
140 break;
141 }
142 may_interrupt();
143 drop(val);
144 may_interrupt();
145 }
146
147 assert_eq!(*M.lock(), NUM_ITERS * NUM_TASKS * 3);
148 println!("Mutex test OK");
149 }
150}