// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/io/crypto/aes_cipher.h" #include #include #include #include #include #include #include "paddle/fluid/framework/io/crypto/cipher_utils.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace framework { void AESCipher::Init(const std::string& cipher_name, const int& iv_size, const int& tag_size) { aes_cipher_name_ = cipher_name; iv_size_ = iv_size; tag_size_ = tag_size; std::set authented_cipher_set{"AES_GCM_NoPadding"}; if (authented_cipher_set.find(cipher_name) != authented_cipher_set.end()) { is_authenticated_cipher_ = true; } } std::string AESCipher::EncryptInternal(const std::string& plaintext, const std::string& key) { CryptoPP::member_ptr m_cipher; CryptoPP::member_ptr m_filter; bool need_iv = false; std::string iv = ""; const unsigned char* key_char = reinterpret_cast(&(key.at(0))); BuildCipher(true, &need_iv, &m_cipher, &m_filter); if (need_iv) { iv_ = CipherUtils::GenKey(iv_size_); m_cipher.get()->SetKeyWithIV( key_char, key.size(), reinterpret_cast(&(iv_.at(0)))); } else { m_cipher.get()->SetKey(key_char, key.size()); } std::string ciphertext; m_filter->Attach(new CryptoPP::StringSink(ciphertext)); CryptoPP::StringSource(plaintext, true, new CryptoPP::Redirector(*m_filter)); if (need_iv) { ciphertext = iv_ + ciphertext; } return ciphertext; } std::string AESCipher::DecryptInternal(const std::string& ciphertext, const std::string& key) { CryptoPP::member_ptr m_cipher; CryptoPP::member_ptr m_filter; bool need_iv = false; std::string iv = ""; const unsigned char* key_char = reinterpret_cast(&(key.at(0))); BuildCipher(false, &need_iv, &m_cipher, &m_filter); int ciphertext_beg = 0; if (need_iv) { iv_ = ciphertext.substr(0, iv_size_ / 8); ciphertext_beg = iv_size_ / 8; m_cipher.get()->SetKeyWithIV( key_char, key.size(), reinterpret_cast(&(iv_.at(0)))); } else { m_cipher.get()->SetKey(key_char, key.size()); } std::string plaintext; m_filter->Attach(new CryptoPP::StringSink(plaintext)); CryptoPP::StringSource(ciphertext.substr(ciphertext_beg), true, new CryptoPP::Redirector(*m_filter)); return plaintext; } std::string AESCipher::AuthenticatedEncryptInternal( const std::string& plaintext, const std::string& key) { CryptoPP::member_ptr m_cipher; CryptoPP::member_ptr m_filter; bool need_iv = false; std::string iv = ""; const unsigned char* key_char = reinterpret_cast(&(key.at(0))); BuildAuthEncCipher(&need_iv, &m_cipher, &m_filter); if (need_iv) { iv_ = CipherUtils::GenKey(iv_size_); m_cipher.get()->SetKeyWithIV( key_char, key.size(), reinterpret_cast(&(iv_.at(0)))); } else { m_cipher.get()->SetKey(key_char, key.size()); } std::string ciphertext; m_filter->Attach(new CryptoPP::StringSink(ciphertext)); CryptoPP::StringSource(plaintext, true, new CryptoPP::Redirector(*m_filter)); if (need_iv) { ciphertext = iv_ + ciphertext; } return ciphertext; } std::string AESCipher::AuthenticatedDecryptInternal( const std::string& ciphertext, const std::string& key) { CryptoPP::member_ptr m_cipher; CryptoPP::member_ptr m_filter; bool need_iv = false; std::string iv = ""; const unsigned char* key_char = reinterpret_cast(&(key.at(0))); BuildAuthDecCipher(&need_iv, &m_cipher, &m_filter); int ciphertext_beg = 0; if (need_iv) { iv_ = ciphertext.substr(0, iv_size_ / 8); ciphertext_beg = iv_size_ / 8; m_cipher.get()->SetKeyWithIV( key_char, key.size(), reinterpret_cast(&(iv_.at(0)))); } else { m_cipher.get()->SetKey(key_char, key.size()); } std::string plaintext; m_filter->Attach(new CryptoPP::StringSink(plaintext)); CryptoPP::StringSource(ciphertext.substr(ciphertext_beg), true, new CryptoPP::Redirector(*m_filter)); PADDLE_ENFORCE_EQ( m_filter->GetLastResult(), true, paddle::platform::errors::InvalidArgument("Integrity check failed. " "Invalid ciphertext input.")); return plaintext; } void AESCipher::BuildCipher( bool for_encrypt, bool* need_iv, CryptoPP::member_ptr* m_cipher, CryptoPP::member_ptr* m_filter) { if (aes_cipher_name_ == "AES_ECB_PKCSPadding" && for_encrypt) { m_cipher->reset(new CryptoPP::ECB_Mode::Encryption); m_filter->reset(new CryptoPP::StreamTransformationFilter( *(*m_cipher).get(), NULL, CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING)); } else if (aes_cipher_name_ == "AES_ECB_PKCSPadding" && !for_encrypt) { m_cipher->reset(new CryptoPP::ECB_Mode::Decryption); m_filter->reset(new CryptoPP::StreamTransformationFilter( *(*m_cipher).get(), NULL, CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING)); } else if (aes_cipher_name_ == "AES_CBC_PKCSPadding" && for_encrypt) { m_cipher->reset(new CryptoPP::CBC_Mode::Encryption); *need_iv = true; m_filter->reset(new CryptoPP::StreamTransformationFilter( *(*m_cipher).get(), NULL, CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING)); } else if (aes_cipher_name_ == "AES_CBC_PKCSPadding" && !for_encrypt) { m_cipher->reset(new CryptoPP::CBC_Mode::Decryption); *need_iv = true; m_filter->reset(new CryptoPP::StreamTransformationFilter( *(*m_cipher).get(), NULL, CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING)); } else if (aes_cipher_name_ == "AES_CTR_NoPadding" && for_encrypt) { m_cipher->reset(new CryptoPP::CTR_Mode::Encryption); *need_iv = true; m_filter->reset(new CryptoPP::StreamTransformationFilter( *(*m_cipher).get(), NULL, CryptoPP::BlockPaddingSchemeDef::NO_PADDING)); } else if (aes_cipher_name_ == "AES_CTR_NoPadding" && !for_encrypt) { m_cipher->reset(new CryptoPP::CTR_Mode::Decryption); *need_iv = true; m_filter->reset(new CryptoPP::StreamTransformationFilter( *(*m_cipher).get(), NULL, CryptoPP::BlockPaddingSchemeDef::NO_PADDING)); } else { PADDLE_THROW(paddle::platform::errors::Unimplemented( "Create cipher error. " "Cipher name %s is error, or has not been implemented.", aes_cipher_name_)); } } void AESCipher::BuildAuthEncCipher( bool* need_iv, CryptoPP::member_ptr* m_cipher, CryptoPP::member_ptr* m_filter) { if (aes_cipher_name_ == "AES_GCM_NoPadding") { m_cipher->reset(new CryptoPP::GCM::Encryption); *need_iv = true; m_filter->reset(new CryptoPP::AuthenticatedEncryptionFilter( *(*m_cipher).get(), NULL, false, tag_size_ / 8, CryptoPP::DEFAULT_CHANNEL, CryptoPP::BlockPaddingSchemeDef::NO_PADDING)); } else { PADDLE_THROW(paddle::platform::errors::Unimplemented( "Create cipher error. " "Cipher name %s is error, or has not been implemented.", aes_cipher_name_)); } } void AESCipher::BuildAuthDecCipher( bool* need_iv, CryptoPP::member_ptr* m_cipher, CryptoPP::member_ptr* m_filter) { if (aes_cipher_name_ == "AES_GCM_NoPadding") { m_cipher->reset(new CryptoPP::GCM::Decryption); *need_iv = true; m_filter->reset(new CryptoPP::AuthenticatedDecryptionFilter( *(*m_cipher).get(), NULL, CryptoPP::AuthenticatedDecryptionFilter::DEFAULT_FLAGS, tag_size_ / 8, CryptoPP::BlockPaddingSchemeDef::NO_PADDING)); } else { PADDLE_THROW(paddle::platform::errors::Unimplemented( "Create cipher error. " "Cipher name %s is error, or has not been implemented.", aes_cipher_name_)); } } std::string AESCipher::Encrypt(const std::string& plaintext, const std::string& key) { return is_authenticated_cipher_ ? AuthenticatedEncryptInternal(plaintext, key) : EncryptInternal(plaintext, key); } std::string AESCipher::Decrypt(const std::string& ciphertext, const std::string& key) { return is_authenticated_cipher_ ? AuthenticatedDecryptInternal(ciphertext, key) : DecryptInternal(ciphertext, key); } void AESCipher::EncryptToFile(const std::string& plaintext, const std::string& key, const std::string& filename) { std::ofstream fout(filename); fout << this->Encrypt(plaintext, key); fout.close(); } std::string AESCipher::DecryptFromFile(const std::string& key, const std::string& filename) { std::ifstream fin(filename); std::string ciphertext{std::istreambuf_iterator(fin), std::istreambuf_iterator()}; fin.close(); return Decrypt(ciphertext, key); } } // namespace framework } // namespace paddle