diff --git a/drivers/virt/coco/pkvm-guest/arm-pkvm-guest.c b/drivers/virt/coco/pkvm-guest/arm-pkvm-guest.c index 3f97b00ff000..fcd6823fa3a3 100644 --- a/drivers/virt/coco/pkvm-guest/arm-pkvm-guest.c +++ b/drivers/virt/coco/pkvm-guest/arm-pkvm-guest.c @@ -18,12 +18,11 @@ #include static size_t pkvm_granule; +static bool pkvm_func_range; -static int arm_smccc_do_one_page(u32 func_id, phys_addr_t phys) +static int __arm_smccc_do(u32 func_id, phys_addr_t phys, int numgranules) { - phys_addr_t end = phys + PAGE_SIZE; - - while (phys < end) { + while (numgranules--) { struct arm_smccc_res res; arm_smccc_1_1_invoke(func_id, phys, 0, 0, &res); @@ -36,33 +35,53 @@ static int arm_smccc_do_one_page(u32 func_id, phys_addr_t phys) return 0; } -static int __set_memory_range(u32 func_id, unsigned long start, int numpages) +static int __arm_smccc_do_range(u32 func_id, phys_addr_t phys, int numgranules) { - void *addr = (void *)start, *end = addr + numpages * PAGE_SIZE; + while (numgranules) { + struct arm_smccc_res res; - while (addr < end) { - int err; + arm_smccc_1_1_invoke(func_id, phys, numgranules, 0, &res); + if (res.a0 != SMCCC_RET_SUCCESS) + return -EPERM; - err = arm_smccc_do_one_page(func_id, virt_to_phys(addr)); - if (err) - return err; - - addr += PAGE_SIZE; + phys += pkvm_granule * res.a1; + numgranules -= res.a1; } return 0; } +/* + * Apply func_id on the range [phys : phys + numpages * PAGE_SIZE) + * + * Use with cautious. Boundaries of the range will be aligned with pkvm_granule. + * This might lead to overshooting when pkvm_granule > PAGE_SIZE. + */ +static int arm_smccc_do_range(u32 func_id, phys_addr_t phys, int numpages, + bool func_has_range) +{ + size_t size = numpages * PAGE_SIZE; + int numgranules; + + numgranules = DIV_ROUND_UP(size, pkvm_granule); + phys = ALIGN_DOWN(phys, pkvm_granule); + + if (func_has_range) + return __arm_smccc_do_range(func_id, phys, numgranules); + + return __arm_smccc_do(func_id, phys, numgranules); +} + static int pkvm_set_memory_encrypted(unsigned long addr, int numpages) { - return __set_memory_range(ARM_SMCCC_VENDOR_HYP_KVM_MEM_UNSHARE_FUNC_ID, - addr, numpages); + return arm_smccc_do_range(ARM_SMCCC_VENDOR_HYP_KVM_MEM_UNSHARE_FUNC_ID, + virt_to_phys((void *)addr), numpages, pkvm_func_range); } static int pkvm_set_memory_decrypted(unsigned long addr, int numpages) { - return __set_memory_range(ARM_SMCCC_VENDOR_HYP_KVM_MEM_SHARE_FUNC_ID, - addr, numpages); + return arm_smccc_do_range(ARM_SMCCC_VENDOR_HYP_KVM_MEM_SHARE_FUNC_ID, + virt_to_phys((void *)addr), numpages, pkvm_func_range); } static const struct arm64_mem_crypt_ops pkvm_crypt_ops = { @@ -73,8 +92,10 @@ static const struct arm64_mem_crypt_ops pkvm_crypt_ops = { static int mmio_guard_ioremap_hook(phys_addr_t phys, size_t size, pgprot_t *prot) { - phys_addr_t end; pteval_t protval = pgprot_val(*prot); + u32 func_id = pkvm_func_range ? + ARM_SMCCC_VENDOR_HYP_KVM_MMIO_RGUARD_MAP_FUNC_ID : + ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_MAP_FUNC_ID; /* * We only expect MMIO emulation for regions mapped with device @@ -83,15 +104,8 @@ static int mmio_guard_ioremap_hook(phys_addr_t phys, size_t size, if (protval != PROT_DEVICE_nGnRE && protval != PROT_DEVICE_nGnRnE) return 0; - phys = PAGE_ALIGN_DOWN(phys); - end = phys + PAGE_ALIGN(size); - - while (phys < end) { - const int func_id = ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_MAP_FUNC_ID; - - WARN_ON_ONCE(arm_smccc_do_one_page(func_id, phys)); - phys += PAGE_SIZE; - } + WARN_ON_ONCE(arm_smccc_do_range(func_id, phys, DIV_ROUND_UP(size, PAGE_SIZE), + pkvm_func_range)); return 0; } @@ -137,6 +151,7 @@ void pkvm_init_hyp_services(void) return; pkvm_granule = res.a0; + pkvm_func_range = !!res.a1; if (kvm_arm_hyp_service_available(ARM_SMCCC_KVM_FUNC_MEM_SHARE) && kvm_arm_hyp_service_available(ARM_SMCCC_KVM_FUNC_MEM_UNSHARE))