auth.c 21.3 KB
Newer Older
L
Linus Torvalds 已提交
1 2 3 4 5 6 7 8 9 10
/*
 * linux/net/sunrpc/auth.c
 *
 * Generic RPC client authentication API.
 *
 * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
 */

#include <linux/types.h>
#include <linux/sched.h>
11
#include <linux/cred.h>
L
Linus Torvalds 已提交
12 13 14
#include <linux/module.h>
#include <linux/slab.h>
#include <linux/errno.h>
15
#include <linux/hash.h>
L
Linus Torvalds 已提交
16
#include <linux/sunrpc/clnt.h>
C
Chuck Lever 已提交
17
#include <linux/sunrpc/gss_api.h>
L
Linus Torvalds 已提交
18 19
#include <linux/spinlock.h>

J
Jeff Layton 已提交
20
#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
L
Linus Torvalds 已提交
21 22 23
# define RPCDBG_FACILITY	RPCDBG_AUTH
#endif

24 25 26 27 28 29 30 31 32
#define RPC_CREDCACHE_DEFAULT_HASHBITS	(4)
struct rpc_cred_cache {
	struct hlist_head	*hashtable;
	unsigned int		hashbits;
	spinlock_t		lock;
};

static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;

33 34 35
static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
	[RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops,
	[RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops,
L
Linus Torvalds 已提交
36 37 38
	NULL,			/* others can be loadable modules */
};

39
static LIST_HEAD(cred_unused);
40
static unsigned long number_cred_unused;
41

42
#define MAX_HASHTABLE_BITS (14)
43
static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
44 45 46 47 48 49 50
{
	unsigned long num;
	unsigned int nbits;
	int ret;

	if (!val)
		goto out_inval;
D
Daniel Walter 已提交
51
	ret = kstrtoul(val, 0, &num);
52
	if (ret)
53
		goto out_inval;
54
	nbits = fls(num - 1);
55 56 57 58 59 60 61 62
	if (nbits > MAX_HASHTABLE_BITS || nbits < 2)
		goto out_inval;
	*(unsigned int *)kp->arg = nbits;
	return 0;
out_inval:
	return -EINVAL;
}

63
static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp)
64 65 66 67 68 69 70 71 72
{
	unsigned int nbits;

	nbits = *(unsigned int *)kp->arg;
	return sprintf(buffer, "%u", 1U << nbits);
}

#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int);

73
static const struct kernel_param_ops param_ops_hashtbl_sz = {
74 75 76 77
	.set = param_set_hashtbl_sz,
	.get = param_get_hashtbl_sz,
};

78 79 80
module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644);
MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");

81 82 83 84
static unsigned long auth_max_cred_cachesize = ULONG_MAX;
module_param(auth_max_cred_cachesize, ulong, 0644);
MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size");

L
Linus Torvalds 已提交
85 86
static u32
pseudoflavor_to_flavor(u32 flavor) {
87
	if (flavor > RPC_AUTH_MAXFLAVOR)
L
Linus Torvalds 已提交
88 89 90 91 92
		return RPC_AUTH_GSS;
	return flavor;
}

int
93
rpcauth_register(const struct rpc_authops *ops)
L
Linus Torvalds 已提交
94
{
95
	const struct rpc_authops *old;
L
Linus Torvalds 已提交
96 97 98 99
	rpc_authflavor_t flavor;

	if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
		return -EINVAL;
100 101 102 103
	old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], NULL, ops);
	if (old == NULL || old == ops)
		return 0;
	return -EPERM;
L
Linus Torvalds 已提交
104
}
105
EXPORT_SYMBOL_GPL(rpcauth_register);
L
Linus Torvalds 已提交
106 107

int
108
rpcauth_unregister(const struct rpc_authops *ops)
L
Linus Torvalds 已提交
109
{
110
	const struct rpc_authops *old;
L
Linus Torvalds 已提交
111 112 113 114
	rpc_authflavor_t flavor;

	if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
		return -EINVAL;
115 116 117 118 119

	old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], ops, NULL);
	if (old == ops || old == NULL)
		return 0;
	return -EPERM;
