auth.c 21.4 KB
Newer Older
L
Linus Torvalds 已提交
1 2 3 4 5 6 7 8 9 10
/*
 * linux/net/sunrpc/auth.c
 *
 * Generic RPC client authentication API.
 *
 * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
 */

#include <linux/types.h>
#include <linux/sched.h>
11
#include <linux/cred.h>
L
Linus Torvalds 已提交
12 13 14
#include <linux/module.h>
#include <linux/slab.h>
#include <linux/errno.h>
15
#include <linux/hash.h>
L
Linus Torvalds 已提交
16
#include <linux/sunrpc/clnt.h>
C
Chuck Lever 已提交
17
#include <linux/sunrpc/gss_api.h>
L
Linus Torvalds 已提交
18 19
#include <linux/spinlock.h>

J
Jeff Layton 已提交
20
#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
L
Linus Torvalds 已提交
21 22 23
# define RPCDBG_FACILITY	RPCDBG_AUTH
#endif

24 25 26 27 28 29 30 31 32
#define RPC_CREDCACHE_DEFAULT_HASHBITS	(4)
struct rpc_cred_cache {
	struct hlist_head	*hashtable;
	unsigned int		hashbits;
	spinlock_t		lock;
};

static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;

33 34 35
static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
	[RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops,
	[RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops,
L
Linus Torvalds 已提交
36 37 38
	NULL,			/* others can be loadable modules */
};

39
static LIST_HEAD(cred_unused);
40
static unsigned long number_cred_unused;
41

42
#define MAX_HASHTABLE_BITS (14)
43
static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
44 45 46 47 48 49 50
{
	unsigned long num;
	unsigned int nbits;
	int ret;

	if (!val)
		goto out_inval;
D
Daniel Walter 已提交
51
	ret = kstrtoul(val, 0, &num);
52
	if (ret)
53
		goto out_inval;
54
	nbits = fls(num - 1);
55 56 57 58 59 60 61 62
	if (nbits > MAX_HASHTABLE_BITS || nbits < 2)
		goto out_inval;
	*(unsigned int *)kp->arg = nbits;
	return 0;
out_inval:
	return -EINVAL;
}

63
static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp)
64 65 66 67 68 69 70 71 72
{
	unsigned int nbits;

	nbits = *(unsigned int *)kp->arg;
	return sprintf(buffer, "%u", 1U << nbits);
}

#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int);

73
static const struct kernel_param_ops param_ops_hashtbl_sz = {
74 75 76 77
	.set = param_set_hashtbl_sz,
	.get = param_get_hashtbl_sz,
};

78 79 80
module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644);
MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");

81 82 83 84
static unsigned long auth_max_cred_cachesize = ULONG_MAX;
module_param(auth_max_cred_cachesize, ulong, 0644);
MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size");

L
Linus Torvalds 已提交
85 86
static u32
pseudoflavor_to_flavor(u32 flavor) {
87
	if (flavor > RPC_AUTH_MAXFLAVOR)
L
Linus Torvalds 已提交
88 89 90 91 92
		return RPC_AUTH_GSS;
	return flavor;
}

int
93
rpcauth_register(const struct rpc_authops *ops)
L
Linus Torvalds 已提交
94
{
95
	const struct rpc_authops *old;
L
Linus Torvalds 已提交
96 97 98 99
	rpc_authflavor_t flavor;

	if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
		return -EINVAL;
100 101 102 103
	old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], NULL, ops);
	if (old == NULL || old == ops)
		return 0;
	return -EPERM;
L
Linus Torvalds 已提交
104
}
105
EXPORT_SYMBOL_GPL(rpcauth_register);
L
Linus Torvalds 已提交
106 107

int
108
rpcauth_unregister(const struct rpc_authops *ops)
L
Linus Torvalds 已提交
109
{
110
	const struct rpc_authops *old;
L
Linus Torvalds 已提交
111 112 113 114
	rpc_authflavor_t flavor;

	if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
		return -EINVAL;
115 116 117 118 119

	old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], ops, NULL);
	if (old == ops || old == NULL)
		return 0;
	return -EPERM;
