diff --git a/arch/arm64/include/asm/kvm_host.h b/arch/arm64/include/asm/kvm_host.h index 94887579722d..76ad0f2f54d6 100644 --- a/arch/arm64/include/asm/kvm_host.h +++ b/arch/arm64/include/asm/kvm_host.h @@ -272,15 +272,24 @@ struct kvm_pinned_page { struct rb_node node; struct page *page; u64 ipa; + u64 __subtree_last; bool dirty; + u8 order; }; +struct kvm_pinned_page +*kvm_pinned_pages_iter_first(struct rb_root_cached *root, u64 start, u64 end); +struct kvm_pinned_page +*kvm_pinned_pages_iter_next(struct kvm_pinned_page *ppage, u64 start, u64 end); +void kvm_pinned_pages_remove(struct kvm_pinned_page *ppage, + struct rb_root_cached *root); + typedef unsigned int pkvm_handle_t; struct kvm_protected_vm { pkvm_handle_t handle; struct kvm_hyp_memcache stage2_teardown_mc; - struct rb_root pinned_pages; + struct rb_root_cached pinned_pages; gpa_t pvmfw_load_addr; bool enabled; }; diff --git a/arch/arm64/kvm/mmu.c b/arch/arm64/kvm/mmu.c index 2e04ae29aa5a..05997453905b 100644 --- a/arch/arm64/kvm/mmu.c +++ b/arch/arm64/kvm/mmu.c @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -291,6 +292,20 @@ static void invalidate_icache_guest_page(void *va, size_t size) __invalidate_icache_guest_page(va, size); } +static u64 __pinned_page_start(struct kvm_pinned_page *ppage) +{ + return ppage->ipa; +} + +static u64 __pinned_page_end(struct kvm_pinned_page *ppage) +{ + return ppage->ipa + (1 << (ppage->order + PAGE_SHIFT)) - 1; +} + +INTERVAL_TREE_DEFINE(struct kvm_pinned_page, node, u64, __subtree_last, + __pinned_page_start, __pinned_page_end, /* emtpy */, + kvm_pinned_pages); + static int pkvm_unmap_guest(struct kvm *kvm, struct kvm_pinned_page *ppage) { struct mm_struct *mm = kvm->mm; @@ -303,7 +318,7 @@ static int pkvm_unmap_guest(struct kvm *kvm, struct kvm_pinned_page *ppage) return ret; unpin_user_pages_dirty_lock(&ppage->page, 1, ppage->dirty); - rb_erase(&ppage->node, &kvm->arch.pkvm.pinned_pages); + kvm_pinned_pages_remove(ppage, &kvm->arch.pkvm.pinned_pages); kfree(ppage); /* @@ -318,40 +333,17 @@ static int pkvm_unmap_guest(struct kvm *kvm, struct kvm_pinned_page *ppage) return 0; } -static struct rb_node *find_first_ppage_node(struct rb_root *root, u64 ipa) -{ - struct rb_node *node = root->rb_node, *prev = NULL; - struct kvm_pinned_page *ppage; - - while (node) { - ppage = rb_entry(node, struct kvm_pinned_page, node); - if (ppage->ipa == ipa) - return node; - prev = node; - node = (ipa < ppage->ipa) ? node->rb_left : node->rb_right; - } - - return prev; -} - -#define for_ppage_node_in_range(kvm, start, end, __node, __tmp) \ - for (__node = find_first_ppage_node(&(kvm)->arch.pkvm.pinned_pages, start); \ - __node && ({ __tmp = rb_next(__node); 1; }); \ - __node = __tmp) \ - if (rb_entry(__node, struct kvm_pinned_page, node)->ipa < start) \ - continue; \ - else if (rb_entry(__node, struct kvm_pinned_page, node)->ipa >= end) \ - break; \ - else +#define for_ppage_node_in_range(kvm, start, end, __ppage, __tmp) \ + for (__ppage = kvm_pinned_pages_iter_first(&(kvm)->arch.pkvm.pinned_pages, start, end); \ + __ppage && ({ __tmp = kvm_pinned_pages_iter_next(__ppage, start, end); 1; }); \ + __ppage = __tmp) static int pkvm_unmap_range(struct kvm *kvm, u64 start, u64 end) { - struct kvm_pinned_page *ppage; - struct rb_node *node, *tmp; + struct kvm_pinned_page *tmp, *ppage; int ret; - for_ppage_node_in_range(kvm, start, end, node, tmp) { - ppage = rb_entry(node, struct kvm_pinned_page, node); + for_ppage_node_in_range(kvm, start, end, ppage, tmp) { ret = pkvm_unmap_guest(kvm, ppage); if (ret) return ret; @@ -429,7 +421,6 @@ void kvm_stage2_unmap_range(struct kvm_s2_mmu *mmu, phys_addr_t start, static void pkvm_stage2_flush(struct kvm *kvm) { struct kvm_pinned_page *ppage; - struct rb_node *node; /* * Contrary to stage2_apply_range(), we don't need to check @@ -437,8 +428,8 @@ static void pkvm_stage2_flush(struct kvm *kvm) * from a vcpu thread, and the list is only ever freed on VM * destroy (which only occurs when all vcpu are gone). */ - for (node = rb_first(&kvm->arch.pkvm.pinned_pages); node; node = rb_next(node)) { - ppage = rb_entry(node, struct kvm_pinned_page, node); + for (ppage = kvm_pinned_pages_iter_first(&kvm->arch.pkvm.pinned_pages, 0, ~(0UL)); + ppage; ppage = kvm_pinned_pages_iter_next(ppage, 0, ~(0UL))) { __clean_dcache_guest_page(page_address(ppage->page), PAGE_SIZE); cond_resched_rwlock_write(&kvm->mmu_lock); } @@ -1041,7 +1032,7 @@ int kvm_init_stage2_mmu(struct kvm *kvm, struct kvm_s2_mmu *mmu, unsigned long t int cpu, err; struct kvm_pgtable *pgt; - kvm->arch.pkvm.pinned_pages = RB_ROOT; + kvm->arch.pkvm.pinned_pages = RB_ROOT_CACHED; mmu->arch = &kvm->arch; /* @@ -1330,12 +1321,10 @@ int kvm_phys_addr_ioremap(struct kvm *kvm, phys_addr_t guest_ipa, static int pkvm_wp_range(struct kvm *kvm, u64 start, u64 end) { - struct kvm_pinned_page *ppage; - struct rb_node *node, *tmp; + struct kvm_pinned_page *tmp, *ppage; int ret; - for_ppage_node_in_range(kvm, start, end, node, tmp) { - ppage = rb_entry(node, struct kvm_pinned_page, node); + for_ppage_node_in_range(kvm, start, end, ppage, tmp) { ret = kvm_call_hyp_nvhe(__pkvm_host_wrprotect_guest, kvm->arch.pkvm.handle, ppage->ipa >> PAGE_SHIFT); @@ -1643,66 +1632,22 @@ static int pkvm_host_map_guest(u64 pfn, u64 gfn, enum kvm_pgtable_prot prot) return (ret == -EPERM) ? -EAGAIN : ret; } -#define node_ppage(__node) \ - (container_of(__node, struct kvm_pinned_page, node)) - -static int cmp_ppage_ipa(const void *key, const struct rb_node *node) +static struct kvm_pinned_page *find_ppage(struct kvm *kvm, u64 ipa) { - s64 a_ipa = (s64)key; - s64 b_ipa = (s64)node_ppage(node)->ipa; - - if (a_ipa < b_ipa) - return -1; - if (a_ipa > b_ipa) - return 1; - return 0; -} - -static int cmp_ppages(struct rb_node *node, const struct rb_node *parent) -{ - return cmp_ppage_ipa((void *)(node_ppage(node))->ipa, parent); -} - -static struct kvm_pinned_page * -find_ppage_or_above(struct kvm *kvm, phys_addr_t ipa) -{ - struct rb_node *node = kvm->arch.pkvm.pinned_pages.rb_node; - - while (node) { - int ret = cmp_ppage_ipa((void *)ipa, node); - - if (!ret) { - break; - } else if (ret > 0) { - node = node->rb_right; - } else if (ret < 0) { - if (!node->rb_left) - break; - node = node->rb_left; - } - } - - if (!node) - return NULL; - - return node_ppage(node); + return kvm_pinned_pages_iter_first(&kvm->arch.pkvm.pinned_pages, + ipa, ipa + PAGE_SIZE - 1); } static int insert_ppage(struct kvm *kvm, struct kvm_pinned_page *ppage) { - if (rb_find_add(&ppage->node, &kvm->arch.pkvm.pinned_pages, cmp_ppages)) + if (find_ppage(kvm, ppage->ipa)) return -EEXIST; + kvm_pinned_pages_insert(ppage, &kvm->arch.pkvm.pinned_pages); + return 0; } -static struct kvm_pinned_page *find_ppage(struct kvm *kvm, u64 ipa) -{ - struct rb_node *node = rb_find((void *)ipa, &kvm->arch.pkvm.pinned_pages, cmp_ppage_ipa); - - return node ? container_of(node, struct kvm_pinned_page, node) : NULL; -} - static int pkvm_relax_perms(struct kvm *kvm, u64 pfn, u64 gfn, enum kvm_pgtable_prot prot) { @@ -1807,6 +1752,7 @@ static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa, ppage->page = page; ppage->ipa = fault_ipa; + ppage->order = 0; ppage->dirty = kvm->arch.pkvm.enabled; WARN_ON(insert_ppage(kvm, ppage)); write_unlock(&kvm->mmu_lock); @@ -1842,11 +1788,11 @@ int pkvm_mem_abort_range(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa, size_t si idx = srcu_read_lock(&vcpu->kvm->srcu); read_lock(&vcpu->kvm->mmu_lock); - ppage = find_ppage_or_above(vcpu->kvm, fault_ipa); - + ppage = kvm_pinned_pages_iter_first(&vcpu->kvm->arch.pkvm.pinned_pages, + fault_ipa, ipa_end); while (fault_ipa < ipa_end) { if (ppage && ppage->ipa == fault_ipa) { - ppage = node_ppage(rb_next(&ppage->node)); + ppage = kvm_pinned_pages_iter_next(ppage, fault_ipa, ipa_end); } else { gfn_t gfn = gpa_to_gfn(fault_ipa); struct kvm_memory_slot *memslot; @@ -1870,7 +1816,8 @@ int pkvm_mem_abort_range(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa, size_t si * We had to release the mmu_lock so let's update the * reference. */ - ppage = find_ppage_or_above(vcpu->kvm, fault_ipa + PAGE_SIZE); + ppage = kvm_pinned_pages_iter_first(&vcpu->kvm->arch.pkvm.pinned_pages, + fault_ipa + PAGE_SIZE, ipa_end); } fault_ipa += PAGE_SIZE; diff --git a/arch/arm64/kvm/pkvm.c b/arch/arm64/kvm/pkvm.c index af896518d1c2..8e64e1303a9c 100644 --- a/arch/arm64/kvm/pkvm.c +++ b/arch/arm64/kvm/pkvm.c @@ -40,14 +40,6 @@ static unsigned int *hyp_memblock_nr_ptr = &kvm_nvhe_sym(hyp_memblock_nr); phys_addr_t hyp_mem_base; phys_addr_t hyp_mem_size; -static int rb_ppage_cmp(const void *key, const struct rb_node *node) -{ - struct kvm_pinned_page *p = container_of(node, struct kvm_pinned_page, node); - phys_addr_t ipa = (phys_addr_t)key; - - return (ipa < p->ipa) ? -1 : (ipa > p->ipa); -} - static int cmp_hyp_memblock(const void *p1, const void *p2) { const struct memblock_region *r1 = p1; @@ -246,7 +238,6 @@ static void __pkvm_destroy_hyp_vm(struct kvm *host_kvm) struct mm_struct *mm = current->mm; struct kvm_pinned_page *ppage; struct kvm_vcpu *host_vcpu; - struct rb_node *node; unsigned long pages = 0; unsigned long idx; @@ -255,9 +246,10 @@ static void __pkvm_destroy_hyp_vm(struct kvm *host_kvm) WARN_ON(kvm_call_hyp_nvhe(__pkvm_start_teardown_vm, host_kvm->arch.pkvm.handle)); - node = rb_first(&host_kvm->arch.pkvm.pinned_pages); - while (node) { - ppage = rb_entry(node, struct kvm_pinned_page, node); + ppage = kvm_pinned_pages_iter_first(&host_kvm->arch.pkvm.pinned_pages, 0, ~(0UL)); + while (ppage) { + struct kvm_pinned_page *next; + WARN_ON(kvm_call_hyp_nvhe(__pkvm_reclaim_dying_guest_page, host_kvm->arch.pkvm.handle, page_to_pfn(ppage->page), @@ -265,10 +257,11 @@ static void __pkvm_destroy_hyp_vm(struct kvm *host_kvm) cond_resched(); unpin_user_pages_dirty_lock(&ppage->page, 1, ppage->dirty); - node = rb_next(node); - rb_erase(&ppage->node, &host_kvm->arch.pkvm.pinned_pages); + next = kvm_pinned_pages_iter_next(ppage, 0, ~(0UL)); + kvm_pinned_pages_remove(ppage, &host_kvm->arch.pkvm.pinned_pages); kfree(ppage); pages++; + ppage = next; } account_locked_vm(mm, pages, false); @@ -472,20 +465,18 @@ void pkvm_host_reclaim_page(struct kvm *host_kvm, phys_addr_t ipa) { struct kvm_pinned_page *ppage; struct mm_struct *mm = current->mm; - struct rb_node *node; write_lock(&host_kvm->mmu_lock); - node = rb_find((void *)ipa, &host_kvm->arch.pkvm.pinned_pages, - rb_ppage_cmp); - if (node) - rb_erase(node, &host_kvm->arch.pkvm.pinned_pages); + ppage = kvm_pinned_pages_iter_first(&host_kvm->arch.pkvm.pinned_pages, + ipa, ipa + PAGE_SIZE - 1); + if (ppage) + kvm_pinned_pages_remove(ppage, &host_kvm->arch.pkvm.pinned_pages); write_unlock(&host_kvm->mmu_lock); - WARN_ON(!node); - if (!node) + WARN_ON(!ppage); + if (!ppage) return; - ppage = container_of(node, struct kvm_pinned_page, node); account_locked_vm(mm, 1, false); unpin_user_pages_dirty_lock(&ppage->page, 1, true); kfree(ppage);