L
Linus Torvalds 已提交
120
}
121
EXPORT_SYMBOL_GPL(rpcauth_unregister);
L
Linus Torvalds 已提交
122

123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
static const struct rpc_authops *
rpcauth_get_authops(rpc_authflavor_t flavor)
{
	const struct rpc_authops *ops;

	if (flavor >= RPC_AUTH_MAXFLAVOR)
		return NULL;

	rcu_read_lock();
	ops = rcu_dereference(auth_flavors[flavor]);
	if (ops == NULL) {
		rcu_read_unlock();
		request_module("rpc-auth-%u", flavor);
		rcu_read_lock();
		ops = rcu_dereference(auth_flavors[flavor]);
		if (ops == NULL)
			goto out;
	}
	if (!try_module_get(ops->owner))
		ops = NULL;
out:
	rcu_read_unlock();
	return ops;
}

static void
rpcauth_put_authops(const struct rpc_authops *ops)
{
	module_put(ops->owner);
}

154 155 156 157 158 159 160 161 162 163 164 165
/**
 * rpcauth_get_pseudoflavor - check if security flavor is supported
 * @flavor: a security flavor
 * @info: a GSS mech OID, quality of protection, and service value
 *
 * Verifies that an appropriate kernel module is available or already loaded.
 * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is
 * not supported locally.
 */
rpc_authflavor_t
rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info)
{
166
	const struct rpc_authops *ops = rpcauth_get_authops(flavor);
167 168
	rpc_authflavor_t pseudoflavor;

169
	if (!ops)
170 171 172 173 174
		return RPC_AUTH_MAXFLAVOR;
	pseudoflavor = flavor;
	if (ops->info2flavor != NULL)
		pseudoflavor = ops->info2flavor(info);

175
	rpcauth_put_authops(ops);
176 177 178 179
	return pseudoflavor;
}
EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor);

180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
/**
 * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor
 * @pseudoflavor: GSS pseudoflavor to match
 * @info: rpcsec_gss_info structure to fill in
 *
 * Returns zero and fills in "info" if pseudoflavor matches a
 * supported mechanism.
 */
int
rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info)
{
	rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor);
	const struct rpc_authops *ops;
	int result;

195
	ops = rpcauth_get_authops(flavor);
196 197 198 199 200 201 202
	if (ops == NULL)
		return -ENOENT;

	result = -ENOENT;
	if (ops->flavor2info != NULL)
		result = ops->flavor2info(pseudoflavor, info);

203
	rpcauth_put_authops(ops);
204 205 206 207
	return result;
}
EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo);

C
Chuck Lever 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220
/**
 * rpcauth_list_flavors - discover registered flavors and pseudoflavors
 * @array: array to fill in
 * @size: size of "array"
 *
 * Returns the number of array items filled in, or a negative errno.
 *
 * The returned array is not sorted by any policy.  Callers should not
 * rely on the order of the items in the returned array.
 */
int
rpcauth_list_flavors(rpc_authflavor_t *array, int size)
{
221 222 223
	const struct rpc_authops *ops;
	rpc_authflavor_t flavor, pseudos[4];
	int i, len, result = 0;
C
Chuck Lever 已提交
224

225
	rcu_read_lock();
C
Chuck Lever 已提交
226
	for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) {
227
		ops = rcu_dereference(auth_flavors[flavor]);
C
Chuck Lever 已提交
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
		if (result >= size) {
			result = -ENOMEM;
			break;
		}

		if (ops == NULL)
			continue;
		if (ops->list_pseudoflavors == NULL) {
			array[result++] = ops->au_flavor;
			continue;
		}
		len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos));
		if (len < 0) {
			result = len;
			break;
		}
		for (i = 0; i < len; i++) {
			if (result >= size) {
				result = -ENOMEM;
				break;
			}
			array[result++] = pseudos[i];
		}
	}
252
	rcu_read_unlock();
