diff --git a/kernel/bpf/lpm_trie.c b/kernel/bpf/lpm_trie.c index 584e02227671bc6d26423a07338fabd556f761d0..d7ea9621851650da7cc2ec51fb8b550a57081353 100644 --- a/kernel/bpf/lpm_trie.c +++ b/kernel/bpf/lpm_trie.c @@ -591,9 +591,100 @@ static void trie_free(struct bpf_map *map) raw_spin_unlock(&trie->lock); } -static int trie_get_next_key(struct bpf_map *map, void *key, void *next_key) +static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key) { - return -ENOTSUPP; + struct lpm_trie *trie = container_of(map, struct lpm_trie, map); + struct bpf_lpm_trie_key *key = _key, *next_key = _next_key; + struct lpm_trie_node *node, *next_node = NULL, *parent; + struct lpm_trie_node **node_stack = NULL; + struct lpm_trie_node __rcu **root; + int err = 0, stack_ptr = -1; + unsigned int next_bit; + size_t matchlen; + + /* The get_next_key follows postorder. For the 4 node example in + * the top of this file, the trie_get_next_key() returns the following + * one after another: + * 192.168.0.0/24 + * 192.168.1.0/24 + * 192.168.128.0/24 + * 192.168.0.0/16 + * + * The idea is to return more specific keys before less specific ones. + */ + + /* Empty trie */ + if (!rcu_dereference(trie->root)) + return -ENOENT; + + /* For invalid key, find the leftmost node in the trie */ + if (!key || key->prefixlen > trie->max_prefixlen) { + root = &trie->root; + goto find_leftmost; + } + + node_stack = kmalloc(trie->max_prefixlen * sizeof(struct lpm_trie_node *), + GFP_USER | __GFP_NOWARN); + if (!node_stack) + return -ENOMEM; + + /* Try to find the exact node for the given key */ + for (node = rcu_dereference(trie->root); node;) { + node_stack[++stack_ptr] = node; + matchlen = longest_prefix_match(trie, node, key); + if (node->prefixlen != matchlen || + node->prefixlen == key->prefixlen) + break; + + next_bit = extract_bit(key->data, node->prefixlen); + node = rcu_dereference(node->child[next_bit]); + } + if (!node || node->prefixlen != key->prefixlen || + (node->flags & LPM_TREE_NODE_FLAG_IM)) { + root = &trie->root; + goto find_leftmost; + } + + /* The node with the exactly-matching key has been found, + * find the first node in postorder after the matched node. + */ + node = node_stack[stack_ptr]; + while (stack_ptr > 0) { + parent = node_stack[stack_ptr - 1]; + if (rcu_dereference(parent->child[0]) == node && + rcu_dereference(parent->child[1])) { + root = &parent->child[1]; + goto find_leftmost; + } + if (!(parent->flags & LPM_TREE_NODE_FLAG_IM)) { + next_node = parent; + goto do_copy; + } + + node = parent; + stack_ptr--; + } + + /* did not find anything */ + err = -ENOENT; + goto free_stack; + +find_leftmost: + /* Find the leftmost non-intermediate node, all intermediate nodes + * have exact two children, so this function will never return NULL. + */ + for (node = rcu_dereference(*root); node;) { + if (!(node->flags & LPM_TREE_NODE_FLAG_IM)) + next_node = node; + node = rcu_dereference(node->child[0]); + } +do_copy: + next_key->prefixlen = next_node->prefixlen; + memcpy((void *)next_key + offsetof(struct bpf_lpm_trie_key, data), + next_node->data, trie->data_size); +free_stack: + kfree(node_stack); + return err; } const struct bpf_map_ops trie_map_ops = { diff --git a/tools/testing/selftests/bpf/test_lpm_map.c b/tools/testing/selftests/bpf/test_lpm_map.c index f61480641b6e2a690760ea841db8843c5c87f08d..081510853c6d10a805f0c6e2c3b94be9b8036a63 100644 --- a/tools/testing/selftests/bpf/test_lpm_map.c +++ b/tools/testing/selftests/bpf/test_lpm_map.c @@ -521,6 +521,126 @@ static void test_lpm_delete(void) close(map_fd); } +static void test_lpm_get_next_key(void) +{ + struct bpf_lpm_trie_key *key_p, *next_key_p; + size_t key_size; + __u32 value = 0; + int map_fd; + + key_size = sizeof(*key_p) + sizeof(__u32); + key_p = alloca(key_size); + next_key_p = alloca(key_size); + + map_fd = bpf_create_map(BPF_MAP_TYPE_LPM_TRIE, key_size, sizeof(value), + 100, BPF_F_NO_PREALLOC); + assert(map_fd >= 0); + + /* empty tree. get_next_key should return ENOENT */ + assert(bpf_map_get_next_key(map_fd, NULL, key_p) == -1 && + errno == ENOENT); + + /* get and verify the first key, get the second one should fail. */ + key_p->prefixlen = 16; + inet_pton(AF_INET, "192.168.0.0", key_p->data); + assert(bpf_map_update_elem(map_fd, key_p, &value, 0) == 0); + + memset(key_p, 0, key_size); + assert(bpf_map_get_next_key(map_fd, NULL, key_p) == 0); + assert(key_p->prefixlen == 16 && key_p->data[0] == 192 && + key_p->data[1] == 168); + + assert(bpf_map_get_next_key(map_fd, key_p, next_key_p) == -1 && + errno == ENOENT); + + /* no exact matching key should get the first one in post order. */ + key_p->prefixlen = 8; + assert(bpf_map_get_next_key(map_fd, NULL, key_p) == 0); + assert(key_p->prefixlen == 16 && key_p->data[0] == 192 && + key_p->data[1] == 168); + + /* add one more element (total two) */ + key_p->prefixlen = 24; + inet_pton(AF_INET, "192.168.0.0", key_p->data); + assert(bpf_map_update_elem(map_fd, key_p, &value, 0) == 0); + + memset(key_p, 0, key_size); + assert(bpf_map_get_next_key(map_fd, NULL, key_p) == 0); + assert(key_p->prefixlen == 24 && key_p->data[0] == 192 && + key_p->data[1] == 168 && key_p->data[2] == 0); + + memset(next_key_p, 0, key_size); + assert(bpf_map_get_next_key(map_fd, key_p, next_key_p) == 0); + assert(next_key_p->prefixlen == 16 && next_key_p->data[0] == 192 && + next_key_p->data[1] == 168); + + memcpy(key_p, next_key_p, key_size); + assert(bpf_map_get_next_key(map_fd, key_p, next_key_p) == -1 && + errno == ENOENT); + + /* Add one more element (total three) */ + key_p->prefixlen = 24; + inet_pton(AF_INET, "192.168.128.0", key_p->data); + assert(bpf_map_update_elem(map_fd, key_p, &value, 0) == 0); + + memset(key_p, 0, key_size); + assert(bpf_map_get_next_key(map_fd, NULL, key_p) == 0); + assert(key_p->prefixlen == 24 && key_p->data[0] == 192 && + key_p->data[1] == 168 && key_p->data[2] == 0); + + memset(next_key_p, 0, key_size); + assert(bpf_map_get_next_key(map_fd, key_p, next_key_p) == 0); + assert(next_key_p->prefixlen == 24 && next_key_p->data[0] == 192 && + next_key_p->data[1] == 168 && next_key_p->data[2] == 128); + + memcpy(key_p, next_key_p, key_size); + assert(bpf_map_get_next_key(map_fd, key_p, next_key_p) == 0); + assert(next_key_p->prefixlen == 16 && next_key_p->data[0] == 192 && + next_key_p->data[1] == 168); + + memcpy(key_p, next_key_p, key_size); + assert(bpf_map_get_next_key(map_fd, key_p, next_key_p) == -1 && + errno == ENOENT); + + /* Add one more element (total four) */ + key_p->prefixlen = 24; + inet_pton(AF_INET, "192.168.1.0", key_p->data); + assert(bpf_map_update_elem(map_fd, key_p, &value, 0) == 0); + + memset(key_p, 0, key_size); + assert(bpf_map_get_next_key(map_fd, NULL, key_p) == 0); + assert(key_p->prefixlen == 24 && key_p->data[0] == 192 && + key_p->data[1] == 168 && key_p->data[2] == 0); + + memset(next_key_p, 0, key_size); + assert(bpf_map_get_next_key(map_fd, key_p, next_key_p) == 0); + assert(next_key_p->prefixlen == 24 && next_key_p->data[0] == 192 && + next_key_p->data[1] == 168 && next_key_p->data[2] == 1); + + memcpy(key_p, next_key_p, key_size); + assert(bpf_map_get_next_key(map_fd, key_p, next_key_p) == 0); + assert(next_key_p->prefixlen == 24 && next_key_p->data[0] == 192 && + next_key_p->data[1] == 168 && next_key_p->data[2] == 128); + + memcpy(key_p, next_key_p, key_size); + assert(bpf_map_get_next_key(map_fd, key_p, next_key_p) == 0); + assert(next_key_p->prefixlen == 16 && next_key_p->data[0] == 192 && + next_key_p->data[1] == 168); + + memcpy(key_p, next_key_p, key_size); + assert(bpf_map_get_next_key(map_fd, key_p, next_key_p) == -1 && + errno == ENOENT); + + /* no exact matching key should return the first one in post order */ + key_p->prefixlen = 22; + inet_pton(AF_INET, "192.168.1.0", key_p->data); + assert(bpf_map_get_next_key(map_fd, key_p, next_key_p) == 0); + assert(next_key_p->prefixlen == 24 && next_key_p->data[0] == 192 && + next_key_p->data[1] == 168 && next_key_p->data[2] == 0); + + close(map_fd); +} + int main(void) { struct rlimit limit = { RLIM_INFINITY, RLIM_INFINITY }; @@ -545,6 +665,8 @@ int main(void) test_lpm_delete(); + test_lpm_get_next_key(); + printf("test_lpm: OK\n"); return 0; }