diff --git a/crypto/ccm.c b/crypto/ccm.c index 0a083342ec8cf3b17c4da15c515c8c3d7a519a0f..8104c564dd318ceea01a763adbece9f2fc297407 100644 --- a/crypto/ccm.c +++ b/crypto/ccm.c @@ -455,7 +455,6 @@ static void crypto_ccm_free(struct aead_instance *inst) static int crypto_ccm_create_common(struct crypto_template *tmpl, struct rtattr **tb, - const char *full_name, const char *ctr_name, const char *mac_name) { @@ -483,7 +482,8 @@ static int crypto_ccm_create_common(struct crypto_template *tmpl, mac = __crypto_hash_alg_common(mac_alg); err = -EINVAL; - if (mac->digestsize != 16) + if (strncmp(mac->base.cra_name, "cbcmac(", 7) != 0 || + mac->digestsize != 16) goto out_put_mac; inst = kzalloc(sizeof(*inst) + sizeof(*ictx), GFP_KERNEL); @@ -506,23 +506,27 @@ static int crypto_ccm_create_common(struct crypto_template *tmpl, ctr = crypto_spawn_skcipher_alg(&ictx->ctr); - /* Not a stream cipher? */ + /* The skcipher algorithm must be CTR mode, using 16-byte blocks. */ err = -EINVAL; - if (ctr->base.cra_blocksize != 1) + if (strncmp(ctr->base.cra_name, "ctr(", 4) != 0 || + crypto_skcipher_alg_ivsize(ctr) != 16 || + ctr->base.cra_blocksize != 1) goto err_drop_ctr; - /* We want the real thing! */ - if (crypto_skcipher_alg_ivsize(ctr) != 16) + /* ctr and cbcmac must use the same underlying block cipher. */ + if (strcmp(ctr->base.cra_name + 4, mac->base.cra_name + 7) != 0) goto err_drop_ctr; err = -ENAMETOOLONG; + if (snprintf(inst->alg.base.cra_name, CRYPTO_MAX_ALG_NAME, + "ccm(%s", ctr->base.cra_name + 4) >= CRYPTO_MAX_ALG_NAME) + goto err_drop_ctr; + if (snprintf(inst->alg.base.cra_driver_name, CRYPTO_MAX_ALG_NAME, "ccm_base(%s,%s)", ctr->base.cra_driver_name, mac->base.cra_driver_name) >= CRYPTO_MAX_ALG_NAME) goto err_drop_ctr; - memcpy(inst->alg.base.cra_name, full_name, CRYPTO_MAX_ALG_NAME); - inst->alg.base.cra_flags = ctr->base.cra_flags & CRYPTO_ALG_ASYNC; inst->alg.base.cra_priority = (mac->base.cra_priority + ctr->base.cra_priority) / 2; @@ -564,7 +568,6 @@ static int crypto_ccm_create(struct crypto_template *tmpl, struct rtattr **tb) const char *cipher_name; char ctr_name[CRYPTO_MAX_ALG_NAME]; char mac_name[CRYPTO_MAX_ALG_NAME]; - char full_name[CRYPTO_MAX_ALG_NAME]; cipher_name = crypto_attr_alg_name(tb[1]); if (IS_ERR(cipher_name)) @@ -578,12 +581,7 @@ static int crypto_ccm_create(struct crypto_template *tmpl, struct rtattr **tb) cipher_name) >= CRYPTO_MAX_ALG_NAME) return -ENAMETOOLONG; - if (snprintf(full_name, CRYPTO_MAX_ALG_NAME, "ccm(%s)", cipher_name) >= - CRYPTO_MAX_ALG_NAME) - return -ENAMETOOLONG; - - return crypto_ccm_create_common(tmpl, tb, full_name, ctr_name, - mac_name); + return crypto_ccm_create_common(tmpl, tb, ctr_name, mac_name); } static struct crypto_template crypto_ccm_tmpl = { @@ -596,23 +594,17 @@ static int crypto_ccm_base_create(struct crypto_template *tmpl, struct rtattr **tb) { const char *ctr_name; - const char *cipher_name; - char full_name[CRYPTO_MAX_ALG_NAME]; + const char *mac_name; ctr_name = crypto_attr_alg_name(tb[1]); if (IS_ERR(ctr_name)) return PTR_ERR(ctr_name); - cipher_name = crypto_attr_alg_name(tb[2]); - if (IS_ERR(cipher_name)) - return PTR_ERR(cipher_name); - - if (snprintf(full_name, CRYPTO_MAX_ALG_NAME, "ccm_base(%s,%s)", - ctr_name, cipher_name) >= CRYPTO_MAX_ALG_NAME) - return -ENAMETOOLONG; + mac_name = crypto_attr_alg_name(tb[2]); + if (IS_ERR(mac_name)) + return PTR_ERR(mac_name); - return crypto_ccm_create_common(tmpl, tb, full_name, ctr_name, - cipher_name); + return crypto_ccm_create_common(tmpl, tb, ctr_name, mac_name); } static struct crypto_template crypto_ccm_base_tmpl = {