C
Chuck Lever 已提交
253 254 255 256 257 258

	dprintk("RPC:       %s returns %d\n", __func__, result);
	return result;
}
EXPORT_SYMBOL_GPL(rpcauth_list_flavors);

L
Linus Torvalds 已提交
259
struct rpc_auth *
260
rpcauth_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
L
Linus Torvalds 已提交
261
{
262
	struct rpc_auth	*auth = ERR_PTR(-EINVAL);
263
	const struct rpc_authops *ops;
264
	u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor);
L
Linus Torvalds 已提交
265

266 267
	ops = rpcauth_get_authops(flavor);
	if (ops == NULL)
268 269
		goto out;

270
	auth = ops->create(args, clnt);
271 272

	rpcauth_put_authops(ops);
273 274
	if (IS_ERR(auth))
		return auth;
L
Linus Torvalds 已提交
275
	if (clnt->cl_auth)
276
		rpcauth_release(clnt->cl_auth);
L
Linus Torvalds 已提交
277
	clnt->cl_auth = auth;
278 279

out:
L
Linus Torvalds 已提交
280 281
	return auth;
}
282
EXPORT_SYMBOL_GPL(rpcauth_create);
L
Linus Torvalds 已提交
283 284

void
285
rpcauth_release(struct rpc_auth *auth)
L
Linus Torvalds 已提交
286 287 288 289 290 291 292 293
{
	if (!atomic_dec_and_test(&auth->au_count))
		return;
	auth->au_ops->destroy(auth);
}

static DEFINE_SPINLOCK(rpc_credcache_lock);

294 295 296 297 298
/*
 * On success, the caller is responsible for freeing the reference
 * held by the hashtable
 */
static bool
299 300
rpcauth_unhash_cred_locked(struct rpc_cred *cred)
{
301 302
	if (!test_and_clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
		return false;
303
	hlist_del_rcu(&cred->cr_hash);
304
	return true;
305 306
}

307
static bool
308 309 310
rpcauth_unhash_cred(struct rpc_cred *cred)
{
	spinlock_t *cache_lock;
311
	bool ret;
312

313 314
	if (!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
		return false;
315 316
	cache_lock = &cred->cr_auth->au_credcache->lock;
	spin_lock(cache_lock);
317
	ret = rpcauth_unhash_cred_locked(cred);
318
	spin_unlock(cache_lock);
319
	return ret;
320 321
}

L
Linus Torvalds 已提交
322 323 324 325
/*
 * Initialize RPC credential cache
 */
int
326
rpcauth_init_credcache(struct rpc_auth *auth)
L
Linus Torvalds 已提交
327 328
{
	struct rpc_cred_cache *new;
329
	unsigned int hashsize;
L
Linus Torvalds 已提交
330

331
	new = kmalloc(sizeof(*new), GFP_KERNEL);
L
Linus Torvalds 已提交
332
	if (!new)
333 334
		goto out_nocache;
	new->hashbits = auth_hashbits;
335
	hashsize = 1U << new->hashbits;
336 337 338
	new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL);
	if (!new->hashtable)
		goto out_nohashtbl;
339
	spin_lock_init(&new->lock);
L
Linus Torvalds 已提交
340 341
	auth->au_credcache = new;
	return 0;
342 343 344 345
out_nohashtbl:
	kfree(new);
out_nocache:
	return -ENOMEM;
L
Linus Torvalds 已提交
346
}
347
EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
L
Linus Torvalds 已提交
348

349 350 351 352 353 354 355 356 357 358 359 360 361
/*
 * Setup a credential key lifetime timeout notification
 */
int
rpcauth_key_timeout_notify(struct rpc_auth *auth, struct rpc_cred *cred)
{
	if (!cred->cr_auth->au_ops->key_timeout)
		return 0;
	return cred->cr_auth->au_ops->key_timeout(auth, cred);
}
EXPORT_SYMBOL_GPL(rpcauth_key_timeout_notify);

bool
362
rpcauth_cred_key_to_expire(struct rpc_auth *auth, struct rpc_cred *cred)
363
{
364 365
	if (auth->au_flags & RPCAUTH_AUTH_NO_CRKEY_TIMEOUT)
		return false;
366 367 368 369 370 371
	if (!cred->cr_ops->crkey_to_expire)
		return false;
	return cred->cr_ops->crkey_to_expire(cred);
}
EXPORT_SYMBOL_GPL(rpcauth_cred_key_to_expire);

372 373 374 375 376 377 378 379 380
char *
rpcauth_stringify_acceptor(struct rpc_cred *cred)
{
	if (!cred->cr_ops->crstringify_acceptor)
		return NULL;
	return cred->cr_ops->crstringify_acceptor(cred);
}
EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor);

