test_rhashtable.c 10.6 KB
Newer Older
1 2 3
/*
 * Resizable, Scalable, Concurrent Hash Table
 *
4
 * Copyright (c) 2014-2015 Thomas Graf <tgraf@suug.ch>
5 6 7 8 9 10 11 12 13 14 15 16 17 18
 * Copyright (c) 2008-2014 Patrick McHardy <kaber@trash.net>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as
 * published by the Free Software Foundation.
 */

/**************************************************************************
 * Self Test
 **************************************************************************/

#include <linux/init.h>
#include <linux/jhash.h>
#include <linux/kernel.h>
19
#include <linux/kthread.h>
20 21 22
#include <linux/module.h>
#include <linux/rcupdate.h>
#include <linux/rhashtable.h>
23
#include <linux/semaphore.h>
24
#include <linux/slab.h>
25
#include <linux/sched.h>
26
#include <linux/vmalloc.h>
27

28
#define MAX_ENTRIES	1000000
29
#define TEST_INSERT_FAIL INT_MAX
30 31 32 33 34 35 36 37 38

static int entries = 50000;
module_param(entries, int, 0);
MODULE_PARM_DESC(entries, "Number of entries to add (default: 50000)");

static int runs = 4;
module_param(runs, int, 0);
MODULE_PARM_DESC(runs, "Number of test runs per variant (default: 4)");

39
static int max_size = 0;
40
module_param(max_size, int, 0);
41
MODULE_PARM_DESC(max_size, "Maximum table size (default: calculated)");
42 43 44 45 46 47 48 49

static bool shrinking = false;
module_param(shrinking, bool, 0);
MODULE_PARM_DESC(shrinking, "Enable automatic shrinking (default: off)");

static int size = 8;
module_param(size, int, 0);
MODULE_PARM_DESC(size, "Initial size hint of table (default: 8)");
50

51 52 53 54
static int tcount = 10;
module_param(tcount, int, 0);
MODULE_PARM_DESC(tcount, "Number of threads to spawn (default: 10)");

55 56 57 58
static bool enomem_retry = false;
module_param(enomem_retry, bool, 0);
MODULE_PARM_DESC(enomem_retry, "Retry insert even if -ENOMEM was returned (default: off)");

59 60 61 62 63
struct test_obj_val {
	int	id;
	int	tid;
};

64
struct test_obj {
65
	struct test_obj_val	value;
66 67 68
	struct rhash_head	node;
};

69 70 71 72 73 74
struct thread_data {
	int id;
	struct task_struct *task;
	struct test_obj *objs;
};

75 76
static struct test_obj array[MAX_ENTRIES];

77
static struct rhashtable_params test_rht_params = {
78 79
	.head_offset = offsetof(struct test_obj, node),
	.key_offset = offsetof(struct test_obj, value),
80
	.key_len = sizeof(struct test_obj_val),
81 82 83 84
	.hashfn = jhash,
	.nulls_base = (3U << RHT_BASE_SHIFT),
};

85 86 87
static struct semaphore prestart_sem;
static struct semaphore startup_sem = __SEMAPHORE_INITIALIZER(startup_sem, 0);

88 89 90
static int insert_retry(struct rhashtable *ht, struct rhash_head *obj,
                        const struct rhashtable_params params)
{
91
	int err, retries = -1, enomem_retries = 0;
92 93 94 95 96

	do {
		retries++;
		cond_resched();
		err = rhashtable_insert_fast(ht, obj, params);
97 98 99 100
		if (err == -ENOMEM && enomem_retry) {
			enomem_retries++;
			err = -EBUSY;
		}
101 102
	} while (err == -EBUSY);

103 104 105 106
	if (enomem_retries)
		pr_info(" %u insertions retried after -ENOMEM\n",
			enomem_retries);

107 108 109
	return err ? : retries;
}

110 111 112 113
static int __init test_rht_lookup(struct rhashtable *ht)
{
	unsigned int i;

114
	for (i = 0; i < entries * 2; i++) {
115 116
		struct test_obj *obj;
		bool expected = !(i % 2);
117 118 119
		struct test_obj_val key = {
			.id = i,
		};
120

121
		if (array[i / 2].value.id == TEST_INSERT_FAIL)
122 123
			expected = false;

124
		obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
125 126

		if (expected && !obj) {
127
			pr_warn("Test failed: Could not find key %u\n", key.id);
128 129 130
			return -ENOENT;
		} else if (!expected && obj) {
			pr_warn("Test failed: Unexpected entry found for key %u\n",
131
				key.id);
132 133
			return -EEXIST;
		} else if (expected && obj) {
134
			if (obj->value.id != i) {
135
				pr_warn("Test failed: Lookup value mismatch %u!=%u\n",
136
					obj->value.id, i);
137 138 139
				return -EINVAL;
			}
		}
140 141

		cond_resched_rcu();
142 143 144 145 146
	}

	return 0;
}