L
Linus Torvalds 已提交
120
}
121
EXPORT_SYMBOL_GPL(rpcauth_unregister);
L
Linus Torvalds 已提交
122

123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
static const struct rpc_authops *
rpcauth_get_authops(rpc_authflavor_t flavor)
{
	const struct rpc_authops *ops;

	if (flavor >= RPC_AUTH_MAXFLAVOR)
		return NULL;

	rcu_read_lock();
	ops = rcu_dereference(auth_flavors[flavor]);
	if (ops == NULL) {
		rcu_read_unlock();
		request_module("rpc-auth-%u", flavor);
		rcu_read_lock();
		ops = rcu_dereference(auth_flavors[flavor]);
		if (ops == NULL)
			goto out;
	}
	if (!try_module_get(ops->owner))
		ops = NULL;
out:
	rcu_read_unlock();
	return ops;
}

static void
rpcauth_put_authops(const struct rpc_authops *ops)
{
	module_put(ops->owner);
}

154 155 156 157 158 159 160 161 162 163 164 165
/**
 * rpcauth_get_pseudoflavor - check if security flavor is supported
 * @flavor: a security flavor
 * @info: a GSS mech OID, quality of protection, and service value
 *
 * Verifies that an appropriate kernel module is available or already loaded.
 * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is
 * not supported locally.
 */
rpc_authflavor_t
rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info)
{
166
	const struct rpc_authops *ops = rpcauth_get_authops(flavor);
167 168
	rpc_authflavor_t pseudoflavor;

169
	if (!ops)
170 171 172 173 174
		return RPC_AUTH_MAXFLAVOR;
	pseudoflavor = flavor;
	if (ops->info2flavor != NULL)
		pseudoflavor = ops->info2flavor(info);

175
	rpcauth_put_authops(ops);
176 177 178 179
	return pseudoflavor;
}
EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor);

180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
/**
 * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor
 * @pseudoflavor: GSS pseudoflavor to match
 * @info: rpcsec_gss_info structure to fill in
 *
 * Returns zero and fills in "info" if pseudoflavor matches a
 * supported mechanism.
 */
int
rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info)
{
	rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor);
	const struct rpc_authops *ops;
	int result;

195
	ops = rpcauth_get_authops(flavor);
196 197 198 199 200 201 202
	if (ops == NULL)
		return -ENOENT;

	result = -ENOENT;
	if (ops->flavor2info != NULL)
		result = ops->flavor2info(pseudoflavor, info);

203
	rpcauth_put_authops(ops);
204 205 206 207
	return result;
}
EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo);

C
Chuck Lever 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220
/**
 * rpcauth_list_flavors - discover registered flavors and pseudoflavors
 * @array: array to fill in
 * @size: size of "array"
 *
 * Returns the number of array items filled in, or a negative errno.
 *
 * The returned array is not sorted by any policy.  Callers should not
 * rely on the order of the items in the returned array.
 */
int
rpcauth_list_flavors(rpc_authflavor_t *array, int size)
{
221 222 223
	const struct rpc_authops *ops;
	rpc_authflavor_t flavor, pseudos[4];
	int i, len, result = 0;
C
Chuck Lever 已提交
224

225
	rcu_read_lock();
C
Chuck Lever 已提交
226
	for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) {
227
		ops = rcu_dereference(auth_flavors[flavor]);
C
Chuck Lever 已提交
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
		if (result >= size) {
			result = -ENOMEM;
			break;
		}

		if (ops == NULL)
			continue;
		if (ops->list_pseudoflavors == NULL) {
			array[result++] = ops->au_flavor;
			continue;
		}
		len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos));
		if (len < 0) {
			result = len;
			break;
		}
		for (i = 0; i < len; i++) {
			if (result >= size) {
				result = -ENOMEM;
				break;
			}
			array[result++] = pseudos[i];
		}
	}
252
	rcu_read_unlock();