L
Linus Torvalds 已提交
381 382 383 384
/*
 * Destroy a list of credentials
 */
static inline
385
void rpcauth_destroy_credlist(struct list_head *head)
L
Linus Torvalds 已提交
386 387 388
{
	struct rpc_cred *cred;

389 390 391
	while (!list_empty(head)) {
		cred = list_entry(head->next, struct rpc_cred, cr_lru);
		list_del_init(&cred->cr_lru);
L
Linus Torvalds 已提交
392 393 394 395
		put_rpccred(cred);
	}
}

396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
static void
rpcauth_lru_add_locked(struct rpc_cred *cred)
{
	if (!list_empty(&cred->cr_lru))
		return;
	number_cred_unused++;
	list_add_tail(&cred->cr_lru, &cred_unused);
}

static void
rpcauth_lru_add(struct rpc_cred *cred)
{
	if (!list_empty(&cred->cr_lru))
		return;
	spin_lock(&rpc_credcache_lock);
	rpcauth_lru_add_locked(cred);
	spin_unlock(&rpc_credcache_lock);
}

static void
rpcauth_lru_remove_locked(struct rpc_cred *cred)
{
	if (list_empty(&cred->cr_lru))
		return;
	number_cred_unused--;
	list_del_init(&cred->cr_lru);
}

static void
rpcauth_lru_remove(struct rpc_cred *cred)
{
	if (list_empty(&cred->cr_lru))
		return;
	spin_lock(&rpc_credcache_lock);
	rpcauth_lru_remove_locked(cred);
	spin_unlock(&rpc_credcache_lock);
}

L
Linus Torvalds 已提交
434 435 436 437 438
/*
 * Clear the RPC credential cache, and delete those credentials
 * that are not referenced.
 */
void
439
rpcauth_clear_credcache(struct rpc_cred_cache *cache)
L
Linus Torvalds 已提交
440
{
441 442
	LIST_HEAD(free);
	struct hlist_head *head;
L
Linus Torvalds 已提交
443
	struct rpc_cred	*cred;
444
	unsigned int hashsize = 1U << cache->hashbits;
L
Linus Torvalds 已提交
445 446 447
	int		i;

	spin_lock(&rpc_credcache_lock);
448
	spin_lock(&cache->lock);
449
	for (i = 0; i < hashsize; i++) {
450 451 452
		head = &cache->hashtable[i];
		while (!hlist_empty(head)) {
			cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
453
			rpcauth_unhash_cred_locked(cred);
454 455 456
			/* Note: We now hold a reference to cred */
			rpcauth_lru_remove_locked(cred);
			list_add_tail(&cred->cr_lru, &free);
L
Linus Torvalds 已提交
457 458
		}
	}
459
	spin_unlock(&cache->lock);
L
Linus Torvalds 已提交
460 461 462 463
	spin_unlock(&rpc_credcache_lock);
	rpcauth_destroy_credlist(&free);
}

464 465 466 467 468 469 470 471 472 473 474
/*
 * Destroy the RPC credential cache
 */
void
rpcauth_destroy_credcache(struct rpc_auth *auth)
{
	struct rpc_cred_cache *cache = auth->au_credcache;

	if (cache) {
		auth->au_credcache = NULL;
		rpcauth_clear_credcache(cache);
475
		kfree(cache->hashtable);
476 477 478
		kfree(cache);
	}
}
479
EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);
480