147
static void test_bucket_stats(struct rhashtable *ht)
148
{
149 150
	unsigned int err, total = 0, chain_len = 0;
	struct rhashtable_iter hti;
151 152
	struct rhash_head *pos;

153
	err = rhashtable_walk_init(ht, &hti, GFP_KERNEL);
154 155 156 157
	if (err) {
		pr_warn("Test failed: allocation error");
		return;
	}
158

159 160 161 162 163
	err = rhashtable_walk_start(&hti);
	if (err && err != -EAGAIN) {
		pr_warn("Test failed: iterator failed: %d\n", err);
		return;
	}
164

165 166 167 168 169 170 171 172 173
	while ((pos = rhashtable_walk_next(&hti))) {
		if (PTR_ERR(pos) == -EAGAIN) {
			pr_info("Info: encountered resize\n");
			chain_len++;
			continue;
		} else if (IS_ERR(pos)) {
			pr_warn("Test failed: rhashtable_walk_next() error: %ld\n",
				PTR_ERR(pos));
			break;
174 175
		}

176
		total++;
177 178
	}

179 180 181 182 183
	rhashtable_walk_stop(&hti);
	rhashtable_walk_exit(&hti);

	pr_info("  Traversal complete: counted=%u, nelems=%u, entries=%d, table-jumps=%u\n",
		total, atomic_read(&ht->nelems), entries, chain_len);
184

185
	if (total != atomic_read(&ht->nelems) || total != entries)
186 187 188
		pr_warn("Test failed: Total count mismatch ^^^");
}

189
static s64 __init test_rhashtable(struct rhashtable *ht)
190 191 192
{
	struct test_obj *obj;
	int err;
193
	unsigned int i, insert_retries = 0;
194
	s64 start, end;
195 196 197

	/*
	 * Insertion Test:
198
	 * Insert entries into table with all keys even numbers
199
	 */
200 201 202
	pr_info("  Adding %d keys\n", entries);
	start = ktime_get_ns();
	for (i = 0; i < entries; i++) {
203
		struct test_obj *obj = &array[i];
204

205
		obj->value.id = i * 2;
206 207 208 209
		err = insert_retry(ht, &obj->node, test_rht_params);
		if (err > 0)
			insert_retries += err;
		else if (err)
210
			return err;
211 212
	}

213 214 215
	if (insert_retries)
		pr_info("  %u insertions retried due to memory pressure\n",
			insert_retries);
216

217
	test_bucket_stats(ht);
218 219 220 221
	rcu_read_lock();
	test_rht_lookup(ht);
	rcu_read_unlock();

222
	test_bucket_stats(ht);
223

224 225
	pr_info("  Deleting %d keys\n", entries);
	for (i = 0; i < entries; i++) {
226 227 228
		struct test_obj_val key = {
			.id = i * 2,
		};
229

230
		if (array[i].value.id != TEST_INSERT_FAIL) {
231 232
			obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
			BUG_ON(!obj);
233

234 235
			rhashtable_remove_fast(ht, &obj->node, test_rht_params);
		}
236 237

		cond_resched();
238 239
	}

240 241 242 243
	end = ktime_get_ns();
	pr_info("  Duration of test: %lld ns\n", end - start);

	return end - start;
244 245
}

246 247
static struct rhashtable ht;

248 249 250 251 252 253
static int thread_lookup_test(struct thread_data *tdata)
{
	int i, err = 0;

	for (i = 0; i < entries; i++) {
		struct test_obj *obj;
254 255 256 257
		struct test_obj_val key = {
			.id = i,
			.tid = tdata->id,
		};
258 259

		obj = rhashtable_lookup_fast(&ht, &key, test_rht_params);
260 261
		if (obj && (tdata->objs[i].value.id == TEST_INSERT_FAIL)) {
			pr_err("  found unexpected object %d-%d\n", key.tid, key.id);
262
			err++;
263 264
		} else if (!obj && (tdata->objs[i].value.id != TEST_INSERT_FAIL)) {
			pr_err("  object %d-%d not found!\n", key.tid, key.id);
265
			err++;
266 267 268
		} else if (obj && memcmp(&obj->value, &key, sizeof(key))) {
			pr_err("  wrong object returned (got %d-%d, expected %d-%d)\n",
			       obj->value.tid, obj->value.id, key.tid, key.id);
269 270
			err++;
		}
271 272

		cond_resched();
273 274 275 276 277 278
	}
	return err;
}