C
Chuck Lever 已提交
253 254 255 256 257 258

	dprintk("RPC:       %s returns %d\n", __func__, result);
	return result;
}
EXPORT_SYMBOL_GPL(rpcauth_list_flavors);

L
Linus Torvalds 已提交
259
struct rpc_auth *
260
rpcauth_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
L
Linus Torvalds 已提交
261
{
262
	struct rpc_auth	*auth = ERR_PTR(-EINVAL);
263
	const struct rpc_authops *ops;
264
	u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor);
L
Linus Torvalds 已提交
265

266 267
	ops = rpcauth_get_authops(flavor);
	if (ops == NULL)
268 269
		goto out;

270
	auth = ops->create(args, clnt);
271 272

	rpcauth_put_authops(ops);
273 274
	if (IS_ERR(auth))
		return auth;
L
Linus Torvalds 已提交
275
	if (clnt->cl_auth)
276
		rpcauth_release(clnt->cl_auth);
L
Linus Torvalds 已提交
277
	clnt->cl_auth = auth;
278 279

out:
L
Linus Torvalds 已提交
280 281
	return auth;
}
282
EXPORT_SYMBOL_GPL(rpcauth_create);
L
Linus Torvalds 已提交
283 284

void
285
rpcauth_release(struct rpc_auth *auth)
L
Linus Torvalds 已提交
286
{
287
	if (!refcount_dec_and_test(&auth->au_count))
L
Linus Torvalds 已提交
288 289 290 291 292 293
		return;
	auth->au_ops->destroy(auth);
}

static DEFINE_SPINLOCK(rpc_credcache_lock);

294 295 296 297 298
/*
 * On success, the caller is responsible for freeing the reference
 * held by the hashtable
 */
static bool
299 300
rpcauth_unhash_cred_locked(struct rpc_cred *cred)
{
301 302
	if (!test_and_clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
		return false;
303
	hlist_del_rcu(&cred->cr_hash);
304
	return true;
305 306
}

307
static bool
308 309 310
rpcauth_unhash_cred(struct rpc_cred *cred)
{
	spinlock_t *cache_lock;
311
	bool ret;
312

313 314
	if (!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
		return false;
315 316
	cache_lock = &cred->cr_auth->au_credcache->lock;
	spin_lock(cache_lock);
317
	ret = rpcauth_unhash_cred_locked(cred);
318
	spin_unlock(cache_lock);
319
	return ret;
320 321
}

L
Linus Torvalds 已提交
322 323 324 325
/*
 * Initialize RPC credential cache
 */
int
326
rpcauth_init_credcache(struct rpc_auth *auth)
L
Linus Torvalds 已提交
327 328
{
	struct rpc_cred_cache *new;
329
	unsigned int hashsize;
L
Linus Torvalds 已提交
330

331
	new = kmalloc(sizeof(*new), GFP_KERNEL);
L
Linus Torvalds 已提交
332
	if (!new)
333 334
		goto out_nocache;
	new->hashbits = auth_hashbits;
335
	hashsize = 1U << new->hashbits;
336 337 338
	new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL);
	if (!new->hashtable)
		goto out_nohashtbl;
339
	spin_lock_init(&new->lock);
L
Linus Torvalds 已提交
340 341
	auth->au_credcache = new;
	return 0;
342 343 344 345
out_nohashtbl:
	kfree(new);
out_nocache:
	return -ENOMEM;
L
Linus Torvalds 已提交
346
}
347
EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
L
Linus Torvalds 已提交
348

349 350 351 352 353 354 355 356 357 358 359 360 361
/*
 * Setup a credential key lifetime timeout notification
 */
int
rpcauth_key_timeout_notify(struct rpc_auth *auth, struct rpc_cred *cred)
{
	if (!cred->cr_auth->au_ops->key_timeout)
		return 0;
	return cred->cr_auth->au_ops->key_timeout(auth, cred);
}
EXPORT_SYMBOL_GPL(rpcauth_key_timeout_notify);

bool
362
rpcauth_cred_key_to_expire(struct rpc_auth *auth, struct rpc_cred *cred)
363
{
364 365
	if (auth->au_flags & RPCAUTH_AUTH_NO_CRKEY_TIMEOUT)
		return false;
366 367 368 369 370 371
	if (!cred->cr_ops->crkey_to_expire)
		return false;
	return cred->cr_ops->crkey_to_expire(cred);
}
EXPORT_SYMBOL_GPL(rpcauth_cred_key_to_expire);

372 373 374 375 376 377 378 379 380
char *
rpcauth_stringify_acceptor(struct rpc_cred *cred)
{
	if (!cred->cr_ops->crstringify_acceptor)
		return NULL;
	return cred->cr_ops->crstringify_acceptor(cred);
}
EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor);

