diff --git a/arch/arm64/kvm/hyp/nvhe/mem_protect.c b/arch/arm64/kvm/hyp/nvhe/mem_protect.c index 745c52300a95..0ed9140e999b 100644 --- a/arch/arm64/kvm/hyp/nvhe/mem_protect.c +++ b/arch/arm64/kvm/hyp/nvhe/mem_protect.c @@ -984,10 +984,14 @@ static enum pkvm_page_state host_get_mmio_page_state(kvm_pte_t pte, u64 addr) return state | pkvm_getstate(prot); } +enum host_check_page_state_flags { + HOST_CHECK_NULL_REFCNT = BIT(0), + HOST_CHECK_IS_MEMORY = BIT(1), +}; + static int ___host_check_page_state_range(u64 addr, u64 size, enum pkvm_page_state state, - struct memblock_region *reg, - bool check_null_refcount) + enum host_check_page_state_flags flags) { struct check_walk_data d = { .desired = state, @@ -995,6 +999,16 @@ static int ___host_check_page_state_range(u64 addr, u64 size, }; u64 end = addr + size; struct hyp_page *p; + struct memblock_region *reg; + struct kvm_mem_range range; + + /* Can't check the state of both MMIO and memory regions at once */ + reg = find_mem_range(addr, &range); + if (!reg && (flags & HOST_CHECK_IS_MEMORY)) + return -EINVAL; + + if (!is_in_mem_range(end - 1, &range)) + return -EINVAL; hyp_assert_lock_held(&host_mmu.lock); @@ -1009,7 +1023,7 @@ static int ___host_check_page_state_range(u64 addr, u64 size, p = hyp_phys_to_page(addr); if (p->host_state != state) return -EPERM; - if (check_null_refcount && hyp_refcount_get(p->refcount)) + if ((flags & HOST_CHECK_NULL_REFCNT) && hyp_refcount_get(p->refcount)) return -EINVAL; } @@ -1025,17 +1039,13 @@ static int ___host_check_page_state_range(u64 addr, u64 size, static int __host_check_page_state_range(u64 addr, u64 size, enum pkvm_page_state state) { - struct memblock_region *reg; - struct kvm_mem_range range; - u64 end = addr + size; + enum host_check_page_state_flags flags = HOST_CHECK_IS_MEMORY; - /* Can't check the state of both MMIO and memory regions at once */ - reg = find_mem_range(addr, &range); - if (!is_in_mem_range(end - 1, &range)) - return -EINVAL; + if (state == PKVM_PAGE_OWNED) + flags |= HOST_CHECK_NULL_REFCNT; /* Check the refcount of PAGE_OWNED pages as those may be used for DMA. */ - return ___host_check_page_state_range(addr, size, state, reg, state == PKVM_PAGE_OWNED); + return ___host_check_page_state_range(addr, size, state, flags); } static int __host_set_page_state_range(u64 addr, u64 size, @@ -1578,7 +1588,7 @@ int __pkvm_host_donate_hyp_locked(u64 pfn, u64 nr_pages, enum kvm_pgtable_prot p hyp_lock_component(); - ret = __host_check_page_state_range(phys, size, PKVM_PAGE_OWNED); + ret = ___host_check_page_state_range(phys, size, PKVM_PAGE_OWNED, HOST_CHECK_NULL_REFCNT); if (ret) goto unlock; if (IS_ENABLED(CONFIG_PKVM_STRICT_CHECKS)) { @@ -1619,7 +1629,7 @@ int __pkvm_hyp_donate_host(u64 pfn, u64 nr_pages) if (ret) goto unlock; if (IS_ENABLED(CONFIG_PKVM_STRICT_CHECKS)) { - ret = __host_check_page_state_range(phys, size, PKVM_NOPAGE); + ret = ___host_check_page_state_range(phys, size, PKVM_NOPAGE, 0); if (ret) goto unlock; } @@ -1636,18 +1646,12 @@ unlock: int __pkvm_host_donate_ffa(u64 pfn, u64 nr_pages) { u64 size, phys = hyp_pfn_to_phys(pfn), end; - struct kvm_mem_range range; - struct memblock_region *reg; int ret; if (check_shl_overflow(nr_pages, PAGE_SHIFT, &size) || check_add_overflow(phys, size, &end)) return -EINVAL; - reg = find_mem_range(phys, &range); - if (!reg || !is_in_mem_range(end - 1, &range)) - return -EPERM; - host_lock_component(); ret = __host_check_page_state_range(phys, size, PKVM_PAGE_OWNED); @@ -1663,18 +1667,12 @@ unlock: int __pkvm_host_reclaim_ffa(u64 pfn, u64 nr_pages) { u64 size, phys = hyp_pfn_to_phys(pfn), end; - struct memblock_region *reg; - struct kvm_mem_range range; int ret; if (check_shl_overflow(nr_pages, PAGE_SHIFT, &size) || check_add_overflow(phys, size, &end)) return -EINVAL; - reg = find_mem_range(phys, &range); - if (!reg || !is_in_mem_range(end - 1, &range)) - return -EPERM; - host_lock_component(); ret = __host_check_page_state_range(phys, size, PKVM_NOPAGE); @@ -1741,8 +1739,8 @@ int module_change_host_page_prot(u64 pfn, enum kvm_pgtable_prot prot, u64 nr_pag } } else { /* The entire range must be pristine. */ - ret = ___host_check_page_state_range( - addr, nr_pages << PAGE_SHIFT, PKVM_PAGE_OWNED, reg, true); + ret = ___host_check_page_state_range(addr, nr_pages << PAGE_SHIFT, + PKVM_PAGE_OWNED, HOST_CHECK_NULL_REFCNT); if (ret) goto unlock; } @@ -1903,13 +1901,13 @@ static int __pkvm_use_dma_locked(phys_addr_t phys_addr, size_t size, ret = ___host_check_page_state_range(addr, PAGE_SIZE, PKVM_PAGE_TAINTED, - reg, false); + 0); /* Page already tainted */ if (!ret) continue; ret = ___host_check_page_state_range(addr, PAGE_SIZE, PKVM_PAGE_OWNED, - reg, false); + 0); if (ret) return ret; } @@ -2273,7 +2271,7 @@ int __pkvm_host_donate_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu, u64 n host_lock_component(); guest_lock_component(vm); - ret = __host_check_page_state_range(phys, size, PKVM_PAGE_OWNED); + ret = ___host_check_page_state_range(phys, size, PKVM_PAGE_OWNED, HOST_CHECK_NULL_REFCNT); if (ret) goto unlock; ret = __guest_check_page_state_range(vcpu, ipa, size, PKVM_NOPAGE); @@ -2357,7 +2355,8 @@ int __pkvm_host_donate_sglist_guest(struct pkvm_hyp_vcpu *vcpu) goto unlock; } - ret = __host_check_page_state_range(phys, size, PKVM_PAGE_OWNED); + ret = ___host_check_page_state_range(phys, size, PKVM_PAGE_OWNED, + HOST_CHECK_NULL_REFCNT); if (ret) goto unlock; @@ -2424,7 +2423,8 @@ int __pkvm_host_donate_sglist_hyp(struct pkvm_sglist_page *sglist, size_t nr_pag goto unlock; } - ret = __host_check_page_state_range(phys, size, PKVM_PAGE_OWNED); + ret = ___host_check_page_state_range(phys, size, PKVM_PAGE_OWNED, + HOST_CHECK_NULL_REFCNT); if (ret) goto unlock; @@ -2536,7 +2536,8 @@ int __pkvm_host_reclaim_page(struct pkvm_hyp_vm *vm, u64 pfn, u64 ipa, u8 order) switch ((int)guest_get_page_state(pte, ipa)) { case PKVM_PAGE_OWNED: - WARN_ON(__host_check_page_state_range(phys, page_size, PKVM_NOPAGE)); + WARN_ON(___host_check_page_state_range(phys, page_size, PKVM_NOPAGE, + HOST_CHECK_IS_MEMORY)); /* No vCPUs of the guest can run, doing this prior to stage-2 unmap is OK */ hyp_poison_page(phys, page_size); psci_mem_protect_dec(1 << order);