481 482 483

#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ)

484 485 486
/*
 * Remove stale credentials. Avoid sleeping inside the loop.
 */
487
static long
488
rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
L
Linus Torvalds 已提交
489
{
490
	struct rpc_cred *cred, *next;
491
	unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM;
492
	long freed = 0;
493

494 495
	list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) {

496 497
		if (nr_to_scan-- == 0)
			break;
498
		if (refcount_read(&cred->cr_count) > 1) {
499 500 501
			rpcauth_lru_remove_locked(cred);
			continue;
		}
502 503 504 505
		/*
		 * Enforce a 60 second garbage collection moratorium
		 * Note that the cred_unused list must be time-ordered.
		 */
506 507 508
		if (!time_in_range(cred->cr_expire, expired, jiffies))
			continue;
		if (!rpcauth_unhash_cred(cred))
509
			continue;
510

511 512 513
		rpcauth_lru_remove_locked(cred);
		freed++;
		list_add_tail(&cred->cr_lru, free);
L
Linus Torvalds 已提交
514
	}
515
	return freed ? freed : SHRINK_STOP;
L
Linus Torvalds 已提交
516 517
}

518 519 520 521 522 523 524 525 526 527 528 529 530 531
static unsigned long
rpcauth_cache_do_shrink(int nr_to_scan)
{
	LIST_HEAD(free);
	unsigned long freed;

	spin_lock(&rpc_credcache_lock);
	freed = rpcauth_prune_expired(&free, nr_to_scan);
	spin_unlock(&rpc_credcache_lock);
	rpcauth_destroy_credlist(&free);

	return freed;
}

L
Linus Torvalds 已提交
532
/*
533
 * Run memory cache shrinker.
L
Linus Torvalds 已提交
534
 */
535 536 537
static unsigned long
rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)

L
Linus Torvalds 已提交
538
{
539 540
	if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL)
		return SHRINK_STOP;
541

542
	/* nothing left, don't come back */
543
	if (list_empty(&cred_unused))
544 545
		return SHRINK_STOP;

546
	return rpcauth_cache_do_shrink(sc->nr_to_scan);
547 548 549 550 551 552
}

static unsigned long
rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc)

{
553
	return number_cred_unused * sysctl_vfs_cache_pressure / 100;
L
Linus Torvalds 已提交
554 555
}

556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
static void
rpcauth_cache_enforce_limit(void)
{
	unsigned long diff;
	unsigned int nr_to_scan;

	if (number_cred_unused <= auth_max_cred_cachesize)
		return;
	diff = number_cred_unused - auth_max_cred_cachesize;
	nr_to_scan = 100;
	if (diff < nr_to_scan)
		nr_to_scan = diff;
	rpcauth_cache_do_shrink(nr_to_scan);
}

L
Linus Torvalds 已提交
571 572 573 574 575
/*
 * Look up a process' credentials in the authentication cache
 */
struct rpc_cred *
rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
576
		int flags, gfp_t gfp)
L
Linus Torvalds 已提交
577
{
578
	LIST_HEAD(free);
L
Linus Torvalds 已提交
579
	struct rpc_cred_cache *cache = auth->au_credcache;
580 581
	struct rpc_cred	*cred = NULL,
			*entry, *new;
582 583
	unsigned int nr;

584
	nr = auth->au_ops->hash_cred(acred, cache->hashbits);
L
Linus Torvalds 已提交
585

586
	rcu_read_lock();
587
	hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
588 589
		if (!entry->cr_ops->crmatch(acred, entry, flags))
			continue;
590
		if (flags & RPCAUTH_LOOKUP_RCU) {
T
Trond Myklebust 已提交
591
			if (test_bit(RPCAUTH_CRED_NEW, &entry->cr_flags) ||
592
			    refcount_read(&entry->cr_count) == 0)
T
Trond Myklebust 已提交
593 594
				continue;
			cred = entry;
595 596
			break;
		}
597
		cred = get_rpccred(entry);
T
Trond Myklebust 已提交
598 599
		if (cred)
			break;
L
Linus Torvalds 已提交
600
	}