static int threadfunc(void *data)
{
279
	int i, step, err = 0, insert_retries = 0;
280 281 282 283 284 285 286
	struct thread_data *tdata = data;

	up(&prestart_sem);
	if (down_interruptible(&startup_sem))
		pr_err("  thread[%d]: down_interruptible failed\n", tdata->id);

	for (i = 0; i < entries; i++) {
287 288
		tdata->objs[i].value.id = i;
		tdata->objs[i].value.tid = tdata->id;
289 290 291
		err = insert_retry(&ht, &tdata->objs[i].node, test_rht_params);
		if (err > 0) {
			insert_retries += err;
292 293 294 295 296 297
		} else if (err) {
			pr_err("  thread[%d]: rhashtable_insert_fast failed\n",
			       tdata->id);
			goto out;
		}
	}
298 299 300
	if (insert_retries)
		pr_info("  thread[%d]: %u insertions retried due to memory pressure\n",
			tdata->id, insert_retries);
301 302 303 304 305 306 307 308 309 310

	err = thread_lookup_test(tdata);
	if (err) {
		pr_err("  thread[%d]: rhashtable_lookup_test failed\n",
		       tdata->id);
		goto out;
	}

	for (step = 10; step > 0; step--) {
		for (i = 0; i < entries; i += step) {
311
			if (tdata->objs[i].value.id == TEST_INSERT_FAIL)
312 313 314 315 316 317 318 319
				continue;
			err = rhashtable_remove_fast(&ht, &tdata->objs[i].node,
			                             test_rht_params);
			if (err) {
				pr_err("  thread[%d]: rhashtable_remove_fast failed\n",
				       tdata->id);
				goto out;
			}
320
			tdata->objs[i].value.id = TEST_INSERT_FAIL;
321 322

			cond_resched();
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
		}
		err = thread_lookup_test(tdata);
		if (err) {
			pr_err("  thread[%d]: rhashtable_lookup_test (2) failed\n",
			       tdata->id);
			goto out;
		}
	}
out:
	while (!kthread_should_stop()) {
		set_current_state(TASK_INTERRUPTIBLE);
		schedule();
	}
	return err;
}

339 340
static int __init test_rht_init(void)
{
341
	int i, err, started_threads = 0, failed_threads = 0;
342
	u64 total_time = 0;
343 344
	struct thread_data *tdata;
	struct test_obj *objs;
345

346
	entries = min(entries, MAX_ENTRIES);
347

348
	test_rht_params.automatic_shrinking = shrinking;
349
	test_rht_params.max_size = max_size ? : roundup_pow_of_two(entries);
350
	test_rht_params.nelem_hint = size;
351

352 353
	pr_info("Running rhashtable test nelem=%d, max_size=%d, shrinking=%d\n",
		size, max_size, shrinking);
354

355 356
	for (i = 0; i < runs; i++) {
		s64 time;
357

358
		pr_info("Test %02d:\n", i);
359
		memset(&array, 0, sizeof(array));
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
		err = rhashtable_init(&ht, &test_rht_params);
		if (err < 0) {
			pr_warn("Test failed: Unable to initialize hashtable: %d\n",
				err);
			continue;
		}

		time = test_rhashtable(&ht);
		rhashtable_destroy(&ht);
		if (time < 0) {
			pr_warn("Test failed: return code %lld\n", time);
			return -EINVAL;
		}

		total_time += time;
	}

377 378
	do_div(total_time, runs);
	pr_info("Average test time: %llu\n", total_time);
379

380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
	if (!tcount)
		return 0;

	pr_info("Testing concurrent rhashtable access from %d threads\n",
	        tcount);
	sema_init(&prestart_sem, 1 - tcount);
	tdata = vzalloc(tcount * sizeof(struct thread_data));
	if (!tdata)
		return -ENOMEM;
	objs  = vzalloc(tcount * entries * sizeof(struct test_obj));
	if (!objs) {
		vfree(tdata);
		return -ENOMEM;
	}

395 396
	test_rht_params.max_size = max_size ? :
	                           roundup_pow_of_two(tcount * entries);
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
	err = rhashtable_init(&ht, &test_rht_params);
	if (err < 0) {
		pr_warn("Test failed: Unable to initialize hashtable: %d\n",
			err);
		vfree(tdata);
		vfree(objs);
		return -EINVAL;
	}
	for (i = 0; i < tcount; i++) {
		tdata[i].id = i;
		tdata[i].objs = objs + i * entries;
		tdata[i].task = kthread_run(threadfunc, &tdata[i],
		                            "rhashtable_thrad[%d]", i);
		if (IS_ERR(tdata[i].task))
			pr_err(" kthread_run failed for thread %d\n", i);
		else
			started_threads++;
	}
	if (down_interruptible(&prestart_sem))
		pr_err("  down interruptible failed\n");
	for (i = 0; i < tcount; i++)
		up(&startup_sem);
	for (i = 0; i < tcount; i++) {
		if (IS_ERR(tdata[i].task))
			continue;
		if ((err = kthread_stop(tdata[i].task))) {
			pr_warn("Test failed: thread %d returned: %d\n",
			        i, err);
			failed_threads++;
		}
	}
	pr_info("Started %d threads, %d failed\n",
	        started_threads, failed_threads);
	rhashtable_destroy(&ht);
	vfree(tdata);
	vfree(objs);
433
	return 0;
434 435
}

436 437 438 439
static void __exit test_rht_exit(void)
{
}

440
module_init(test_rht_init);
441
module_exit(test_rht_exit);
442 443

MODULE_LICENSE("GPL v2");