diff --git a/paddle/fluid/framework/io/crypto/aes_cipher.cc b/paddle/fluid/framework/io/crypto/aes_cipher.cc index 95c89049e1c3df9495fe765765d66cf0720e6ccd..dce71692b6b07674fde15b29ac36ff7fe35aef1f 100644 --- a/paddle/fluid/framework/io/crypto/aes_cipher.cc +++ b/paddle/fluid/framework/io/crypto/aes_cipher.cc @@ -65,7 +65,7 @@ std::string AESCipher::EncryptInternal(const std::string& plaintext, iv_ = CipherUtils::GenKey(iv_size_); m_cipher.get()->SetKeyWithIV( key_char, key.size(), - reinterpret_cast(&(iv_.at(0)))); + reinterpret_cast(&(iv_.at(0))), iv_.size()); } else { m_cipher.get()->SetKey(key_char, key.size()); } @@ -74,7 +74,7 @@ std::string AESCipher::EncryptInternal(const std::string& plaintext, m_filter->Attach(new CryptoPP::StringSink(ciphertext)); CryptoPP::StringSource(plaintext, true, new CryptoPP::Redirector(*m_filter)); if (need_iv) { - ciphertext = iv_ + ciphertext; + ciphertext = iv_.append(ciphertext); } return ciphertext; @@ -95,7 +95,7 @@ std::string AESCipher::DecryptInternal(const std::string& ciphertext, ciphertext_beg = iv_size_ / 8; m_cipher.get()->SetKeyWithIV( key_char, key.size(), - reinterpret_cast(&(iv_.at(0)))); + reinterpret_cast(&(iv_.at(0))), iv_.size()); } else { m_cipher.get()->SetKey(key_char, key.size()); } @@ -120,7 +120,7 @@ std::string AESCipher::AuthenticatedEncryptInternal( iv_ = CipherUtils::GenKey(iv_size_); m_cipher.get()->SetKeyWithIV( key_char, key.size(), - reinterpret_cast(&(iv_.at(0)))); + reinterpret_cast(&(iv_.at(0))), iv_.size()); } else { m_cipher.get()->SetKey(key_char, key.size()); } @@ -129,7 +129,7 @@ std::string AESCipher::AuthenticatedEncryptInternal( m_filter->Attach(new CryptoPP::StringSink(ciphertext)); CryptoPP::StringSource(plaintext, true, new CryptoPP::Redirector(*m_filter)); if (need_iv) { - ciphertext = iv_ + ciphertext; + ciphertext = iv_.append(ciphertext); } return ciphertext; @@ -150,7 +150,7 @@ std::string AESCipher::AuthenticatedDecryptInternal( ciphertext_beg = iv_size_ / 8; m_cipher.get()->SetKeyWithIV( key_char, key.size(), - reinterpret_cast(&(iv_.at(0)))); + reinterpret_cast(&(iv_.at(0))), iv_.size()); } else { m_cipher.get()->SetKey(key_char, key.size()); } diff --git a/paddle/fluid/framework/io/crypto/aes_cipher_test.cc b/paddle/fluid/framework/io/crypto/aes_cipher_test.cc index e80781854f58961d4bbf947f88630a92b2d6f533..0702e4ab78364cc01d7d5a7d4bee6daf1ae67735 100644 --- a/paddle/fluid/framework/io/crypto/aes_cipher_test.cc +++ b/paddle/fluid/framework/io/crypto/aes_cipher_test.cc @@ -45,6 +45,7 @@ TEST_F(AESTest, security_string) { {"AES_CTR_NoPadding", "AES_CBC_PKCSPadding", "AES_ECB_PKCSPadding", "AES_GCM_NoPadding"}); const std::string plaintext("hello world."); + bool is_throw = false; for (auto& i : name_list) { AESTest::GenConfigFile(i); try { @@ -54,8 +55,10 @@ TEST_F(AESTest, security_string) { std::string plaintext1 = cipher->Decrypt(ciphertext, AESTest::key); EXPECT_EQ(plaintext, plaintext1); } catch (CryptoPP::Exception& e) { + is_throw = true; LOG(ERROR) << e.what(); } + EXPECT_FALSE(is_throw); } } @@ -64,6 +67,7 @@ TEST_F(AESTest, security_vector) { {"AES_CTR_NoPadding", "AES_CBC_PKCSPadding", "AES_ECB_PKCSPadding", "AES_GCM_NoPadding"}); std::vector input{1, 2, 3, 4}; + bool is_throw = false; for (auto& i : name_list) { AESTest::GenConfigFile(i); try { @@ -79,8 +83,10 @@ TEST_F(AESTest, security_vector) { EXPECT_EQ(i, output); } } catch (CryptoPP::Exception& e) { + is_throw = true; LOG(ERROR) << e.what(); } + EXPECT_FALSE(is_throw); } } @@ -90,6 +96,7 @@ TEST_F(AESTest, encrypt_to_file) { "AES_GCM_NoPadding"}); const std::string plaintext("hello world."); std::string filename("aes_test.ciphertext"); + bool is_throw = false; for (auto& i : name_list) { AESTest::GenConfigFile(i); try { @@ -98,8 +105,10 @@ TEST_F(AESTest, encrypt_to_file) { std::string plaintext1 = cipher->DecryptFromFile(AESTest::key, filename); EXPECT_EQ(plaintext, plaintext1); } catch (CryptoPP::Exception& e) { + is_throw = true; LOG(ERROR) << e.what(); } + EXPECT_FALSE(is_throw); } } diff --git a/paddle/fluid/framework/io/crypto/cipher.cc b/paddle/fluid/framework/io/crypto/cipher.cc index 763c282017e16e4dfc823a9784fe30a4aa875c92..eca175c020cb6f85eac2970aa9734c0a6850ebef 100644 --- a/paddle/fluid/framework/io/crypto/cipher.cc +++ b/paddle/fluid/framework/io/crypto/cipher.cc @@ -48,10 +48,10 @@ std::shared_ptr CipherFactory::CreateCipher( ret->Init(cipher_name, iv_size, tag_size); return ret; } else { - PADDLE_THROW( + PADDLE_THROW(paddle::platform::errors::InvalidArgument( "Invalid cipher name is specied. " "Please check you have specified valid cipher" - " name in CryptoProperties."); + " name in CryptoProperties.")); } return nullptr; } diff --git a/paddle/fluid/framework/io/crypto/cipher_utils.cc b/paddle/fluid/framework/io/crypto/cipher_utils.cc index 77fe9810ac251de79c60cdc351cf1bf63136b0e4..b8a2c5419b37e95ad1f1264a6dc9de7fb74a6d9e 100644 --- a/paddle/fluid/framework/io/crypto/cipher_utils.cc +++ b/paddle/fluid/framework/io/crypto/cipher_utils.cc @@ -111,7 +111,7 @@ bool CipherUtils::GetValue( return true; } -const int CipherUtils::AES_DEFAULT_IV_SIZE = 96; +const int CipherUtils::AES_DEFAULT_IV_SIZE = 128; const int CipherUtils::AES_DEFAULT_TAG_SIZE = 128; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/io/crypto/cipher_utils_test.cc b/paddle/fluid/framework/io/crypto/cipher_utils_test.cc index 2f7d7282b8b701d71b25976ecf88f1d00a60ff4b..eddb8ca699b8f0ee82b206f3fe6b2f9c852e0430 100644 --- a/paddle/fluid/framework/io/crypto/cipher_utils_test.cc +++ b/paddle/fluid/framework/io/crypto/cipher_utils_test.cc @@ -70,7 +70,7 @@ TEST(CipherUtils, gen_key) { EXPECT_NE(key, key1); std::string key2 = CipherUtils::ReadKeyFromFile(filename); EXPECT_EQ(key1, key2); - EXPECT_EQ(key.size(), 256 / 8); + EXPECT_EQ(static_cast(key.size()), 32); } } // namespace framework