601 602
	rcu_read_unlock();

603
	if (cred != NULL)
604
		goto found;
L
Linus Torvalds 已提交
605

606 607 608
	if (flags & RPCAUTH_LOOKUP_RCU)
		return ERR_PTR(-ECHILD);

609
	new = auth->au_ops->crcreate(auth, acred, flags, gfp);
610 611 612 613
	if (IS_ERR(new)) {
		cred = new;
		goto out;
	}
L
Linus Torvalds 已提交
614

615
	spin_lock(&cache->lock);
616
	hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {
617 618 619
		if (!entry->cr_ops->crmatch(acred, entry, flags))
			continue;
		cred = get_rpccred(entry);
T
Trond Myklebust 已提交
620 621
		if (cred)
			break;
622 623
	}
	if (cred == NULL) {
624
		cred = new;
625
		set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
626
		refcount_inc(&cred->cr_count);
627 628 629
		hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
	} else
		list_add_tail(&new->cr_lru, &free);
630
	spin_unlock(&cache->lock);
631
	rpcauth_cache_enforce_limit();
632
found:
633 634 635
	if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
	    cred->cr_ops->cr_init != NULL &&
	    !(flags & RPCAUTH_LOOKUP_NEW)) {
636 637 638 639 640
		int res = cred->cr_ops->cr_init(auth, cred);
		if (res < 0) {
			put_rpccred(cred);
			cred = ERR_PTR(res);
		}
L
Linus Torvalds 已提交
641
	}
642 643 644
	rpcauth_destroy_credlist(&free);
out:
	return cred;
L
Linus Torvalds 已提交
645
}
646
EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache);
L
Linus Torvalds 已提交
647 648

struct rpc_cred *
649
rpcauth_lookupcred(struct rpc_auth *auth, int flags)
L
Linus Torvalds 已提交
650
{
651
	struct auth_cred acred;
L
Linus Torvalds 已提交
652
	struct rpc_cred *ret;
653
	const struct cred *cred = current_cred();
L
Linus Torvalds 已提交
654

655
	dprintk("RPC:       looking up %s cred\n",
L
Linus Torvalds 已提交
656
		auth->au_ops->au_name);
657 658 659 660

	memset(&acred, 0, sizeof(acred));
	acred.uid = cred->fsuid;
	acred.gid = cred->fsgid;
661
	acred.group_info = cred->group_info;
662
	ret = auth->au_ops->lookup_cred(auth, &acred, flags);
L
Linus Torvalds 已提交
663 664
	return ret;
}
665
EXPORT_SYMBOL_GPL(rpcauth_lookupcred);
L
Linus Torvalds 已提交
666

667 668 669 670 671
void
rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
		  struct rpc_auth *auth, const struct rpc_credops *ops)
{
	INIT_HLIST_NODE(&cred->cr_hash);
672
	INIT_LIST_HEAD(&cred->cr_lru);
673
	refcount_set(&cred->cr_count, 1);
674 675 676 677 678
	cred->cr_auth = auth;
	cred->cr_ops = ops;
	cred->cr_expire = jiffies;
	cred->cr_uid = acred->uid;
}
679
EXPORT_SYMBOL_GPL(rpcauth_init_cred);
680

681
struct rpc_cred *
682
rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags)
683 684 685
{
	dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid,
			cred->cr_auth->au_ops->au_name, cred);
686
	return get_rpccred(cred);
687
}
688
EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred);
689

