1use abyss::spinlock::{SpinLock, SpinLockGuard};
4use core::{
5 cell::UnsafeCell,
6 ops::{Deref, DerefMut},
7 sync::atomic::{AtomicUsize, Ordering},
8};
9
10pub struct RwLock<T>
34where
35 T: ?Sized + Send,
36{
37 state: AtomicUsize,
43 owner: SpinLock<Option<(u64, &'static core::panic::Location<'static>)>>,
44 data: UnsafeCell<T>,
45}
46
47const STATE_MASK: usize = 0b1 << (usize::BITS - 2);
48const STATE_WRITER_LOCKED: usize = 0b1 << (usize::BITS - 2);
49
50#[inline]
51fn is_write_locked(b: usize) -> bool {
52 b & STATE_MASK == STATE_WRITER_LOCKED
53}
54
55pub struct RwLockWriteGuard<'a, T>
65where
66 T: ?Sized + Send,
67 T: 'a,
68{
69 lock: &'a RwLock<T>,
70 data: &'a mut T,
71}
72
73pub struct RwLockReadGuard<'a, T>
83where
84 T: ?Sized + Send,
85 T: 'a,
86{
87 lock: &'a RwLock<T>,
88 data: &'a T,
89}
90
91impl<'a, T> RwLockReadGuard<'a, T>
92where
93 T: ?Sized + Send,
94 T: 'a,
95{
96 #[track_caller]
98 pub fn upgrade(self) -> RwLockWriteGuard<'a, T> {
99 let this = core::mem::ManuallyDrop::new(self);
100 let lock = unsafe { core::ptr::read(&this.lock) };
101 loop {
102 let mut guard = lock.owner.lock();
103 if lock
104 .state
105 .compare_exchange(1, STATE_WRITER_LOCKED, Ordering::Acquire, Ordering::Acquire)
106 .is_ok()
107 {
108 *guard = Some((
109 crate::thread::Current::get_tid(),
110 core::panic::Location::caller(),
111 ));
112 guard.unlock();
113
114 break RwLockWriteGuard {
115 lock,
116 data: unsafe { &mut *lock.data.get() },
117 };
118 }
119 guard.unlock();
120 }
121 }
122}
123
124impl<'a, T> RwLockWriteGuard<'a, T>
125where
126 T: ?Sized + Send,
127 T: 'a,
128{
129 pub fn downgrade(self) -> RwLockReadGuard<'a, T> {
131 let this = core::mem::ManuallyDrop::new(self);
132 let lock = unsafe { core::ptr::read(&this.lock) };
133 assert!(
134 lock.state
135 .compare_exchange(STATE_WRITER_LOCKED, 1, Ordering::Acquire, Ordering::Acquire)
136 .is_ok()
137 );
138 RwLockReadGuard {
139 lock,
140 data: unsafe { &*lock.data.get() },
141 }
142 }
143}
144
145impl<T> RwLock<T>
146where
147 T: Send,
148{
149 pub const fn new(data: T) -> RwLock<T> {
151 RwLock {
152 state: AtomicUsize::new(0),
153 owner: SpinLock::new(None),
154 data: UnsafeCell::new(data),
155 }
156 }
157
158 #[inline]
159 fn validate_state(
160 &self,
161 owner: SpinLockGuard<Option<(u64, &'static core::panic::Location<'static>)>>,
162 ) {
163 {
164 let owner = owner.expect("RwLock is in unexpected state.");
165 if owner.0 == crate::thread::Current::get_tid() {
166 panic!(
167 "Try to acquiring ReadGuard on the thread holding the WriteGuard acquired on {:?}.",
168 owner.1
169 );
170 }
171 }
172 owner.unlock();
173 }
174
175 #[inline]
176 fn read_lock(&self) {
177 loop {
178 let guard = self.owner.lock();
179 let prev = self.state.load(Ordering::Relaxed);
180 if is_write_locked(prev) {
181 self.validate_state(guard);
182 core::hint::spin_loop();
183 } else if self
184 .state
185 .compare_exchange(prev, prev + 1, Ordering::Acquire, Ordering::Acquire)
186 .is_ok()
187 {
188 guard.unlock();
189 break;
190 } else {
191 guard.unlock();
192 }
193 }
194 }
195
196 #[inline]
209 #[track_caller]
210 pub fn read(&self) -> RwLockReadGuard<'_, T> {
211 if let Ok(guard) = self.try_read() {
212 guard
213 } else {
214 self.read_lock();
215 RwLockReadGuard {
216 lock: self,
217 data: unsafe { &*self.data.get() },
218 }
219 }
220 }
221
222 #[inline]
234 #[track_caller]
235 pub fn try_read(&self) -> Result<RwLockReadGuard<'_, T>, crate::spinlock::WouldBlock> {
236 loop {
237 let guard = self.owner.lock();
238 let prev = self.state.load(Ordering::Relaxed);
239 if is_write_locked(prev) {
240 self.validate_state(guard);
241 break Err(crate::spinlock::WouldBlock);
242 } else if self
243 .state
244 .compare_exchange(prev, prev + 1, Ordering::Acquire, Ordering::Acquire)
245 .is_ok()
246 {
247 guard.unlock();
248 break Ok(RwLockReadGuard {
249 lock: self,
250 data: unsafe { &*self.data.get() },
251 });
252 }
253 guard.unlock();
254 }
255 }
256
257 #[inline]
258 fn write_lock(&self) {
259 loop {
260 let prev = self.state.load(Ordering::Relaxed);
261 if prev > 0 {
262 core::hint::spin_loop();
263 } else if self
264 .state
265 .compare_exchange(0, STATE_WRITER_LOCKED, Ordering::Acquire, Ordering::Acquire)
266 .is_ok()
267 {
268 break;
269 }
270 }
271 }
272
273 #[inline]
282 #[track_caller]
283 pub fn write(&self) -> RwLockWriteGuard<'_, T> {
284 if let Ok(guard) = self.try_write() {
285 guard
286 } else {
287 let mut guard = self.owner.lock();
288 self.write_lock();
289 *guard = Some((
290 crate::thread::Current::get_tid(),
291 core::panic::Location::caller(),
292 ));
293 guard.unlock();
294 RwLockWriteGuard {
295 lock: self,
296 data: unsafe { &mut *self.data.get() },
297 }
298 }
299 }
300
301 #[track_caller]
313 pub fn try_write(&self) -> Result<RwLockWriteGuard<'_, T>, crate::spinlock::WouldBlock> {
314 loop {
315 let mut guard = self.owner.lock();
316 let prev = self.state.load(Ordering::Relaxed);
317 if prev > 0 {
318 guard.unlock();
319 break Err(crate::spinlock::WouldBlock);
320 } else if self
321 .state
322 .compare_exchange(
323 prev,
324 prev | STATE_WRITER_LOCKED,
325 Ordering::Acquire,
326 Ordering::Acquire,
327 )
328 .is_ok()
329 {
330 *guard = Some((
331 crate::thread::Current::get_tid(),
332 core::panic::Location::caller(),
333 ));
334 guard.unlock();
335
336 break Ok(RwLockWriteGuard {
337 lock: self,
338 data: unsafe { &mut *self.data.get() },
339 });
340 }
341 guard.unlock();
342 }
343 }
344 #[inline]
349 #[allow(clippy::mut_from_ref)]
350 pub unsafe fn steal(&self) -> &mut T {
351 unsafe { &mut *self.data.get() }
352 }
353
354 #[inline]
356 pub fn into_inner(self) -> T {
357 self.data.into_inner()
358 }
359}
360
361unsafe impl<T> Sync for RwLock<T> where T: ?Sized + Send {}
362unsafe impl<T> Send for RwLock<T> where T: ?Sized + Send {}
363
364impl<T: Send> core::fmt::Debug for RwLock<T> {
365 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
366 f.debug_struct("RwLock")
367 .field("state", &self.state.load(Ordering::SeqCst))
368 .finish()
369 }
370}
371
372impl<'a, T> Deref for RwLockReadGuard<'a, T>
373where
374 T: ?Sized + Send,
375{
376 type Target = T;
377 fn deref(&self) -> &T {
378 self.data
379 }
380}
381
382impl<'a, T> Deref for RwLockWriteGuard<'a, T>
383where
384 T: ?Sized + Send,
385{
386 type Target = T;
387 fn deref(&self) -> &T {
388 &*self.data
389 }
390}
391
392impl<'a, T> DerefMut for RwLockWriteGuard<'a, T>
393where
394 T: ?Sized + Send,
395{
396 fn deref_mut(&mut self) -> &mut T {
397 &mut *self.data
398 }
399}
400
401impl<'a, T> Drop for RwLockReadGuard<'a, T>
402where
403 T: ?Sized + Send,
404{
405 #[track_caller]
406 fn drop(&mut self) {
407 debug_assert_eq!(self.lock.state.load(Ordering::Acquire) & STATE_MASK, 0);
408 self.lock.state.fetch_sub(1, Ordering::Release);
409 }
410}
411
412impl<'a, T> Drop for RwLockWriteGuard<'a, T>
413where
414 T: ?Sized + Send,
415{
416 #[track_caller]
417 fn drop(&mut self) {
418 debug_assert_eq!(
419 self.lock.state.load(Ordering::Acquire) & STATE_MASK,
420 STATE_WRITER_LOCKED
421 );
422 self.lock
423 .state
424 .fetch_and(!STATE_WRITER_LOCKED, Ordering::Release);
425 }
426}