L
Linus Torvalds 已提交
381 382 383 384
/*
 * Destroy a list of credentials
 */
static inline
385
void rpcauth_destroy_credlist(struct list_head *head)
L
Linus Torvalds 已提交
386 387 388
{
	struct rpc_cred *cred;

389 390 391
	while (!list_empty(head)) {
		cred = list_entry(head->next, struct rpc_cred, cr_lru);
		list_del_init(&cred->cr_lru);
L
Linus Torvalds 已提交
392 393 394 395
		put_rpccred(cred);
	}
}

396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
static void
rpcauth_lru_add_locked(struct rpc_cred *cred)
{
	if (!list_empty(&cred->cr_lru))
		return;
	number_cred_unused++;
	list_add_tail(&cred->cr_lru, &cred_unused);
}

static void
rpcauth_lru_add(struct rpc_cred *cred)
{
	if (!list_empty(&cred->cr_lru))
		return;
	spin_lock(&rpc_credcache_lock);
	rpcauth_lru_add_locked(cred);
	spin_unlock(&rpc_credcache_lock);
}

static void
rpcauth_lru_remove_locked(struct rpc_cred *cred)
{
	if (list_empty(&cred->cr_lru))
		return;
	number_cred_unused--;
	list_del_init(&cred->cr_lru);
}

static void
rpcauth_lru_remove(struct rpc_cred *cred)
{
	if (list_empty(&cred->cr_lru))
		return;
	spin_lock(&rpc_credcache_lock);
	rpcauth_lru_remove_locked(cred);
	spin_unlock(&rpc_credcache_lock);
}

L
Linus Torvalds 已提交
434 435 436 437 438
/*
 * Clear the RPC credential cache, and delete those credentials
 * that are not referenced.
 */
void
439
rpcauth_clear_credcache(struct rpc_cred_cache *cache)
L
Linus Torvalds 已提交
440
{
441 442
	LIST_HEAD(free);
	struct hlist_head *head;
L
Linus Torvalds 已提交
443
	struct rpc_cred	*cred;
444
	unsigned int hashsize = 1U << cache->hashbits;
L
Linus Torvalds 已提交
445 446 447
	int		i;

	spin_lock(&rpc_credcache_lock);
448
	spin_lock(&cache->lock);
449
	for (i = 0; i < hashsize; i++) {
450 451 452
		head = &cache->hashtable[i];
		while (!hlist_empty(head)) {
			cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
453
			rpcauth_unhash_cred_locked(cred);
454 455 456
			/* Note: We now hold a reference to cred */
			rpcauth_lru_remove_locked(cred);
			list_add_tail(&cred->cr_lru, &free);
L
Linus Torvalds 已提交
457 458
		}
	}
459
	spin_unlock(&cache->lock);
L
Linus Torvalds 已提交
460 461 462 463
	spin_unlock(&rpc_credcache_lock);
	rpcauth_destroy_credlist(&free);
}

464 465 466 467 468 469 470 471 472 473 474
/*
 * Destroy the RPC credential cache
 */
void
rpcauth_destroy_credcache(struct rpc_auth *auth)
{
	struct rpc_cred_cache *cache = auth->au_credcache;

	if (cache) {
		auth->au_credcache = NULL;
		rpcauth_clear_credcache(cache);
475
		kfree(cache->hashtable);
476 477 478
		kfree(cache);
	}
}
479
EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);
480

481 482 483

