diff --git a/drivers/staging/android/ashmem.h b/drivers/staging/android/ashmem.h index 06f7f91ac22b..373437f6a279 100644 --- a/drivers/staging/android/ashmem.h +++ b/drivers/staging/android/ashmem.h @@ -34,5 +34,6 @@ enum { bool is_ashmem_file(struct file *file); int ashmem_area_name(struct file *file, char *name); long ashmem_area_size(struct file *file); +struct file *ashmem_area_vmfile(struct file *file); #endif /* _LINUX_ASHMEM_H */ diff --git a/drivers/staging/android/ashmem_rust.rs b/drivers/staging/android/ashmem_rust.rs index 30c886c63246..de64200442b4 100644 --- a/drivers/staging/android/ashmem_rust.rs +++ b/drivers/staging/android/ashmem_rust.rs @@ -724,3 +724,31 @@ unsafe extern "C" fn ashmem_area_size(file: *mut bindings::file) -> isize { Err(_err) => 0, } } + +/// # Safety +/// +/// The caller must ensure that `file` is valid for the duration of this function. +/// +/// If this function returns a non-NULL pointer to a file structure, the refcount for that +/// file will be incremented by 1. It is the caller's responsibility to decrement the refcount +/// when the file is no longer needed. +#[no_mangle] +unsafe extern "C" fn ashmem_area_vmfile(file: *mut bindings::file) -> *mut bindings::file { + // SAFETY: file is valid for the duration of this function. + let ashmem = match unsafe { get_ashmem_area(file) } { + Ok(area) => area, + Err(_err) => return null_mut(), + }; + + let asma = &mut *ashmem.inner.lock(); + match asma.file.as_ref() { + Some(shmem_file) => { + let shmem_file_ptr = shmem_file.file().as_ptr(); + // SAFETY: file is valid for the duration of the function, which means shmem file is + // also valid at this point. + unsafe { bindings::get_file(shmem_file_ptr) }; + shmem_file_ptr + } + None => null_mut(), + } +} diff --git a/drivers/staging/android/ashmem_rust_exports.c b/drivers/staging/android/ashmem_rust_exports.c index 3957a3b208c0..373110484c16 100644 --- a/drivers/staging/android/ashmem_rust_exports.c +++ b/drivers/staging/android/ashmem_rust_exports.c @@ -19,3 +19,4 @@ EXPORT_SYMBOL_GPL(is_ashmem_file); EXPORT_SYMBOL_GPL(ashmem_area_name); EXPORT_SYMBOL_GPL(ashmem_area_size); +EXPORT_SYMBOL_GPL(ashmem_area_vmfile);