690
static struct rpc_cred *
691
rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
L
Linus Torvalds 已提交
692
{
693
	struct rpc_auth *auth = task->tk_client->cl_auth;
L
Linus Torvalds 已提交
694
	struct auth_cred acred = {
695 696
		.uid = GLOBAL_ROOT_UID,
		.gid = GLOBAL_ROOT_GID,
L
Linus Torvalds 已提交
697 698
	};

699
	dprintk("RPC: %5u looking up %s cred\n",
700
		task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
701
	return auth->au_ops->lookup_cred(auth, &acred, lookupflags);
702 703
}

704
static struct rpc_cred *
705
rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
706 707 708 709 710
{
	struct rpc_auth *auth = task->tk_client->cl_auth;

	dprintk("RPC: %5u looking up %s cred\n",
		task->tk_pid, auth->au_ops->au_name);
711
	return rpcauth_lookupcred(auth, lookupflags);
L
Linus Torvalds 已提交
712 713
}

714
static int
715
rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags)
L
Linus Torvalds 已提交
716
{
717
	struct rpc_rqst *req = task->tk_rqstp;
718
	struct rpc_cred *new;
719 720 721 722
	int lookupflags = 0;

	if (flags & RPC_TASK_ASYNC)
		lookupflags |= RPCAUTH_LOOKUP_NEW;
723
	if (cred != NULL)
724
		new = cred->cr_ops->crbind(task, cred, lookupflags);
725
	else if (flags & RPC_TASK_ROOTCREDS)
726
		new = rpcauth_bind_root_cred(task, lookupflags);
727
	else
728 729 730
		new = rpcauth_bind_new_cred(task, lookupflags);
	if (IS_ERR(new))
		return PTR_ERR(new);
731
	put_rpccred(req->rq_cred);
732
	req->rq_cred = new;
733
	return 0;
L
Linus Torvalds 已提交
734 735 736 737 738
}

void
put_rpccred(struct rpc_cred *cred)
{
739 740
	if (cred == NULL)
		return;
741
	rcu_read_lock();
742
	if (refcount_dec_and_test(&cred->cr_count))
743
		goto destroy;
744
	if (refcount_read(&cred->cr_count) != 1 ||
745 746 747 748 749 750 751 752 753 754
	    !test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
		goto out;
	if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) {
		cred->cr_expire = jiffies;
		rpcauth_lru_add(cred);
		/* Race breaker */
		if (unlikely(!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)))
			rpcauth_lru_remove(cred);
	} else if (rpcauth_unhash_cred(cred)) {
		rpcauth_lru_remove(cred);
755
		if (refcount_dec_and_test(&cred->cr_count))
756
			goto destroy;
757
	}
758 759
out:
	rcu_read_unlock();
760
	return;
761 762 763
destroy:
	rcu_read_unlock();
	cred->cr_ops->crdestroy(cred);
L
Linus Torvalds 已提交
764
}
765
EXPORT_SYMBOL_GPL(put_rpccred);
L
Linus Torvalds 已提交
766

767 768
__be32 *
rpcauth_marshcred(struct rpc_task *task, __be32 *p)
L
Linus Torvalds 已提交
769
{
770
	struct rpc_cred	*cred = task->tk_rqstp->rq_cred;
L
Linus Torvalds 已提交
771

772
	dprintk("RPC: %5u marshaling %s cred %p\n",
773
		task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
774

L
Linus Torvalds 已提交
775 776 777
	return cred->cr_ops->crmarshal(task, p);
}

778 779
__be32 *
rpcauth_checkverf(struct rpc_task *task, __be32 *p)
L
Linus Torvalds 已提交
780
{
781
	struct rpc_cred	*cred = task->tk_rqstp->rq_cred;
L
Linus Torvalds 已提交
782

783
	dprintk("RPC: %5u validating %s cred %p\n",
784
		task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
785

L
Linus Torvalds 已提交
786 787 788
	return cred->cr_ops->crvalidate(task, p);
}

789 790 791 792 793 794 795 796 797
static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
				   __be32 *data, void *obj)
{
	struct xdr_stream xdr;

	xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data);
	encode(rqstp, &xdr, obj);
}

L
Linus Torvalds 已提交
798
int
799
rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp,
800
		__be32 *data, void *obj)