#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ)

484 485 486
/*
 * Remove stale credentials. Avoid sleeping inside the loop.
 */
487
static long
488
rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
L
Linus Torvalds 已提交
489
{
490
	struct rpc_cred *cred, *next;
491
	unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM;
492
	long freed = 0;
493

494 495
	list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) {

496 497
		if (nr_to_scan-- == 0)
			break;
498
		if (refcount_read(&cred->cr_count) > 1) {
499 500 501
			rpcauth_lru_remove_locked(cred);
			continue;
		}
502 503 504 505
		/*
		 * Enforce a 60 second garbage collection moratorium
		 * Note that the cred_unused list must be time-ordered.
		 */
506 507 508
		if (!time_in_range(cred->cr_expire, expired, jiffies))
			continue;
		if (!rpcauth_unhash_cred(cred))
509
			continue;
510

511 512 513
		rpcauth_lru_remove_locked(cred);
		freed++;
		list_add_tail(&cred->cr_lru, free);
L
Linus Torvalds 已提交
514
	}
515
	return freed ? freed : SHRINK_STOP;
L
Linus Torvalds 已提交
516 517
}

518 519 520 521 522 523 524 525 526 527 528 529 530 531
static unsigned long
rpcauth_cache_do_shrink(int nr_to_scan)
{
	LIST_HEAD(free);
	unsigned long freed;

	spin_lock(&rpc_credcache_lock);
	freed = rpcauth_prune_expired(&free, nr_to_scan);
	spin_unlock(&rpc_credcache_lock);
	rpcauth_destroy_credlist(&free);

	return freed;
}

L
Linus Torvalds 已提交
532
/*
533
 * Run memory cache shrinker.
L
Linus Torvalds 已提交
534
 */
535 536 537
static unsigned long
rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)

L
Linus Torvalds 已提交
538
{
539 540
	if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL)
		return SHRINK_STOP;
541

542
	/* nothing left, don't come back */
543
	if (list_empty(&cred_unused))
544 545
		return SHRINK_STOP;

546
	return rpcauth_cache_do_shrink(sc->nr_to_scan);
547 548 549 550 551 552
}

static unsigned long
rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc)

{
553
	return number_cred_unused * sysctl_vfs_cache_pressure / 100;
L
Linus Torvalds 已提交
554 555
}

556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
static void
rpcauth_cache_enforce_limit(void)
{
	unsigned long diff;
	unsigned int nr_to_scan;

	if (number_cred_unused <= auth_max_cred_cachesize)
		return;
	diff = number_cred_unused - auth_max_cred_cachesize;
	nr_to_scan = 100;
	if (diff < nr_to_scan)
		nr_to_scan = diff;
	rpcauth_cache_do_shrink(nr_to_scan);
}

L
Linus Torvalds 已提交
571 572 573 574 575
/*
 * Look up a process' credentials in the authentication cache
 */
struct rpc_cred *
rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
576
		int flags, gfp_t gfp)
L
Linus Torvalds 已提交
577
{
578
	LIST_HEAD(free);
L
Linus Torvalds 已提交
579
	struct rpc_cred_cache *cache = auth->au_credcache;
580 581
	struct rpc_cred	*cred = NULL,
			*entry, *new;
582 583
	unsigned int nr;

584
	nr = auth->au_ops->hash_cred(acred, cache->hashbits);
L
Linus Torvalds 已提交
585

586
	rcu_read_lock();
587
	hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
588 589
		if (!entry->cr_ops->crmatch(acred, entry, flags))
			continue;
590
		if (flags & RPCAUTH_LOOKUP_RCU) {
T
Trond Myklebust 已提交
591
			if (test_bit(RPCAUTH_CRED_NEW, &entry->cr_flags) ||
592
			    refcount_read(&entry->cr_count) == 0)
T
Trond Myklebust 已提交
593 594
				continue;
			cred = entry;
595 596
			break;
		}
597
		cred = get_rpccred(entry);
T
Trond Myklebust 已提交
598 599
		if (cred)
			break;
L
Linus Torvalds 已提交
600
	}
