use crate::{
buffer::LimitedVec,
gas::ChargeError,
pages::{GearPage, WasmPage, WasmPagesAmount},
};
use alloc::format;
use byteorder::{ByteOrder, LittleEndian};
use core::{
fmt,
fmt::Debug,
ops::{Deref, DerefMut},
};
use numerated::{
interval::{Interval, TryFromRangeError},
tree::IntervalsTree,
};
use scale_info::{
scale::{self, Decode, Encode, EncodeLike, Input, Output},
TypeInfo,
};
#[derive(Clone, Copy, Eq, PartialEq, Encode, Decode)]
pub struct MemoryInterval {
pub offset: u32,
pub size: u32,
}
impl MemoryInterval {
#[inline]
pub fn to_bytes(&self) -> [u8; 8] {
let mut bytes = [0u8; 8];
LittleEndian::write_u32(&mut bytes[0..4], self.offset);
LittleEndian::write_u32(&mut bytes[4..8], self.size);
bytes
}
#[inline]
pub fn try_from_bytes(bytes: &[u8]) -> Result<Self, &'static str> {
if bytes.len() != 8 {
return Err("bytes size != 8");
}
let offset = LittleEndian::read_u32(&bytes[0..4]);
let size = LittleEndian::read_u32(&bytes[4..8]);
Ok(MemoryInterval { offset, size })
}
}
impl From<(u32, u32)> for MemoryInterval {
fn from(val: (u32, u32)) -> Self {
MemoryInterval {
offset: val.0,
size: val.1,
}
}
}
impl From<MemoryInterval> for (u32, u32) {
fn from(val: MemoryInterval) -> Self {
(val.offset, val.size)
}
}
impl Debug for MemoryInterval {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MemoryInterval")
.field("offset", &format_args!("{:#x}", self.offset))
.field("size", &format_args!("{:#x}", self.size))
.finish()
}
}
#[derive(Debug, Default, PartialEq, Eq, Clone, TypeInfo, derive_more::Display)]
#[display(
fmt = "Trying to make wrong size page buffer, must be {:#x}",
GearPage::SIZE
)]
pub struct IntoPageBufError;
pub type PageBufInner = LimitedVec<u8, IntoPageBufError, { GearPage::SIZE as usize }>;
#[derive(Clone, PartialEq, Eq, TypeInfo)]
pub struct PageBuf(PageBufInner);
impl Encode for PageBuf {
fn size_hint(&self) -> usize {
GearPage::SIZE as usize
}
fn encode_to<W: Output + ?Sized>(&self, dest: &mut W) {
dest.write(self.0.inner())
}
}
impl Decode for PageBuf {
#[inline]
fn decode<I: Input>(input: &mut I) -> Result<Self, scale::Error> {
let mut buffer = PageBufInner::new_default();
input.read(buffer.inner_mut())?;
Ok(Self(buffer))
}
}
impl EncodeLike for PageBuf {}
impl Debug for PageBuf {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"PageBuf({:?}..{:?})",
&self.0.inner()[0..10],
&self.0.inner()[GearPage::SIZE as usize - 10..GearPage::SIZE as usize]
)
}
}
impl Deref for PageBuf {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.0.inner()
}
}
impl DerefMut for PageBuf {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.inner_mut()
}
}
impl PageBuf {
pub fn new_zeroed() -> PageBuf {
Self(PageBufInner::new_default())
}
pub fn from_inner(mut inner: PageBufInner) -> Self {
inner.extend_with(0);
Self(inner)
}
}
pub type HostPointer = u64;
const _: () = assert!(size_of::<HostPointer>() >= size_of::<usize>());
#[derive(Debug, Clone, Eq, PartialEq, derive_more::Display)]
pub enum MemoryError {
#[display(fmt = "Trying to access memory outside wasm program memory")]
AccessOutOfBounds,
}
pub trait Memory<Context> {
type GrowError: Debug;
fn grow(&self, ctx: &mut Context, pages: WasmPagesAmount) -> Result<(), Self::GrowError>;
fn size(&self, ctx: &Context) -> WasmPagesAmount;
fn write(&self, ctx: &mut Context, offset: u32, buffer: &[u8]) -> Result<(), MemoryError>;
fn read(&self, ctx: &Context, offset: u32, buffer: &mut [u8]) -> Result<(), MemoryError>;
fn get_buffer_host_addr(&self, ctx: &Context) -> Option<HostPointer> {
if self.size(ctx) == WasmPagesAmount::from(0) {
None
} else {
unsafe { Some(self.get_buffer_host_addr_unsafe(ctx)) }
}
}
unsafe fn get_buffer_host_addr_unsafe(&self, ctx: &Context) -> HostPointer;
}
#[derive(Debug)]
pub struct AllocationsContext {
init_allocations: IntervalsTree<WasmPage>,
allocations: IntervalsTree<WasmPage>,
max_pages: WasmPagesAmount,
static_pages: WasmPagesAmount,
}
#[must_use]
pub trait GrowHandler<Context> {
fn before_grow_action(ctx: &mut Context, mem: &mut impl Memory<Context>) -> Self;
fn after_grow_action(self, ctx: &mut Context, mem: &mut impl Memory<Context>);
}
pub struct NoopGrowHandler;
impl<Context> GrowHandler<Context> for NoopGrowHandler {
fn before_grow_action(_ctx: &mut Context, _mem: &mut impl Memory<Context>) -> Self {
NoopGrowHandler
}
fn after_grow_action(self, _ctx: &mut Context, _mem: &mut impl Memory<Context>) {}
}
#[derive(Debug, Clone, PartialEq, Eq, derive_more::Display)]
pub enum MemorySetupError {
#[display(fmt = "Memory size {memory_size:?} must be less than or equal to {max_pages:?}")]
MemorySizeExceedsMaxPages {
memory_size: WasmPagesAmount,
max_pages: WasmPagesAmount,
},
#[display(fmt = "Memory size {memory_size:?} must be at least {static_pages:?}")]
InsufficientMemorySize {
memory_size: WasmPagesAmount,
static_pages: WasmPagesAmount,
},
#[display(fmt = "Stack end {stack_end:?} is out of static memory 0..{static_pages:?}")]
StackEndOutOfStaticMemory {
stack_end: WasmPage,
static_pages: WasmPagesAmount,
},
#[display(
fmt = "Allocated page {page:?} is out of allowed memory interval {static_pages:?}..{memory_size:?}"
)]
AllocatedPageOutOfAllowedInterval {
page: WasmPage,
static_pages: WasmPagesAmount,
memory_size: WasmPagesAmount,
},
}
#[derive(Debug, Clone, Eq, PartialEq, derive_more::Display, derive_more::From)]
pub enum AllocError {
#[display(fmt = "Trying to allocate more wasm program memory than allowed")]
ProgramAllocOutOfBounds,
#[display(fmt = "{_0:?} cannot be freed by the current program")]
InvalidFree(WasmPage),
#[display(fmt = "Invalid range {_0:?}..={_1:?} for free_range")]
InvalidFreeRange(WasmPage, WasmPage),
#[from]
#[display(fmt = "{_0}")]
GasCharge(ChargeError),
}
impl AllocationsContext {
pub fn try_new(
memory_size: WasmPagesAmount,
allocations: IntervalsTree<WasmPage>,
static_pages: WasmPagesAmount,
stack_end: Option<WasmPage>,
max_pages: WasmPagesAmount,
) -> Result<Self, MemorySetupError> {
Self::validate_memory_params(
memory_size,
&allocations,
static_pages,
stack_end,
max_pages,
)?;
Ok(Self {
init_allocations: allocations.clone(),
allocations,
max_pages,
static_pages,
})
}
fn validate_memory_params(
memory_size: WasmPagesAmount,
allocations: &IntervalsTree<WasmPage>,
static_pages: WasmPagesAmount,
stack_end: Option<WasmPage>,
max_pages: WasmPagesAmount,
) -> Result<(), MemorySetupError> {
if memory_size > max_pages {
return Err(MemorySetupError::MemorySizeExceedsMaxPages {
memory_size,
max_pages,
});
}
if static_pages > memory_size {
return Err(MemorySetupError::InsufficientMemorySize {
memory_size,
static_pages,
});
}
if let Some(stack_end) = stack_end {
if stack_end > static_pages {
return Err(MemorySetupError::StackEndOutOfStaticMemory {
stack_end,
static_pages,
});
}
}
if let Some(page) = allocations.end() {
if page >= memory_size {
return Err(MemorySetupError::AllocatedPageOutOfAllowedInterval {
page,
static_pages,
memory_size,
});
}
}
if let Some(page) = allocations.start() {
if page < static_pages {
return Err(MemorySetupError::AllocatedPageOutOfAllowedInterval {
page,
static_pages,
memory_size,
});
}
}
Ok(())
}
pub fn alloc<Context, G: GrowHandler<Context>>(
&mut self,
ctx: &mut Context,
mem: &mut impl Memory<Context>,
pages: WasmPagesAmount,
charge_gas_for_grow: impl FnOnce(WasmPagesAmount) -> Result<(), ChargeError>,
) -> Result<WasmPage, AllocError> {
let heap = match Interval::try_from(self.static_pages..self.max_pages) {
Ok(interval) => interval,
Err(TryFromRangeError::IncorrectRange) => {
let err_msg = format!(
"AllocationContext:alloc: Must be self.static_pages <= self.max_pages. This is guaranteed by `Code::try_new`. \
Static pages - {:?}, max pages - {:?}",
self.static_pages, self.max_pages
);
log::error!("{err_msg}");
unreachable!("{err_msg}")
}
Err(TryFromRangeError::EmptyRange) => {
return Err(AllocError::ProgramAllocOutOfBounds);
}
};
if pages == WasmPage::from(0) {
return Ok(heap.start());
}
let interval = self
.allocations
.voids(heap)
.find_map(|void| {
Interval::<WasmPage>::with_len(void.start(), u32::from(pages))
.ok()
.and_then(|interval| (interval.end() <= void.end()).then_some(interval))
})
.ok_or(AllocError::ProgramAllocOutOfBounds)?;
if let Ok(grow) = Interval::<WasmPage>::try_from(mem.size(ctx)..interval.end().inc()) {
charge_gas_for_grow(grow.len())?;
let grow_handler = G::before_grow_action(ctx, mem);
mem.grow(ctx, grow.len()).unwrap_or_else(|err| {
let err_msg = format!(
"AllocationContext:alloc: Failed to grow memory. \
Got error - {err:?}",
);
log::error!("{err_msg}");
unreachable!("{err_msg}")
});
grow_handler.after_grow_action(ctx, mem);
}
self.allocations.insert(interval);
Ok(interval.start())
}
pub fn free(&mut self, page: WasmPage) -> Result<(), AllocError> {
if page < self.static_pages || page >= self.max_pages || !self.allocations.contains(page) {
Err(AllocError::InvalidFree(page))
} else {
self.allocations.remove(page);
Ok(())
}
}
pub fn free_range(&mut self, interval: Interval<WasmPage>) -> Result<(), AllocError> {
if interval.start() < self.static_pages || interval.end() >= self.max_pages {
Err(AllocError::InvalidFreeRange(
interval.start(),
interval.end(),
))
} else {
self.allocations.remove(interval);
Ok(())
}
}
pub fn into_parts(
self,
) -> (
WasmPagesAmount,
IntervalsTree<WasmPage>,
IntervalsTree<WasmPage>,
) {
(self.static_pages, self.init_allocations, self.allocations)
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec::Vec;
use core::{cell::Cell, iter};
struct TestMemory(Cell<WasmPagesAmount>);
impl TestMemory {
fn new(amount: WasmPagesAmount) -> Self {
Self(Cell::new(amount))
}
}
impl Memory<()> for TestMemory {
type GrowError = ();
fn grow(&self, _ctx: &mut (), pages: WasmPagesAmount) -> Result<(), Self::GrowError> {
let new_pages_amount = self.0.get().add(pages).ok_or(())?;
self.0.set(new_pages_amount);
Ok(())
}
fn size(&self, _ctx: &()) -> WasmPagesAmount {
self.0.get()
}
fn write(&self, _ctx: &mut (), _offset: u32, _buffer: &[u8]) -> Result<(), MemoryError> {
unimplemented!()
}
fn read(&self, _ctx: &(), _offset: u32, _buffer: &mut [u8]) -> Result<(), MemoryError> {
unimplemented!()
}
unsafe fn get_buffer_host_addr_unsafe(&self, _ctx: &()) -> HostPointer {
unimplemented!()
}
}
#[test]
fn page_buf() {
let _ = env_logger::try_init();
let mut data = PageBufInner::filled_with(199u8);
data.inner_mut()[1] = 2;
let page_buf = PageBuf::from_inner(data);
log::debug!("page buff = {:?}", page_buf);
}
#[test]
fn free_fails() {
let mut ctx =
AllocationsContext::try_new(0.into(), Default::default(), 0.into(), None, 0.into())
.unwrap();
assert_eq!(ctx.free(1.into()), Err(AllocError::InvalidFree(1.into())));
let mut ctx = AllocationsContext::try_new(
1.into(),
[WasmPage::from(0)].into_iter().collect(),
0.into(),
None,
1.into(),
)
.unwrap();
assert_eq!(ctx.free(1.into()), Err(AllocError::InvalidFree(1.into())));
let mut ctx = AllocationsContext::try_new(
4.into(),
[WasmPage::from(1), WasmPage::from(3)].into_iter().collect(),
1.into(),
None,
4.into(),
)
.unwrap();
let interval = Interval::<WasmPage>::try_from(1u16..4).unwrap();
assert_eq!(ctx.free_range(interval), Ok(()));
}
#[track_caller]
fn alloc_ok(ctx: &mut AllocationsContext, mem: &mut TestMemory, pages: u16, expected: u16) {
let res = ctx.alloc::<(), NoopGrowHandler>(&mut (), mem, pages.into(), |_| Ok(()));
assert_eq!(res, Ok(expected.into()));
}
#[track_caller]
fn alloc_err(ctx: &mut AllocationsContext, mem: &mut TestMemory, pages: u16, err: AllocError) {
let res = ctx.alloc::<(), NoopGrowHandler>(&mut (), mem, pages.into(), |_| Ok(()));
assert_eq!(res, Err(err));
}
#[test]
fn alloc() {
let _ = env_logger::try_init();
let mut ctx = AllocationsContext::try_new(
256.into(),
Default::default(),
16.into(),
None,
256.into(),
)
.unwrap();
let mut mem = TestMemory::new(16.into());
alloc_ok(&mut ctx, &mut mem, 16, 16);
alloc_ok(&mut ctx, &mut mem, 0, 16);
(2..16).for_each(|i| alloc_ok(&mut ctx, &mut mem, 16, i * 16));
alloc_err(&mut ctx, &mut mem, 16, AllocError::ProgramAllocOutOfBounds);
ctx.free(137.into()).unwrap();
alloc_ok(&mut ctx, &mut mem, 1, 137);
ctx.free(117.into()).unwrap();
ctx.free(118.into()).unwrap();
alloc_ok(&mut ctx, &mut mem, 2, 117);
let interval = Interval::<WasmPage>::try_from(117..119).unwrap();
ctx.free_range(interval).unwrap();
alloc_ok(&mut ctx, &mut mem, 2, 117);
ctx.free(117.into()).unwrap();
ctx.free(158.into()).unwrap();
alloc_err(&mut ctx, &mut mem, 2, AllocError::ProgramAllocOutOfBounds);
}
#[test]
fn memory_params_validation() {
assert_eq!(
AllocationsContext::validate_memory_params(
4.into(),
&iter::once(WasmPage::from(2)).collect(),
2.into(),
Some(2.into()),
4.into(),
),
Ok(())
);
assert_eq!(
AllocationsContext::validate_memory_params(
4.into(),
&Default::default(),
2.into(),
Some(2.into()),
3.into(),
),
Err(MemorySetupError::MemorySizeExceedsMaxPages {
memory_size: 4.into(),
max_pages: 3.into()
})
);
assert_eq!(
AllocationsContext::validate_memory_params(
1.into(),
&Default::default(),
2.into(),
Some(1.into()),
4.into(),
),
Err(MemorySetupError::InsufficientMemorySize {
memory_size: 1.into(),
static_pages: 2.into()
})
);
assert_eq!(
AllocationsContext::validate_memory_params(
4.into(),
&Default::default(),
2.into(),
Some(3.into()),
4.into(),
),
Err(MemorySetupError::StackEndOutOfStaticMemory {
stack_end: 3.into(),
static_pages: 2.into()
})
);
assert_eq!(
AllocationsContext::validate_memory_params(
4.into(),
&[WasmPage::from(1), WasmPage::from(3)].into_iter().collect(),
2.into(),
Some(2.into()),
4.into(),
),
Err(MemorySetupError::AllocatedPageOutOfAllowedInterval {
page: 1.into(),
static_pages: 2.into(),
memory_size: 4.into()
})
);
assert_eq!(
AllocationsContext::validate_memory_params(
4.into(),
&[WasmPage::from(2), WasmPage::from(4)].into_iter().collect(),
2.into(),
Some(2.into()),
4.into(),
),
Err(MemorySetupError::AllocatedPageOutOfAllowedInterval {
page: 4.into(),
static_pages: 2.into(),
memory_size: 4.into()
})
);
assert_eq!(
AllocationsContext::validate_memory_params(
13.into(),
&iter::once(WasmPage::from(1)).collect(),
10.into(),
None,
13.into()
),
Err(MemorySetupError::AllocatedPageOutOfAllowedInterval {
page: 1.into(),
static_pages: 10.into(),
memory_size: 13.into()
})
);
assert_eq!(
AllocationsContext::validate_memory_params(
13.into(),
&iter::once(WasmPage::from(1)).collect(),
WasmPagesAmount::UPPER,
None,
13.into()
),
Err(MemorySetupError::InsufficientMemorySize {
memory_size: 13.into(),
static_pages: WasmPagesAmount::UPPER
})
);
assert_eq!(
AllocationsContext::validate_memory_params(
WasmPagesAmount::UPPER,
&iter::once(WasmPage::from(1)).collect(),
10.into(),
None,
WasmPagesAmount::UPPER,
),
Err(MemorySetupError::AllocatedPageOutOfAllowedInterval {
page: 1.into(),
static_pages: 10.into(),
memory_size: WasmPagesAmount::UPPER
})
);
}
mod property_tests {
use super::*;
use proptest::{
arbitrary::any,
collection::size_range,
prop_oneof, proptest,
strategy::{Just, Strategy},
test_runner::Config as ProptestConfig,
};
#[derive(Debug, Clone)]
enum Action {
Alloc { pages: WasmPagesAmount },
Free { page: WasmPage },
FreeRange { page: WasmPage, size: u8 },
}
fn actions() -> impl Strategy<Value = Vec<Action>> {
let action = prop_oneof![
wasm_pages_amount_with_range(0, 32).prop_map(|pages| Action::Alloc { pages }),
wasm_page().prop_map(|page| Action::Free { page }),
(wasm_page(), any::<u8>())
.prop_map(|(page, size)| Action::FreeRange { page, size }),
];
proptest::collection::vec(action, 0..1024)
}
fn allocations(start: u16, end: u16) -> impl Strategy<Value = IntervalsTree<WasmPage>> {
proptest::collection::btree_set(wasm_page_with_range(start, end), size_range(0..1024))
.prop_map(|pages| pages.into_iter().collect::<IntervalsTree<WasmPage>>())
}
fn wasm_page_with_range(start: u16, end: u16) -> impl Strategy<Value = WasmPage> {
(start..=end).prop_map(WasmPage::from)
}
fn wasm_page() -> impl Strategy<Value = WasmPage> {
wasm_page_with_range(0, u16::MAX)
}
fn wasm_pages_amount_with_range(
start: u32,
end: u32,
) -> impl Strategy<Value = WasmPagesAmount> {
(start..=end).prop_map(|x| {
if x == u16::MAX as u32 + 1 {
WasmPagesAmount::UPPER
} else {
WasmPagesAmount::from(x as u16)
}
})
}
fn wasm_pages_amount() -> impl Strategy<Value = WasmPagesAmount> {
wasm_pages_amount_with_range(0, u16::MAX as u32 + 1)
}
#[derive(Debug)]
struct MemoryParams {
max_pages: WasmPagesAmount,
mem_size: WasmPagesAmount,
static_pages: WasmPagesAmount,
allocations: IntervalsTree<WasmPage>,
}
fn combined_memory_params() -> impl Strategy<Value = MemoryParams> {
wasm_pages_amount()
.prop_flat_map(|max_pages| {
let mem_size = wasm_pages_amount_with_range(0, u32::from(max_pages));
(Just(max_pages), mem_size)
})
.prop_flat_map(|(max_pages, mem_size)| {
let static_pages = wasm_pages_amount_with_range(0, u32::from(mem_size));
(Just(max_pages), Just(mem_size), static_pages)
})
.prop_filter(
"filter out cases where allocation region has zero size",
|(_max_pages, mem_size, static_pages)| static_pages < mem_size,
)
.prop_flat_map(|(max_pages, mem_size, static_pages)| {
let end_exclusive = u32::from(mem_size) - 1;
(
Just(max_pages),
Just(mem_size),
Just(static_pages),
allocations(u32::from(static_pages) as u16, end_exclusive as u16),
)
})
.prop_map(
|(max_pages, mem_size, static_pages, allocations)| MemoryParams {
max_pages,
mem_size,
static_pages,
allocations,
},
)
}
fn proptest_config() -> ProptestConfig {
ProptestConfig {
cases: 1024,
..Default::default()
}
}
#[track_caller]
fn assert_free_error(err: AllocError) {
match err {
AllocError::InvalidFree(_) => {}
AllocError::InvalidFreeRange(_, _) => {}
err => panic!("{err:?}"),
}
}
proptest! {
#![proptest_config(proptest_config())]
#[test]
fn alloc(
mem_params in combined_memory_params(),
actions in actions(),
) {
let _ = env_logger::try_init();
let MemoryParams{max_pages, mem_size, static_pages, allocations} = mem_params;
let mut ctx = AllocationsContext::try_new(mem_size, allocations, static_pages, None, max_pages).unwrap();
let mut mem = TestMemory::new(mem_size);
for action in actions {
match action {
Action::Alloc { pages } => {
match ctx.alloc::<_, NoopGrowHandler>(&mut (), &mut mem, pages, |_| Ok(())) {
Err(AllocError::ProgramAllocOutOfBounds) => {
let x = mem.size(&()).add(pages);
assert!(x.is_none() || x.unwrap() > max_pages);
}
Ok(page) => {
assert!(pages == WasmPagesAmount::from(0) || (page >= static_pages && page < max_pages));
assert!(mem.size(&()) <= max_pages);
assert!(WasmPagesAmount::from(page).add(pages).unwrap() <= mem.size(&()));
}
Err(err) => panic!("{err:?}"),
}
}
Action::Free { page } => {
if let Err(err) = ctx.free(page) {
assert_free_error(err);
}
}
Action::FreeRange { page, size } => {
if let Ok(interval) = Interval::<WasmPage>::with_len(page, size as u32) {
let _ = ctx.free_range(interval).map_err(assert_free_error);
}
}
}
}
}
}
}
}