diff --git a/include/net/xfrm.h b/include/net/xfrm.h index 865ecb643fa4..cd4ff247e079 100644 --- a/include/net/xfrm.h +++ b/include/net/xfrm.h @@ -233,6 +233,7 @@ struct xfrm_state { /* Data for encapsulator */ struct xfrm_encap_tmpl *encap; + struct sock __rcu *encap_sk; /* NAT keepalive */ u32 nat_keepalive_interval; /* seconds */ diff --git a/net/ipv4/esp4.c b/net/ipv4/esp4.c index cbe4c6fc8b8e..f0a7f06df3ad 100644 --- a/net/ipv4/esp4.c +++ b/net/ipv4/esp4.c @@ -120,16 +120,47 @@ static void esp_ssg_unref(struct xfrm_state *x, void *tmp, struct sk_buff *skb) } #ifdef CONFIG_INET_ESPINTCP +struct esp_tcp_sk { + struct sock *sk; + struct rcu_head rcu; +}; + +static void esp_free_tcp_sk(struct rcu_head *head) +{ + struct esp_tcp_sk *esk = container_of(head, struct esp_tcp_sk, rcu); + + sock_put(esk->sk); + kfree(esk); +} + static struct sock *esp_find_tcp_sk(struct xfrm_state *x) { struct xfrm_encap_tmpl *encap = x->encap; struct net *net = xs_net(x); + struct esp_tcp_sk *esk; __be16 sport, dport; + struct sock *nsk; struct sock *sk; + sk = rcu_dereference(x->encap_sk); + if (sk && sk->sk_state == TCP_ESTABLISHED) + return sk; + spin_lock_bh(&x->lock); sport = encap->encap_sport; dport = encap->encap_dport; + nsk = rcu_dereference_protected(x->encap_sk, + lockdep_is_held(&x->lock)); + if (sk && sk == nsk) { + esk = kmalloc(sizeof(*esk), GFP_ATOMIC); + if (!esk) { + spin_unlock_bh(&x->lock); + return ERR_PTR(-ENOMEM); + } + RCU_INIT_POINTER(x->encap_sk, NULL); + esk->sk = sk; + call_rcu(&esk->rcu, esp_free_tcp_sk); + } spin_unlock_bh(&x->lock); sk = inet_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, x->id.daddr.a4, @@ -142,6 +173,20 @@ static struct sock *esp_find_tcp_sk(struct xfrm_state *x) return ERR_PTR(-EINVAL); } + spin_lock_bh(&x->lock); + nsk = rcu_dereference_protected(x->encap_sk, + lockdep_is_held(&x->lock)); + if (encap->encap_sport != sport || + encap->encap_dport != dport) { + sock_put(sk); + sk = nsk ?: ERR_PTR(-EREMCHG); + } else if (sk == nsk) { + sock_put(sk); + } else { + rcu_assign_pointer(x->encap_sk, sk); + } + spin_unlock_bh(&x->lock); + return sk; } @@ -166,8 +211,6 @@ static int esp_output_tcp_finish(struct xfrm_state *x, struct sk_buff *skb) err = espintcp_push_skb(sk, skb); bh_unlock_sock(sk); - sock_put(sk); - out: rcu_read_unlock(); return err; @@ -351,8 +394,6 @@ static struct ip_esp_hdr *esp_output_tcp_encap(struct xfrm_state *x, if (IS_ERR(sk)) return ERR_CAST(sk); - sock_put(sk); - *lenp = htons(len); esph = (struct ip_esp_hdr *)(lenp + 1); diff --git a/net/ipv6/esp6.c b/net/ipv6/esp6.c index 62d17d7f6d9a..3810cfbc4410 100644 --- a/net/ipv6/esp6.c +++ b/net/ipv6/esp6.c @@ -137,16 +137,47 @@ static void esp_ssg_unref(struct xfrm_state *x, void *tmp, struct sk_buff *skb) } #ifdef CONFIG_INET6_ESPINTCP +struct esp_tcp_sk { + struct sock *sk; + struct rcu_head rcu; +}; + +static void esp_free_tcp_sk(struct rcu_head *head) +{ + struct esp_tcp_sk *esk = container_of(head, struct esp_tcp_sk, rcu); + + sock_put(esk->sk); + kfree(esk); +} + static struct sock *esp6_find_tcp_sk(struct xfrm_state *x) { struct xfrm_encap_tmpl *encap = x->encap; struct net *net = xs_net(x); + struct esp_tcp_sk *esk; __be16 sport, dport; + struct sock *nsk; struct sock *sk; + sk = rcu_dereference(x->encap_sk); + if (sk && sk->sk_state == TCP_ESTABLISHED) + return sk; + spin_lock_bh(&x->lock); sport = encap->encap_sport; dport = encap->encap_dport; + nsk = rcu_dereference_protected(x->encap_sk, + lockdep_is_held(&x->lock)); + if (sk && sk == nsk) { + esk = kmalloc(sizeof(*esk), GFP_ATOMIC); + if (!esk) { + spin_unlock_bh(&x->lock); + return ERR_PTR(-ENOMEM); + } + RCU_INIT_POINTER(x->encap_sk, NULL); + esk->sk = sk; + call_rcu(&esk->rcu, esp_free_tcp_sk); + } spin_unlock_bh(&x->lock); sk = __inet6_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, &x->id.daddr.in6, @@ -159,6 +190,20 @@ static struct sock *esp6_find_tcp_sk(struct xfrm_state *x) return ERR_PTR(-EINVAL); } + spin_lock_bh(&x->lock); + nsk = rcu_dereference_protected(x->encap_sk, + lockdep_is_held(&x->lock)); + if (encap->encap_sport != sport || + encap->encap_dport != dport) { + sock_put(sk); + sk = nsk ?: ERR_PTR(-EREMCHG); + } else if (sk == nsk) { + sock_put(sk); + } else { + rcu_assign_pointer(x->encap_sk, sk); + } + spin_unlock_bh(&x->lock); + return sk; } @@ -183,8 +228,6 @@ static int esp_output_tcp_finish(struct xfrm_state *x, struct sk_buff *skb) err = espintcp_push_skb(sk, skb); bh_unlock_sock(sk); - sock_put(sk); - out: rcu_read_unlock(); return err; @@ -381,8 +424,6 @@ static struct ip_esp_hdr *esp6_output_tcp_encap(struct xfrm_state *x, if (IS_ERR(sk)) return ERR_CAST(sk); - sock_put(sk); - *lenp = htons(len); esph = (struct ip_esp_hdr *)(lenp + 1); diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index 2f4c348fe6c1..b839cfe31296 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -773,6 +773,9 @@ int __xfrm_state_delete(struct xfrm_state *x) xfrm_nat_keepalive_state_updated(x); spin_unlock(&net->xfrm.xfrm_state_lock); + if (x->encap_sk) + sock_put(rcu_dereference_raw(x->encap_sk)); + xfrm_dev_state_delete(x); /* All xfrm_state objects are created by xfrm_state_alloc.