diff --git a/drivers/crypto/atmel-aes.c b/drivers/crypto/atmel-aes.c index 889830e5b6518d2be5df9591c64058008f323155..3f7b0966163cf34b1f5b9dc49959e485f2997b7e 100644 --- a/drivers/crypto/atmel-aes.c +++ b/drivers/crypto/atmel-aes.c @@ -111,6 +111,7 @@ struct atmel_aes_base_ctx { int keylen; u32 key[AES_KEYSIZE_256 / sizeof(u32)]; u16 block_size; + bool is_aead; }; struct atmel_aes_ctx { @@ -157,6 +158,7 @@ struct atmel_aes_authenc_ctx { struct atmel_aes_reqctx { unsigned long mode; + u32 lastc[AES_BLOCK_SIZE / sizeof(u32)]; }; #ifdef CONFIG_CRYPTO_DEV_ATMEL_AUTHENC @@ -498,12 +500,34 @@ static void atmel_aes_authenc_complete(struct atmel_aes_dev *dd, int err); static inline int atmel_aes_complete(struct atmel_aes_dev *dd, int err) { #ifdef CONFIG_CRYPTO_DEV_ATMEL_AUTHENC - atmel_aes_authenc_complete(dd, err); + if (dd->ctx->is_aead) + atmel_aes_authenc_complete(dd, err); #endif clk_disable(dd->iclk); dd->flags &= ~AES_FLAGS_BUSY; + if (!dd->ctx->is_aead) { + struct ablkcipher_request *req = + ablkcipher_request_cast(dd->areq); + struct atmel_aes_reqctx *rctx = ablkcipher_request_ctx(req); + struct crypto_ablkcipher *ablkcipher = + crypto_ablkcipher_reqtfm(req); + int ivsize = crypto_ablkcipher_ivsize(ablkcipher); + + if (rctx->mode & AES_FLAGS_ENCRYPT) { + scatterwalk_map_and_copy(req->info, req->dst, + req->nbytes - ivsize, ivsize, 0); + } else { + if (req->src == req->dst) { + memcpy(req->info, rctx->lastc, ivsize); + } else { + scatterwalk_map_and_copy(req->info, req->src, + req->nbytes - ivsize, ivsize, 0); + } + } + } + if (dd->is_async) dd->areq->complete(dd->areq, err); @@ -1072,11 +1096,11 @@ static int atmel_aes_ctr_start(struct atmel_aes_dev *dd) static int atmel_aes_crypt(struct ablkcipher_request *req, unsigned long mode) { - struct atmel_aes_base_ctx *ctx; + struct crypto_ablkcipher *ablkcipher = crypto_ablkcipher_reqtfm(req); + struct atmel_aes_base_ctx *ctx = crypto_ablkcipher_ctx(ablkcipher); struct atmel_aes_reqctx *rctx; struct atmel_aes_dev *dd; - ctx = crypto_ablkcipher_ctx(crypto_ablkcipher_reqtfm(req)); switch (mode & AES_FLAGS_OPMODE_MASK) { case AES_FLAGS_CFB8: ctx->block_size = CFB8_BLOCK_SIZE; @@ -1098,6 +1122,7 @@ static int atmel_aes_crypt(struct ablkcipher_request *req, unsigned long mode) ctx->block_size = AES_BLOCK_SIZE; break; } + ctx->is_aead = false; dd = atmel_aes_find_dev(ctx); if (!dd) @@ -1106,6 +1131,13 @@ static int atmel_aes_crypt(struct ablkcipher_request *req, unsigned long mode) rctx = ablkcipher_request_ctx(req); rctx->mode = mode; + if (!(mode & AES_FLAGS_ENCRYPT) && (req->src == req->dst)) { + int ivsize = crypto_ablkcipher_ivsize(ablkcipher); + + scatterwalk_map_and_copy(rctx->lastc, req->src, + (req->nbytes - ivsize), ivsize, 0); + } + return atmel_aes_handle_queue(dd, &req->base); } @@ -1740,6 +1772,7 @@ static int atmel_aes_gcm_crypt(struct aead_request *req, ctx = crypto_aead_ctx(crypto_aead_reqtfm(req)); ctx->block_size = AES_BLOCK_SIZE; + ctx->is_aead = true; dd = atmel_aes_find_dev(ctx); if (!dd) @@ -2224,6 +2257,7 @@ static int atmel_aes_authenc_crypt(struct aead_request *req, rctx->base.mode = mode; ctx->block_size = AES_BLOCK_SIZE; + ctx->is_aead = true; dd = atmel_aes_find_dev(ctx); if (!dd)