601 602
	rcu_read_unlock();

603
	if (cred != NULL)
604
		goto found;
L
Linus Torvalds 已提交
605

606 607 608
	if (flags & RPCAUTH_LOOKUP_RCU)
		return ERR_PTR(-ECHILD);

609
	new = auth->au_ops->crcreate(auth, acred, flags, gfp);
610 611 612 613
	if (IS_ERR(new)) {
		cred = new;
		goto out;
	}
L
Linus Torvalds 已提交
614

615
	spin_lock(&cache->lock);
616
	hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {
617 618 619
		if (!entry->cr_ops->crmatch(acred, entry, flags))
			continue;
		cred = get_rpccred(entry);
T
Trond Myklebust 已提交
620 621
		if (cred)
			break;
622 623
	}
	if (cred == NULL) {
624
		cred = new;
625
		set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
626
		refcount_inc(&cred->cr_count);
627 628 629
		hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
	} else
		list_add_tail(&new->cr_lru, &free);
630
	spin_unlock(&cache->lock);
631
	rpcauth_cache_enforce_limit();
632
found:
633 634 635
	if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
	    cred->cr_ops->cr_init != NULL &&
	    !(flags & RPCAUTH_LOOKUP_NEW)) {
636 637 638 639 640
		int res = cred->cr_ops->cr_init(auth, cred);
		if (res < 0) {
			put_rpccred(cred);
			cred = ERR_PTR(res);
		}
L
Linus Torvalds 已提交
641
	}
642 643 644
	rpcauth_destroy_credlist(&free);
out:
	return cred;
L
Linus Torvalds 已提交
645
}
646
EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache);
L
Linus Torvalds 已提交
647 648

struct rpc_cred *
649
rpcauth_lookupcred(struct rpc_auth *auth, int flags)
L
Linus Torvalds 已提交
650
{
651
	struct auth_cred acred;
L
Linus Torvalds 已提交
652
	struct rpc_cred *ret;
653
	const struct cred *cred = current_cred();
L
Linus Torvalds 已提交
654

655
	dprintk("RPC:       looking up %s cred\n",
L
Linus Torvalds 已提交
656
		auth->au_ops->au_name);
657 658 659 660

	memset(&acred, 0, sizeof(acred));
	acred.uid = cred->fsuid;
	acred.gid = cred->fsgid;
661
	acred.cred = cred;
662
	ret = auth->au_ops->lookup_cred(auth, &acred, flags);
L
Linus Torvalds 已提交
663 664
	return ret;
}
665
EXPORT_SYMBOL_GPL(rpcauth_lookupcred);
L
Linus Torvalds 已提交
666

667 668 669 670 671
void
rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
		  struct rpc_auth *auth, const struct rpc_credops *ops)
{
	INIT_HLIST_NODE(&cred->cr_hash);
672
	INIT_LIST_HEAD(&cred->cr_lru);
673
	refcount_set(&cred->cr_count, 1);
674 675 676
	cred->cr_auth = auth;
	cred->cr_ops = ops;
	cred->cr_expire = jiffies;
677
	cred->cr_cred = get_cred(acred->cred);
678 679
	cred->cr_uid = acred->uid;
}
680
EXPORT_SYMBOL_GPL(rpcauth_init_cred);
681

682
struct rpc_cred *
683
rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags)
684 685 686
{
	dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid,
			cred->cr_auth->au_ops->au_name, cred);
687
	return get_rpccred(cred);
688
}
689
EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred);
690

