memory_set/
set.rs

1use alloc::collections::BTreeMap;
2#[allow(unused_imports)] // this is a weird false alarm
3use alloc::vec::Vec;
4use core::fmt;
5
6use memory_addr::{AddrRange, MemoryAddr};
7
8use crate::{MappingBackend, MappingError, MappingResult, MemoryArea};
9
10/// A container that maintains memory mappings ([`MemoryArea`]).
11pub struct MemorySet<B: MappingBackend> {
12    areas: BTreeMap<B::Addr, MemoryArea<B>>,
13}
14
15impl<B: MappingBackend> MemorySet<B> {
16    /// Creates a new memory set.
17    pub const fn new() -> Self {
18        Self {
19            areas: BTreeMap::new(),
20        }
21    }
22
23    /// Returns the number of memory areas in the memory set.
24    pub fn len(&self) -> usize {
25        self.areas.len()
26    }
27
28    /// Returns `true` if the memory set contains no memory areas.
29    pub fn is_empty(&self) -> bool {
30        self.areas.is_empty()
31    }
32
33    /// Returns the iterator over all memory areas.
34    pub fn iter(&self) -> impl Iterator<Item = &MemoryArea<B>> {
35        self.areas.values()
36    }
37
38    /// Returns whether the given address range overlaps with any existing area.
39    pub fn overlaps(&self, range: AddrRange<B::Addr>) -> bool {
40        if let Some((_, before)) = self.areas.range(..range.start).last() {
41            if before.va_range().overlaps(range) {
42                return true;
43            }
44        }
45        if let Some((_, after)) = self.areas.range(range.start..).next() {
46            if after.va_range().overlaps(range) {
47                return true;
48            }
49        }
50        false
51    }
52
53    /// Finds the memory area that contains the given address.
54    pub fn find(&self, addr: B::Addr) -> Option<&MemoryArea<B>> {
55        let candidate = self.areas.range(..=addr).last().map(|(_, a)| a);
56        candidate.filter(|a| a.va_range().contains(addr))
57    }
58
59    /// Finds a free area that can accommodate the given size.
60    ///
61    /// The search starts from the given `hint` address, and the area should be
62    /// within the given `limit` range.
63    ///
64    /// Returns the start address of the free area. Returns `None` if no such
65    /// area is found.
66    pub fn find_free_area(
67        &self,
68        hint: B::Addr,
69        size: usize,
70        limit: AddrRange<B::Addr>,
71    ) -> Option<B::Addr> {
72        // brute force: try each area's end address as the start.
73        let mut last_end = hint.max(limit.start);
74        if let Some((_, area)) = self.areas.range(..last_end).last() {
75            last_end = last_end.max(area.end());
76        }
77        for (&addr, area) in self.areas.range(last_end..) {
78            if last_end.checked_add(size).is_some_and(|end| end <= addr) {
79                return Some(last_end);
80            }
81            last_end = area.end();
82        }
83        if last_end
84            .checked_add(size)
85            .is_some_and(|end| end <= limit.end)
86        {
87            Some(last_end)
88        } else {
89            None
90        }
91    }
92
93    /// Add a new memory mapping.
94    ///
95    /// The mapping is represented by a [`MemoryArea`].
96    ///
97    /// If the new area overlaps with any existing area, the behavior is
98    /// determined by the `unmap_overlap` parameter. If it is `true`, the
99    /// overlapped regions will be unmapped first. Otherwise, it returns an
100    /// error.
101    pub fn map(
102        &mut self,
103        area: MemoryArea<B>,
104        page_table: &mut B::PageTable,
105        unmap_overlap: bool,
106    ) -> MappingResult {
107        if area.va_range().is_empty() {
108            return Err(MappingError::InvalidParam);
109        }
110
111        if self.overlaps(area.va_range()) {
112            if unmap_overlap {
113                self.unmap(area.start(), area.size(), page_table)?;
114            } else {
115                return Err(MappingError::AlreadyExists);
116            }
117        }
118
119        area.map_area(page_table)?;
120        assert!(self.areas.insert(area.start(), area).is_none());
121        Ok(())
122    }
123
124    /// Remove memory mappings within the given address range.
125    ///
126    /// All memory areas that are fully contained in the range will be removed
127    /// directly. If the area intersects with the boundary, it will be shrinked.
128    /// If the unmapped range is in the middle of an existing area, it will be
129    /// split into two areas.
130    pub fn unmap(
131        &mut self,
132        start: B::Addr,
133        size: usize,
134        page_table: &mut B::PageTable,
135    ) -> MappingResult {
136        let range =
137            AddrRange::try_from_start_size(start, size).ok_or(MappingError::InvalidParam)?;
138        if range.is_empty() {
139            return Ok(());
140        }
141
142        let end = range.end;
143
144        // Unmap entire areas that are contained by the range.
145        self.areas.retain(|_, area| {
146            if area.va_range().contained_in(range) {
147                area.unmap_area(page_table).unwrap();
148                false
149            } else {
150                true
151            }
152        });
153
154        // Shrink right if the area intersects with the left boundary.
155        if let Some((&before_start, before)) = self.areas.range_mut(..start).last() {
156            let before_end = before.end();
157            if before_end > start {
158                if before_end <= end {
159                    // the unmapped area is at the end of `before`.
160                    before.shrink_right(start.sub_addr(before_start), page_table)?;
161                } else {
162                    // the unmapped area is in the middle `before`, need to split.
163                    let right_part = before.split(end).unwrap();
164                    before.shrink_right(start.sub_addr(before_start), page_table)?;
165                    assert_eq!(right_part.start().into(), Into::<usize>::into(end));
166                    self.areas.insert(end, right_part);
167                }
168            }
169        }
170
171        // Shrink left if the area intersects with the right boundary.
172        if let Some((&after_start, after)) = self.areas.range_mut(start..).next() {
173            let after_end = after.end();
174            if after_start < end {
175                // the unmapped area is at the start of `after`.
176                let mut new_area = self.areas.remove(&after_start).unwrap();
177                new_area.shrink_left(after_end.sub_addr(end), page_table)?;
178                assert_eq!(new_area.start().into(), Into::<usize>::into(end));
179                self.areas.insert(end, new_area);
180            }
181        }
182
183        Ok(())
184    }
185
186    /// Remove all memory areas and the underlying mappings.
187    pub fn clear(&mut self, page_table: &mut B::PageTable) -> MappingResult {
188        for (_, area) in self.areas.iter() {
189            area.unmap_area(page_table)?;
190        }
191        self.areas.clear();
192        Ok(())
193    }
194
195    /// Change the flags of memory mappings within the given address range.
196    ///
197    /// `update_flags` is a function that receives old flags and processes
198    /// new flags (e.g., some flags can not be changed through this interface).
199    /// It returns [`None`] if there is no bit to change.
200    ///
201    /// Memory areas will be skipped according to `update_flags`. Memory areas
202    /// that are fully contained in the range or contains the range or
203    /// intersects with the boundary will be handled similarly to `munmap`.
204    pub fn protect(
205        &mut self,
206        start: B::Addr,
207        size: usize,
208        update_flags: impl Fn(B::Flags) -> Option<B::Flags>,
209        page_table: &mut B::PageTable,
210    ) -> MappingResult {
211        let end = start.checked_add(size).ok_or(MappingError::InvalidParam)?;
212        let mut to_insert = Vec::new();
213        for (&area_start, area) in self.areas.iter_mut() {
214            let area_end = area.end();
215
216            if let Some(new_flags) = update_flags(area.flags()) {
217                if area_start >= end {
218                    // [ prot ]
219                    //          [ area ]
220                    break;
221                } else if area_end <= start {
222                    //          [ prot ]
223                    // [ area ]
224                    // Do nothing
225                } else if area_start >= start && area_end <= end {
226                    // [   prot   ]
227                    //   [ area ]
228                    area.protect_area(new_flags, page_table)?;
229                    area.set_flags(new_flags);
230                } else if area_start < start && area_end > end {
231                    //        [ prot ]
232                    // [ left | area | right ]
233                    let right_part = area.split(end).unwrap();
234                    area.set_end(start);
235
236                    let mut middle_part =
237                        MemoryArea::new(start, size, area.flags(), area.backend().clone());
238                    middle_part.protect_area(new_flags, page_table)?;
239                    middle_part.set_flags(new_flags);
240
241                    to_insert.push((right_part.start(), right_part));
242                    to_insert.push((middle_part.start(), middle_part));
243                } else if area_end > end {
244                    // [    prot ]
245                    //   [  area | right ]
246                    let right_part = area.split(end).unwrap();
247                    area.protect_area(new_flags, page_table)?;
248                    area.set_flags(new_flags);
249
250                    to_insert.push((right_part.start(), right_part));
251                } else {
252                    //        [ prot    ]
253                    // [ left |  area ]
254                    let mut right_part = area.split(start).unwrap();
255                    right_part.protect_area(new_flags, page_table)?;
256                    right_part.set_flags(new_flags);
257
258                    to_insert.push((right_part.start(), right_part));
259                }
260            }
261        }
262        self.areas.extend(to_insert);
263        Ok(())
264    }
265}
266
267impl<B: MappingBackend> fmt::Debug for MemorySet<B>
268where
269    B::Addr: fmt::Debug,
270    B::Flags: fmt::Debug,
271{
272    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
273        f.debug_list().entries(self.areas.values()).finish()
274    }
275}