L
Linus Torvalds 已提交
801
{
802
	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
L
Linus Torvalds 已提交
803

804
	dprintk("RPC: %5u using %s cred %p to wrap rpc data\n",
L
Linus Torvalds 已提交
805 806 807 808
			task->tk_pid, cred->cr_ops->cr_name, cred);
	if (cred->cr_ops->crwrap_req)
		return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);
	/* By default, we encode the arguments normally. */
809 810
	rpcauth_wrap_req_encode(encode, rqstp, data, obj);
	return 0;
L
Linus Torvalds 已提交
811 812
}

813 814 815 816 817 818 819 820 821 822
static int
rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
			  __be32 *data, void *obj)
{
	struct xdr_stream xdr;

	xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data);
	return decode(rqstp, &xdr, obj);
}

L
Linus Torvalds 已提交
823
int
824
rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp,
825
		__be32 *data, void *obj)
L
Linus Torvalds 已提交
826
{
827
	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
L
Linus Torvalds 已提交
828

829
	dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n",
L
Linus Torvalds 已提交
830 831 832 833 834
			task->tk_pid, cred->cr_ops->cr_name, cred);
	if (cred->cr_ops->crunwrap_resp)
		return cred->cr_ops->crunwrap_resp(task, decode, rqstp,
						   data, obj);
	/* By default, we decode the arguments normally. */
835
	return rpcauth_unwrap_req_decode(decode, rqstp, data, obj);
L
Linus Torvalds 已提交
836 837
}

838 839 840 841 842 843 844 845 846 847
bool
rpcauth_xmit_need_reencode(struct rpc_task *task)
{
	struct rpc_cred *cred = task->tk_rqstp->rq_cred;

	if (!cred || !cred->cr_ops->crneed_reencode)
		return false;
	return cred->cr_ops->crneed_reencode(task);
}

L
Linus Torvalds 已提交
848 849 850
int
rpcauth_refreshcred(struct rpc_task *task)
{
851
	struct rpc_cred	*cred;
L
Linus Torvalds 已提交
852 853
	int err;

854 855 856 857 858 859
	cred = task->tk_rqstp->rq_cred;
	if (cred == NULL) {
		err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags);
		if (err < 0)
			goto out;
		cred = task->tk_rqstp->rq_cred;
J
Joe Perches 已提交
860
	}
861
	dprintk("RPC: %5u refreshing %s cred %p\n",
862
		task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
863

L
Linus Torvalds 已提交
864
	err = cred->cr_ops->crrefresh(task);
865
out:
L
Linus Torvalds 已提交
866 867 868 869 870 871 872 873
	if (err < 0)
		task->tk_status = err;
	return err;
}

void
rpcauth_invalcred(struct rpc_task *task)
{
874
	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
875

876
	dprintk("RPC: %5u invalidating %s cred %p\n",
877
		task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
878 879
	if (cred)
		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
L
Linus Torvalds 已提交
880 881 882 883 884
}

int
rpcauth_uptodatecred(struct rpc_task *task)
{
885
	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
886 887 888

	return cred == NULL ||
		test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
L
Linus Torvalds 已提交
889
}
890

891
static struct shrinker rpc_cred_shrinker = {
892 893
	.count_objects = rpcauth_cache_shrink_count,
	.scan_objects = rpcauth_cache_shrink_scan,
894 895
	.seeks = DEFAULT_SEEKS,
};
896

897
int __init rpcauth_init_module(void)
898
{
899 900 901 902 903 904 905 906
	int err;

	err = rpc_init_authunix();
	if (err < 0)
		goto out1;
	err = rpc_init_generic_auth();
	if (err < 0)
		goto out2;
907 908 909
	err = register_shrinker(&rpc_cred_shrinker);
	if (err < 0)
		goto out3;
910
	return 0;
911 912
out3:
	rpc_destroy_generic_auth();
913 914 915 916
out2:
	rpc_destroy_authunix();
out1:
	return err;
917 918
}

919
void rpcauth_remove_module(void)
920
{
921 922
	rpc_destroy_authunix();
	rpc_destroy_generic_auth();
923
	unregister_shrinker(&rpc_cred_shrinker);
924
}