691
static struct rpc_cred *
692
rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
L
Linus Torvalds 已提交
693
{
694
	struct rpc_auth *auth = task->tk_client->cl_auth;
L
Linus Torvalds 已提交
695
	struct auth_cred acred = {
696 697
		.uid = GLOBAL_ROOT_UID,
		.gid = GLOBAL_ROOT_GID,
698
		.cred = get_task_cred(&init_task),
L
Linus Torvalds 已提交
699
	};
700
	struct rpc_cred *ret;
L
Linus Torvalds 已提交
701

702
	dprintk("RPC: %5u looking up %s cred\n",
703
		task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
704 705 706
	ret = auth->au_ops->lookup_cred(auth, &acred, lookupflags);
	put_cred(acred.cred);
	return ret;
707 708
}

709
static struct rpc_cred *
710
rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
711 712 713 714 715
{
	struct rpc_auth *auth = task->tk_client->cl_auth;

	dprintk("RPC: %5u looking up %s cred\n",
		task->tk_pid, auth->au_ops->au_name);
716
	return rpcauth_lookupcred(auth, lookupflags);
L
Linus Torvalds 已提交
717 718
}

719
static int
720
rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags)
L
Linus Torvalds 已提交
721
{
722
	struct rpc_rqst *req = task->tk_rqstp;
723
	struct rpc_cred *new;
724 725 726 727
	int lookupflags = 0;

	if (flags & RPC_TASK_ASYNC)
		lookupflags |= RPCAUTH_LOOKUP_NEW;
728
	if (cred != NULL)
729
		new = cred->cr_ops->crbind(task, cred, lookupflags);
730
	else if (flags & RPC_TASK_ROOTCREDS)
731
		new = rpcauth_bind_root_cred(task, lookupflags);
732
	else
733 734 735
		new = rpcauth_bind_new_cred(task, lookupflags);
	if (IS_ERR(new))
		return PTR_ERR(new);
736
	put_rpccred(req->rq_cred);
737
	req->rq_cred = new;
738
	return 0;
L
Linus Torvalds 已提交
739 740 741 742 743
}

void
put_rpccred(struct rpc_cred *cred)
{
744 745
	if (cred == NULL)
		return;
746
	rcu_read_lock();
747
	if (refcount_dec_and_test(&cred->cr_count))
748
		goto destroy;
749
	if (refcount_read(&cred->cr_count) != 1 ||
750 751 752 753 754 755 756 757 758 759
	    !test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
		goto out;
	if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) {
		cred->cr_expire = jiffies;
		rpcauth_lru_add(cred);
		/* Race breaker */
		if (unlikely(!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)))
			rpcauth_lru_remove(cred);
	} else if (rpcauth_unhash_cred(cred)) {
		rpcauth_lru_remove(cred);
760
		if (refcount_dec_and_test(&cred->cr_count))
761
			goto destroy;
762
	}
763 764
out:
	rcu_read_unlock();
765
	return;
766 767 768
destroy:
	rcu_read_unlock();
	cred->cr_ops->crdestroy(cred);
L
Linus Torvalds 已提交
769
}
770
EXPORT_SYMBOL_GPL(put_rpccred);
L
Linus Torvalds 已提交
771

772 773
__be32 *
rpcauth_marshcred(struct rpc_task *task, __be32 *p)
L
Linus Torvalds 已提交
774
{
775
	struct rpc_cred	*cred = task->tk_rqstp->rq_cred;
L
Linus Torvalds 已提交
776

777
	dprintk("RPC: %5u marshaling %s cred %p\n",
778
		task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
779

L
Linus Torvalds 已提交
780 781 782
	return cred->cr_ops->crmarshal(task, p);
}

783 784
__be32 *
rpcauth_checkverf(struct rpc_task *task, __be32 *p)
L
Linus Torvalds 已提交
785
{
786
	struct rpc_cred	*cred = task->tk_rqstp->rq_cred;
L
Linus Torvalds 已提交
787

788
	dprintk("RPC: %5u validating %s cred %p\n",
789
		task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
790

L
Linus Torvalds 已提交
791 792 793
	return cred->cr_ops->crvalidate(task, p);
}

794 795 796 797 798 799 800 801 802
static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
				   __be32 *data, void *obj)
{
	struct xdr_stream xdr;

	xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data);
	encode(rqstp, &xdr, obj);
}

L
Linus Torvalds 已提交
803
int
804
rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp,
805
		__be32 *data, void *obj)
