keos/sync/
rwlock.rs

1//! RwLock implementations.
2
3use abyss::spinlock::{SpinLock, SpinLockGuard};
4use core::{
5    cell::UnsafeCell,
6    ops::{Deref, DerefMut},
7    sync::atomic::{AtomicUsize, Ordering},
8};
9
10/// A reader-writer lock
11///
12/// This type of lock allows a number of readers or at most one writer at any
13/// point in time. The write portion of this lock typically allows modification
14/// of the underlying data (exclusive access) and the read portion of this lock
15/// typically allows for read-only access (shared access).
16///
17/// In comparison, a [`Mutex`] does not distinguish between readers or writers
18/// that acquire the lock, therefore blocking any threads waiting for the lock
19/// to become available. An `RwLock` will allow any number of readers to acquire
20/// the lock as long as a writer is not holding the lock.
21///
22/// The priority policy of the lock is dependent on the underlying operating
23/// system's implementation, and this type does not guarantee that any
24/// particular policy will be used.
25///
26/// The type parameter `T` represents the data that this lock protects. It is
27/// required that `T` satisfies [`Send`] to be shared across threads and
28/// [`Sync`] to allow concurrent access through readers. The RAII guards
29/// returned from the locking methods implement [`Deref`] (and [`DerefMut`]
30/// for the `write` methods) to allow access to the content of the lock.
31///
32/// [`Mutex`]: struct.Mutex.html
33pub struct RwLock<T>
34where
35    T: ?Sized + Send,
36{
37    // state:
38    // Upper 2bit represent the lock state.
39    // 0: Nobody try to get lock for writing.
40    // 1: Writer is waiting.
41    // 2: Writer holds the lock.
42    state: AtomicUsize,
43    owner: SpinLock<Option<(u64, &'static core::panic::Location<'static>)>>,
44    data: UnsafeCell<T>,
45}
46
47const STATE_MASK: usize = 0b1 << (usize::BITS - 2);
48const STATE_WRITER_LOCKED: usize = 0b1 << (usize::BITS - 2);
49
50#[inline]
51fn is_write_locked(b: usize) -> bool {
52    b & STATE_MASK == STATE_WRITER_LOCKED
53}
54
55/// RAII structure used to release the exclusive write access of a lock when
56/// dropped.
57///
58/// This structure is created by the [`write`] and [`try_write`] methods
59/// on [`RwLock`].
60///
61/// [`write`]: struct.RwLock.html#method.write
62/// [`try_write`]: struct.RwLock.html#method.try_write
63/// [`RwLock`]: struct.RwLock.html
64pub struct RwLockWriteGuard<'a, T>
65where
66    T: ?Sized + Send,
67    T: 'a,
68{
69    lock: &'a RwLock<T>,
70    data: &'a mut T,
71}
72
73/// RAII structure used to release the shared read access of a lock when
74/// dropped.
75///
76/// This structure is created by the [`read`] and [`try_read`] methods on
77/// [`RwLock`].
78///
79/// [`read`]: struct.RwLock.html#method.read
80/// [`try_read`]: struct.RwLock.html#method.try_read
81/// [`RwLock`]: struct.RwLock.html
82pub struct RwLockReadGuard<'a, T>
83where
84    T: ?Sized + Send,
85    T: 'a,
86{
87    lock: &'a RwLock<T>,
88    data: &'a T,
89}
90
91impl<'a, T> RwLockReadGuard<'a, T>
92where
93    T: ?Sized + Send,
94    T: 'a,
95{
96    /// Upgrade the `RwLockReadGuard`` into `RwLockWriteGuard`.
97    #[track_caller]
98    pub fn upgrade(self) -> RwLockWriteGuard<'a, T> {
99        let this = core::mem::ManuallyDrop::new(self);
100        let lock = unsafe { core::ptr::read(&this.lock) };
101        loop {
102            let mut guard = lock.owner.lock();
103            if lock
104                .state
105                .compare_exchange(1, STATE_WRITER_LOCKED, Ordering::Acquire, Ordering::Acquire)
106                .is_ok()
107            {
108                *guard = Some((
109                    crate::thread::Current::get_tid(),
110                    core::panic::Location::caller(),
111                ));
112                guard.unlock();
113
114                break RwLockWriteGuard {
115                    lock,
116                    data: unsafe { &mut *lock.data.get() },
117                };
118            }
119            guard.unlock();
120        }
121    }
122}
123
124impl<'a, T> RwLockWriteGuard<'a, T>
125where
126    T: ?Sized + Send,
127    T: 'a,
128{
129    /// Downgrade the `RwLockWriteGuard` into `RwLockReadGuard`.
130    pub fn downgrade(self) -> RwLockReadGuard<'a, T> {
131        let this = core::mem::ManuallyDrop::new(self);
132        let lock = unsafe { core::ptr::read(&this.lock) };
133        assert!(
134            lock.state
135                .compare_exchange(STATE_WRITER_LOCKED, 1, Ordering::Acquire, Ordering::Acquire)
136                .is_ok()
137        );
138        RwLockReadGuard {
139            lock,
140            data: unsafe { &*lock.data.get() },
141        }
142    }
143}
144
145impl<T> RwLock<T>
146where
147    T: Send,
148{
149    /// Creates a new instance of an `RwLock<T>` which is unlocked.
150    pub const fn new(data: T) -> RwLock<T> {
151        RwLock {
152            state: AtomicUsize::new(0),
153            owner: SpinLock::new(None),
154            data: UnsafeCell::new(data),
155        }
156    }
157
158    #[inline]
159    fn validate_state(
160        &self,
161        owner: SpinLockGuard<Option<(u64, &'static core::panic::Location<'static>)>>,
162    ) {
163        {
164            let owner = owner.expect("RwLock is in unexpected state.");
165            if owner.0 == crate::thread::Current::get_tid() {
166                panic!(
167                    "Try to acquiring ReadGuard on the thread holding the WriteGuard acquired on {:?}.",
168                    owner.1
169                );
170            }
171        }
172        owner.unlock();
173    }
174
175    #[inline]
176    fn read_lock(&self) {
177        loop {
178            let guard = self.owner.lock();
179            let prev = self.state.load(Ordering::Relaxed);
180            if is_write_locked(prev) {
181                self.validate_state(guard);
182                core::hint::spin_loop();
183            } else if self
184                .state
185                .compare_exchange(prev, prev + 1, Ordering::Acquire, Ordering::Acquire)
186                .is_ok()
187            {
188                guard.unlock();
189                break;
190            } else {
191                guard.unlock();
192            }
193        }
194    }
195
196    /// Locks this rwlock with shared read access, blocking the current thread
197    /// until it can be acquired.
198    ///
199    /// The call
200    /// ing thread will be blocked until there are no more writers which
201    /// hold the lock. There may be other readers currently inside the lock when
202    /// this method returns. This method does not provide any guarantees with
203    /// respect to the ordering of whether contentious readers or writers will
204    /// acquire the lock first.
205    ///
206    /// Returns an RAII guard which will release this thread's shared access
207    /// once it is dropped.
208    #[inline]
209    #[track_caller]
210    pub fn read(&self) -> RwLockReadGuard<'_, T> {
211        if let Ok(guard) = self.try_read() {
212            guard
213        } else {
214            self.read_lock();
215            RwLockReadGuard {
216                lock: self,
217                data: unsafe { &*self.data.get() },
218            }
219        }
220    }
221
222    /// Attempts to acquire this rwlock with shared read access.
223    ///
224    /// If the access could not be granted at this time, then `Err` is returned.
225    /// Otherwise, an RAII guard is returned which will release the shared
226    /// access when it is dropped.
227    ///
228    /// This function does not block.
229    ///
230    /// This function does not provide any guarantees with respect to the
231    /// ordering of whether contentious readers or writers will acquire the
232    /// lock first.
233    #[inline]
234    #[track_caller]
235    pub fn try_read(&self) -> Result<RwLockReadGuard<'_, T>, crate::spinlock::WouldBlock> {
236        loop {
237            let guard = self.owner.lock();
238            let prev = self.state.load(Ordering::Relaxed);
239            if is_write_locked(prev) {
240                self.validate_state(guard);
241                break Err(crate::spinlock::WouldBlock);
242            } else if self
243                .state
244                .compare_exchange(prev, prev + 1, Ordering::Acquire, Ordering::Acquire)
245                .is_ok()
246            {
247                guard.unlock();
248                break Ok(RwLockReadGuard {
249                    lock: self,
250                    data: unsafe { &*self.data.get() },
251                });
252            }
253            guard.unlock();
254        }
255    }
256
257    #[inline]
258    fn write_lock(&self) {
259        loop {
260            let prev = self.state.load(Ordering::Relaxed);
261            if prev > 0 {
262                core::hint::spin_loop();
263            } else if self
264                .state
265                .compare_exchange(0, STATE_WRITER_LOCKED, Ordering::Acquire, Ordering::Acquire)
266                .is_ok()
267            {
268                break;
269            }
270        }
271    }
272
273    /// Locks this rwlock with exclusive write access, blocking the current
274    /// thread until it can be acquired.
275    ///
276    /// This function will not return while other writers or other readers
277    /// currently have access to the lock.
278    ///
279    /// Returns an RAII guard which will drop the write access of this rwlock
280    /// when dropped.
281    #[inline]
282    #[track_caller]
283    pub fn write(&self) -> RwLockWriteGuard<'_, T> {
284        if let Ok(guard) = self.try_write() {
285            guard
286        } else {
287            let mut guard = self.owner.lock();
288            self.write_lock();
289            *guard = Some((
290                crate::thread::Current::get_tid(),
291                core::panic::Location::caller(),
292            ));
293            guard.unlock();
294            RwLockWriteGuard {
295                lock: self,
296                data: unsafe { &mut *self.data.get() },
297            }
298        }
299    }
300
301    /// Attempts to lock this rwlock with exclusive write access.
302    ///
303    /// If the lock could not be acquired at this time, then `Err` is returned.
304    /// Otherwise, an RAII guard is returned which will release the lock when
305    /// it is dropped.
306    ///
307    /// This function does not block.
308    ///
309    /// This function does not provide any guarantees with respect to the
310    /// ordering of whether contentious readers or writers will acquire the
311    /// lock first.
312    #[track_caller]
313    pub fn try_write(&self) -> Result<RwLockWriteGuard<'_, T>, crate::spinlock::WouldBlock> {
314        loop {
315            let mut guard = self.owner.lock();
316            let prev = self.state.load(Ordering::Relaxed);
317            if prev > 0 {
318                guard.unlock();
319                break Err(crate::spinlock::WouldBlock);
320            } else if self
321                .state
322                .compare_exchange(
323                    prev,
324                    prev | STATE_WRITER_LOCKED,
325                    Ordering::Acquire,
326                    Ordering::Acquire,
327                )
328                .is_ok()
329            {
330                *guard = Some((
331                    crate::thread::Current::get_tid(),
332                    core::panic::Location::caller(),
333                ));
334                guard.unlock();
335
336                break Ok(RwLockWriteGuard {
337                    lock: self,
338                    data: unsafe { &mut *self.data.get() },
339                });
340            }
341            guard.unlock();
342        }
343    }
344    /// This steals the ownership even if the value is locked. Racy.
345    ///
346    /// # Safety
347    /// This is unsafe.
348    #[inline]
349    #[allow(clippy::mut_from_ref)]
350    pub unsafe fn steal(&self) -> &mut T {
351        unsafe { &mut *self.data.get() }
352    }
353
354    /// Consumes this RwLock, returning the underlying data.
355    #[inline]
356    pub fn into_inner(self) -> T {
357        self.data.into_inner()
358    }
359}
360
361unsafe impl<T> Sync for RwLock<T> where T: ?Sized + Send {}
362unsafe impl<T> Send for RwLock<T> where T: ?Sized + Send {}
363
364impl<T: Send> core::fmt::Debug for RwLock<T> {
365    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
366        f.debug_struct("RwLock")
367            .field("state", &self.state.load(Ordering::SeqCst))
368            .finish()
369    }
370}
371
372impl<'a, T> Deref for RwLockReadGuard<'a, T>
373where
374    T: ?Sized + Send,
375{
376    type Target = T;
377    fn deref(&self) -> &T {
378        self.data
379    }
380}
381
382impl<'a, T> Deref for RwLockWriteGuard<'a, T>
383where
384    T: ?Sized + Send,
385{
386    type Target = T;
387    fn deref(&self) -> &T {
388        &*self.data
389    }
390}
391
392impl<'a, T> DerefMut for RwLockWriteGuard<'a, T>
393where
394    T: ?Sized + Send,
395{
396    fn deref_mut(&mut self) -> &mut T {
397        &mut *self.data
398    }
399}
400
401impl<'a, T> Drop for RwLockReadGuard<'a, T>
402where
403    T: ?Sized + Send,
404{
405    #[track_caller]
406    fn drop(&mut self) {
407        debug_assert_eq!(self.lock.state.load(Ordering::Acquire) & STATE_MASK, 0);
408        self.lock.state.fetch_sub(1, Ordering::Release);
409    }
410}
411
412impl<'a, T> Drop for RwLockWriteGuard<'a, T>
413where
414    T: ?Sized + Send,
415{
416    #[track_caller]
417    fn drop(&mut self) {
418        debug_assert_eq!(
419            self.lock.state.load(Ordering::Acquire) & STATE_MASK,
420            STATE_WRITER_LOCKED
421        );
422        self.lock
423            .state
424            .fetch_and(!STATE_WRITER_LOCKED, Ordering::Release);
425    }
426}