L
Linus Torvalds 已提交
806
{
807
	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
L
Linus Torvalds 已提交
808

809
	dprintk("RPC: %5u using %s cred %p to wrap rpc data\n",
L
Linus Torvalds 已提交
810 811 812 813
			task->tk_pid, cred->cr_ops->cr_name, cred);
	if (cred->cr_ops->crwrap_req)
		return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);
	/* By default, we encode the arguments normally. */
814 815
	rpcauth_wrap_req_encode(encode, rqstp, data, obj);
	return 0;
L
Linus Torvalds 已提交
816 817
}

818 819 820 821 822 823 824 825 826 827
static int
rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
			  __be32 *data, void *obj)
{
	struct xdr_stream xdr;

	xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data);
	return decode(rqstp, &xdr, obj);
}

L
Linus Torvalds 已提交
828
int
829
rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp,
830
		__be32 *data, void *obj)
L
Linus Torvalds 已提交
831
{
832
	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
L
Linus Torvalds 已提交
833

834
	dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n",
L
Linus Torvalds 已提交
835 836 837 838 839
			task->tk_pid, cred->cr_ops->cr_name, cred);
	if (cred->cr_ops->crunwrap_resp)
		return cred->cr_ops->crunwrap_resp(task, decode, rqstp,
						   data, obj);
	/* By default, we decode the arguments normally. */
840
	return rpcauth_unwrap_req_decode(decode, rqstp, data, obj);
L
Linus Torvalds 已提交
841 842
}

843 844 845 846 847 848 849 850 851 852
bool
rpcauth_xmit_need_reencode(struct rpc_task *task)
{
	struct rpc_cred *cred = task->tk_rqstp->rq_cred;

	if (!cred || !cred->cr_ops->crneed_reencode)
		return false;
	return cred->cr_ops->crneed_reencode(task);
}

L
Linus Torvalds 已提交
853 854 855
int
rpcauth_refreshcred(struct rpc_task *task)
{
856
	struct rpc_cred	*cred;
L
Linus Torvalds 已提交
857 858
	int err;

859 860 861 862 863 864
	cred = task->tk_rqstp->rq_cred;
	if (cred == NULL) {
		err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags);
		if (err < 0)
			goto out;
		cred = task->tk_rqstp->rq_cred;
J
Joe Perches 已提交
865
	}
866
	dprintk("RPC: %5u refreshing %s cred %p\n",
867
		task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
868

L
Linus Torvalds 已提交
869
	err = cred->cr_ops->crrefresh(task);
870
out:
L
Linus Torvalds 已提交
871 872 873 874 875 876 877 878
	if (err < 0)
		task->tk_status = err;
	return err;
}

void
rpcauth_invalcred(struct rpc_task *task)
{
879
	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
880

881
	dprintk("RPC: %5u invalidating %s cred %p\n",
882
		task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
883 884
	if (cred)
		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
L
Linus Torvalds 已提交
885 886 887 888 889
}

int
rpcauth_uptodatecred(struct rpc_task *task)
{
890
	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
891 892 893

	return cred == NULL ||
		test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
L
Linus Torvalds 已提交
894
}
895

896
static struct shrinker rpc_cred_shrinker = {
897 898
	.count_objects = rpcauth_cache_shrink_count,
	.scan_objects = rpcauth_cache_shrink_scan,
899 900
	.seeks = DEFAULT_SEEKS,
};
901

902
int __init rpcauth_init_module(void)
903
{
904 905 906 907 908 909 910 911
	int err;

	err = rpc_init_authunix();
	if (err < 0)
		goto out1;
	err = rpc_init_generic_auth();
	if (err < 0)
		goto out2;
912 913 914
	err = register_shrinker(&rpc_cred_shrinker);
	if (err < 0)
		goto out3;
915
	return 0;
916 917
out3:
	rpc_destroy_generic_auth();
918 919 920 921
out2:
	rpc_destroy_authunix();
out1:
	return err;
922 923
}

924
void rpcauth_remove_module(void)
925
{
926 927
	rpc_destroy_authunix();
	rpc_destroy_generic_auth();
928
	unregister_shrinker(&rpc_cred_shrinker);
929
}