From 5fad21c9eb59f4c8d6981531b38a9d01145d4fb7 Mon Sep 17 00:00:00 2001 From: yangqingyou Date: Thu, 21 May 2020 13:09:00 +0000 Subject: [PATCH] add crypto helper for paddle, test=develop --- cmake/external/cryptopp.cmake | 84 ++++++ cmake/third_party.cmake | 3 +- paddle/fluid/framework/io/CMakeLists.txt | 1 + .../fluid/framework/io/crypto/CMakeLists.txt | 14 + .../fluid/framework/io/crypto/aes_cipher.cc | 281 ++++++++++++++++++ paddle/fluid/framework/io/crypto/aes_cipher.h | 93 ++++++ .../framework/io/crypto/aes_cipher_test.cc | 107 +++++++ paddle/fluid/framework/io/crypto/cipher.cc | 60 ++++ paddle/fluid/framework/io/crypto/cipher.h | 51 ++++ .../fluid/framework/io/crypto/cipher_utils.cc | 117 ++++++++ .../fluid/framework/io/crypto/cipher_utils.h | 63 ++++ .../framework/io/crypto/cipher_utils_test.cc | 77 +++++ 12 files changed, 950 insertions(+), 1 deletion(-) create mode 100644 cmake/external/cryptopp.cmake create mode 100644 paddle/fluid/framework/io/crypto/CMakeLists.txt create mode 100644 paddle/fluid/framework/io/crypto/aes_cipher.cc create mode 100644 paddle/fluid/framework/io/crypto/aes_cipher.h create mode 100644 paddle/fluid/framework/io/crypto/aes_cipher_test.cc create mode 100644 paddle/fluid/framework/io/crypto/cipher.cc create mode 100644 paddle/fluid/framework/io/crypto/cipher.h create mode 100644 paddle/fluid/framework/io/crypto/cipher_utils.cc create mode 100644 paddle/fluid/framework/io/crypto/cipher_utils.h create mode 100644 paddle/fluid/framework/io/crypto/cipher_utils_test.cc diff --git a/cmake/external/cryptopp.cmake b/cmake/external/cryptopp.cmake new file mode 100644 index 00000000000..4ec63de4150 --- /dev/null +++ b/cmake/external/cryptopp.cmake @@ -0,0 +1,84 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +INCLUDE(ExternalProject) + +SET(CRYPTOPP_PREFIX_DIR ${THIRD_PARTY_PATH}/cryptopp) +SET(CRYPTOPP_INSTALL_DIR ${THIRD_PARTY_PATH}/install/cryptopp) +SET(CRYPTOPP_INCLUDE_DIR "${CRYPTOPP_INSTALL_DIR}/include" CACHE PATH "cryptopp include directory." FORCE) +SET(CRYPTOPP_REPOSITORY https://github.com/weidai11/cryptopp.git) +SET(CRYPTOPP_TAG CRYPTOPP_8_2_0) + +IF(WIN32) + SET(CRYPTOPP_LIBRARIES "${CRYPTOPP_INSTALL_DIR}/lib/cryptopp-static.lib" CACHE FILEPATH "cryptopp library." FORCE) + SET(CRYPTOPP_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd") + set(CompilerFlags + CMAKE_CXX_FLAGS + CMAKE_CXX_FLAGS_DEBUG + CMAKE_CXX_FLAGS_RELEASE + CMAKE_C_FLAGS + CMAKE_C_FLAGS_DEBUG + CMAKE_C_FLAGS_RELEASE + ) + foreach(CompilerFlag ${CompilerFlags}) + string(REPLACE "/MD" "/MT" ${CompilerFlag} "${${CompilerFlag}}") + endforeach() +ELSE(WIN32) + SET(CRYPTOPP_LIBRARIES "${CRYPTOPP_INSTALL_DIR}/lib/libcryptopp.a" CACHE FILEPATH "cryptopp library." FORCE) + SET(CRYPTOPP_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) +ENDIF(WIN32) + +set(CRYPTOPP_CMAKE_ARGS ${COMMON_CMAKE_ARGS} + -DBUILD_SHARED=ON + -DBUILD_STATIC=ON + -DBUILD_TESTING=OFF + -DCMAKE_INSTALL_LIBDIR=${CRYPTOPP_INSTALL_DIR}/lib + -DCMAKE_INSTALL_PREFIX=${CRYPTOPP_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + -DCMAKE_CXX_FLAGS=${CRYPTOPP_CMAKE_CXX_FLAGS} + -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} +) + +INCLUDE_DIRECTORIES(${CRYPTOPP_INCLUDE_DIR}) + +cache_third_party(extern_cryptopp + REPOSITORY ${CRYPTOPP_REPOSITORY} + TAG ${CRYPTOPP_TAG} + DIR CRYPTOPP_SOURCE_DIR) + +ExternalProject_Add( + extern_cryptopp + ${EXTERNAL_PROJECT_LOG_ARGS} + ${SHALLOW_CLONE} + "${CRYPTOPP_DOWNLOAD_CMD}" + PREFIX ${CRYPTOPP_PREFIX_DIR} + SOURCE_DIR ${CRYPTOPP_SOURCE_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + PATCH_COMMAND + COMMAND ${CMAKE_COMMAND} -E remove_directory "/cmake/" + COMMAND git clone -b ${CRYPTOPP_TAG} https://github.com/noloader/cryptopp-cmake "/cmake" + COMMAND ${CMAKE_COMMAND} -E copy_directory "/cmake/" "/" + INSTALL_DIR ${CRYPTOPP_INSTALL_DIR} + CMAKE_ARGS ${CRYPTOPP_CMAKE_ARGS} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${CRYPTOPP_INSTALL_DIR} + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} +) + +ADD_LIBRARY(cryptopp STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET cryptopp PROPERTY IMPORTED_LOCATION ${CRYPTOPP_LIBRARIES}) +ADD_DEPENDENCIES(cryptopp extern_cryptopp) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 837babea020..55737265d05 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -205,9 +205,10 @@ include(external/threadpool)# download threadpool include(external/dlpack) # download dlpack include(external/xxhash) # download, build, install xxhash include(external/warpctc) # download, build, install warpctc +include(external/cryptopp) # download, build, install cryptopp list(APPEND third_party_deps extern_eigen3 extern_gflags extern_glog extern_boost extern_xxhash) -list(APPEND third_party_deps extern_zlib extern_dlpack extern_warpctc extern_threadpool) +list(APPEND third_party_deps extern_zlib extern_dlpack extern_warpctc extern_threadpool extern_cryptopp) # download file set(CUDAERROR_URL "http://paddlepaddledeps.bj.bcebos.com/cudaErrorMessage.tar.gz" CACHE STRING "" FORCE) diff --git a/paddle/fluid/framework/io/CMakeLists.txt b/paddle/fluid/framework/io/CMakeLists.txt index 0f0e562c0f5..2c62489dadf 100644 --- a/paddle/fluid/framework/io/CMakeLists.txt +++ b/paddle/fluid/framework/io/CMakeLists.txt @@ -2,3 +2,4 @@ cc_library(fs SRCS fs.cc DEPS string_helper glog boost) cc_library(shell SRCS shell.cc DEPS string_helper glog timer enforce) cc_test(test_fs SRCS test_fs.cc DEPS fs shell) +add_subdirectory(crypto) diff --git a/paddle/fluid/framework/io/crypto/CMakeLists.txt b/paddle/fluid/framework/io/crypto/CMakeLists.txt new file mode 100644 index 00000000000..175e124638f --- /dev/null +++ b/paddle/fluid/framework/io/crypto/CMakeLists.txt @@ -0,0 +1,14 @@ +#cc_library(cipher_factory SRCS cipher_factory.cc DEPS aes_gcm_crypto) +#cc_library(crypto_helper SRCS crypto_helper.cc DEPS cryptopp) +#cc_library(aes_gcm_crypto SRCS aes_gcm_crypto.cc DEPS cryptopp glog) + +cc_library(paddle_ciphers SRCS cipher_utils.cc cipher.cc aes_cipher.cc DEPS cryptopp) + +#cc_test(crypt_fstream_test SRCS crypt_fstream_test.cc DEPS cryptopp) +cc_test(aes_cipher_test SRCS aes_cipher_test.cc DEPS paddle_ciphers) + +cc_test(cipher_utils_test SRCS cipher_utils_test.cc DEPS paddle_ciphers) + +set(CMAKE_BUILD_TYPE "Debug") +#add_executable(cryptopp_test cryptopp_test.cc) +#target_link_libraries(cryptopp_test cryptopp crypto_helper) diff --git a/paddle/fluid/framework/io/crypto/aes_cipher.cc b/paddle/fluid/framework/io/crypto/aes_cipher.cc new file mode 100644 index 00000000000..95c89049e1c --- /dev/null +++ b/paddle/fluid/framework/io/crypto/aes_cipher.cc @@ -0,0 +1,281 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/io/crypto/aes_cipher.h" + +#include +#include +#include +#include + +#include +#include + +#include "paddle/fluid/framework/io/crypto/cipher_utils.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { + +void AESCipher::Init(const std::string& cipher_name, const int& iv_size, + const int& tag_size) { + aes_cipher_name_ = cipher_name; + iv_size_ = iv_size; + tag_size_ = tag_size; + std::set authented_cipher_set{"AES_GCM_NoPadding"}; + if (authented_cipher_set.find(cipher_name) != authented_cipher_set.end()) { + is_authenticated_cipher_ = true; + } +} + +std::string AESCipher::EncryptInternal(const std::string& plaintext, + const std::string& key) { + CryptoPP::member_ptr m_cipher; + CryptoPP::member_ptr m_filter; + bool need_iv = false; + std::string iv = ""; + const unsigned char* key_char = + reinterpret_cast(&(key.at(0))); + BuildCipher(true, &need_iv, &m_cipher, &m_filter); + if (need_iv) { + iv_ = CipherUtils::GenKey(iv_size_); + m_cipher.get()->SetKeyWithIV( + key_char, key.size(), + reinterpret_cast(&(iv_.at(0)))); + } else { + m_cipher.get()->SetKey(key_char, key.size()); + } + + std::string ciphertext; + m_filter->Attach(new CryptoPP::StringSink(ciphertext)); + CryptoPP::StringSource(plaintext, true, new CryptoPP::Redirector(*m_filter)); + if (need_iv) { + ciphertext = iv_ + ciphertext; + } + + return ciphertext; +} + +std::string AESCipher::DecryptInternal(const std::string& ciphertext, + const std::string& key) { + CryptoPP::member_ptr m_cipher; + CryptoPP::member_ptr m_filter; + bool need_iv = false; + std::string iv = ""; + const unsigned char* key_char = + reinterpret_cast(&(key.at(0))); + BuildCipher(false, &need_iv, &m_cipher, &m_filter); + int ciphertext_beg = 0; + if (need_iv) { + iv_ = ciphertext.substr(0, iv_size_ / 8); + ciphertext_beg = iv_size_ / 8; + m_cipher.get()->SetKeyWithIV( + key_char, key.size(), + reinterpret_cast(&(iv_.at(0)))); + } else { + m_cipher.get()->SetKey(key_char, key.size()); + } + std::string plaintext; + m_filter->Attach(new CryptoPP::StringSink(plaintext)); + CryptoPP::StringSource(ciphertext.substr(ciphertext_beg), true, + new CryptoPP::Redirector(*m_filter)); + + return plaintext; +} + +std::string AESCipher::AuthenticatedEncryptInternal( + const std::string& plaintext, const std::string& key) { + CryptoPP::member_ptr m_cipher; + CryptoPP::member_ptr m_filter; + bool need_iv = false; + std::string iv = ""; + const unsigned char* key_char = + reinterpret_cast(&(key.at(0))); + BuildAuthEncCipher(&need_iv, &m_cipher, &m_filter); + if (need_iv) { + iv_ = CipherUtils::GenKey(iv_size_); + m_cipher.get()->SetKeyWithIV( + key_char, key.size(), + reinterpret_cast(&(iv_.at(0)))); + } else { + m_cipher.get()->SetKey(key_char, key.size()); + } + + std::string ciphertext; + m_filter->Attach(new CryptoPP::StringSink(ciphertext)); + CryptoPP::StringSource(plaintext, true, new CryptoPP::Redirector(*m_filter)); + if (need_iv) { + ciphertext = iv_ + ciphertext; + } + + return ciphertext; +} + +std::string AESCipher::AuthenticatedDecryptInternal( + const std::string& ciphertext, const std::string& key) { + CryptoPP::member_ptr m_cipher; + CryptoPP::member_ptr m_filter; + bool need_iv = false; + std::string iv = ""; + const unsigned char* key_char = + reinterpret_cast(&(key.at(0))); + BuildAuthDecCipher(&need_iv, &m_cipher, &m_filter); + int ciphertext_beg = 0; + if (need_iv) { + iv_ = ciphertext.substr(0, iv_size_ / 8); + ciphertext_beg = iv_size_ / 8; + m_cipher.get()->SetKeyWithIV( + key_char, key.size(), + reinterpret_cast(&(iv_.at(0)))); + } else { + m_cipher.get()->SetKey(key_char, key.size()); + } + std::string plaintext; + m_filter->Attach(new CryptoPP::StringSink(plaintext)); + CryptoPP::StringSource(ciphertext.substr(ciphertext_beg), true, + new CryptoPP::Redirector(*m_filter)); + PADDLE_ENFORCE_EQ( + m_filter->GetLastResult(), true, + paddle::platform::errors::InvalidArgument("Integrity check failed. " + "Invalid ciphertext input.")); + return plaintext; +} + +void AESCipher::BuildCipher( + bool for_encrypt, bool* need_iv, + CryptoPP::member_ptr* m_cipher, + CryptoPP::member_ptr* m_filter) { + if (aes_cipher_name_ == "AES_ECB_PKCSPadding" && for_encrypt) { + m_cipher->reset(new CryptoPP::ECB_Mode::Encryption); + m_filter->reset(new CryptoPP::StreamTransformationFilter( + *(*m_cipher).get(), NULL, + CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING)); + } else if (aes_cipher_name_ == "AES_ECB_PKCSPadding" && !for_encrypt) { + m_cipher->reset(new CryptoPP::ECB_Mode::Decryption); + m_filter->reset(new CryptoPP::StreamTransformationFilter( + *(*m_cipher).get(), NULL, + CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING)); + } else if (aes_cipher_name_ == "AES_CBC_PKCSPadding" && for_encrypt) { + m_cipher->reset(new CryptoPP::CBC_Mode::Encryption); + *need_iv = true; + m_filter->reset(new CryptoPP::StreamTransformationFilter( + *(*m_cipher).get(), NULL, + CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING)); + } else if (aes_cipher_name_ == "AES_CBC_PKCSPadding" && !for_encrypt) { + m_cipher->reset(new CryptoPP::CBC_Mode::Decryption); + *need_iv = true; + m_filter->reset(new CryptoPP::StreamTransformationFilter( + *(*m_cipher).get(), NULL, + CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING)); + } else if (aes_cipher_name_ == "AES_CTR_NoPadding" && for_encrypt) { + m_cipher->reset(new CryptoPP::CTR_Mode::Encryption); + *need_iv = true; + m_filter->reset(new CryptoPP::StreamTransformationFilter( + *(*m_cipher).get(), NULL, CryptoPP::BlockPaddingSchemeDef::NO_PADDING)); + } else if (aes_cipher_name_ == "AES_CTR_NoPadding" && !for_encrypt) { + m_cipher->reset(new CryptoPP::CTR_Mode::Decryption); + *need_iv = true; + m_filter->reset(new CryptoPP::StreamTransformationFilter( + *(*m_cipher).get(), NULL, CryptoPP::BlockPaddingSchemeDef::NO_PADDING)); + } else { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Create cipher error. " + "Cipher name %s is error, or has not been implemented.", + aes_cipher_name_)); + } +} + +void AESCipher::BuildAuthEncCipher( + bool* need_iv, + CryptoPP::member_ptr* m_cipher, + CryptoPP::member_ptr* m_filter) { + if (aes_cipher_name_ == "AES_GCM_NoPadding") { + m_cipher->reset(new CryptoPP::GCM::Encryption); + *need_iv = true; + m_filter->reset(new CryptoPP::AuthenticatedEncryptionFilter( + *(*m_cipher).get(), NULL, false, tag_size_ / 8, + CryptoPP::DEFAULT_CHANNEL, + CryptoPP::BlockPaddingSchemeDef::NO_PADDING)); + } else { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Create cipher error. " + "Cipher name %s is error, or has not been implemented.", + aes_cipher_name_)); + } +} + +void AESCipher::BuildAuthDecCipher( + bool* need_iv, + CryptoPP::member_ptr* m_cipher, + CryptoPP::member_ptr* m_filter) { + if (aes_cipher_name_ == "AES_GCM_NoPadding") { + m_cipher->reset(new CryptoPP::GCM::Decryption); + *need_iv = true; + m_filter->reset(new CryptoPP::AuthenticatedDecryptionFilter( + *(*m_cipher).get(), NULL, + CryptoPP::AuthenticatedDecryptionFilter::DEFAULT_FLAGS, tag_size_ / 8, + CryptoPP::BlockPaddingSchemeDef::NO_PADDING)); + } else { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Create cipher error. " + "Cipher name %s is error, or has not been implemented.", + aes_cipher_name_)); + } +} + +std::string AESCipher::Encrypt(const std::string& plaintext, + const std::string& key) { + return is_authenticated_cipher_ ? AuthenticatedEncryptInternal(plaintext, key) + : EncryptInternal(plaintext, key); +} + +std::string AESCipher::Decrypt(const std::string& ciphertext, + const std::string& key) { + return is_authenticated_cipher_ + ? AuthenticatedDecryptInternal(ciphertext, key) + : DecryptInternal(ciphertext, key); +} + +void AESCipher::EncryptToFile(const std::string& plaintext, + const std::string& key, + const std::string& filename) { + std::ofstream fout(filename); + fout << this->Encrypt(plaintext, key); + fout.close(); +} + +std::string AESCipher::DecryptFromFile(const std::string& key, + const std::string& filename) { + std::ifstream fin(filename); + std::string ciphertext{std::istreambuf_iterator(fin), + std::istreambuf_iterator()}; + fin.close(); + return Decrypt(ciphertext, key); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/io/crypto/aes_cipher.h b/paddle/fluid/framework/io/crypto/aes_cipher.h new file mode 100644 index 00000000000..0c1f1fd243c --- /dev/null +++ b/paddle/fluid/framework/io/crypto/aes_cipher.h @@ -0,0 +1,93 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include + +#include "paddle/fluid/framework/io/crypto/cipher.h" + +namespace paddle { +namespace framework { + +class AESCipher : public Cipher { + public: + AESCipher() = default; + ~AESCipher() {} + + std::string Encrypt(const std::string& input, + const std::string& key) override; + std::string Decrypt(const std::string& input, + const std::string& key) override; + + void EncryptToFile(const std::string& input, const std::string& key, + const std::string& filename) override; + std::string DecryptFromFile(const std::string& key, + const std::string& filename) override; + + void Init(const std::string& cipher_name, const int& iv_size, + const int& tag_size); + + private: + std::string EncryptInternal(const std::string& plaintext, + const std::string& key); + std::string DecryptInternal(const std::string& ciphertext, + const std::string& key); + + std::string AuthenticatedEncryptInternal(const std::string& plaintext, + const std::string& key); + std::string AuthenticatedDecryptInternal(const std::string& ciphertext, + const std::string& key); + + void BuildCipher( + bool for_encrypt, bool* need_iv, + CryptoPP::member_ptr* m_cipher, + CryptoPP::member_ptr* m_filter); + + void BuildAuthEncCipher( + bool* need_iv, + CryptoPP::member_ptr* m_cipher, + CryptoPP::member_ptr* m_filter); + + void BuildAuthDecCipher( + bool* need_iv, + CryptoPP::member_ptr* m_cipher, + CryptoPP::member_ptr* m_filter); + + std::string aes_cipher_name_; + int iv_size_; + int tag_size_; + std::string iv_; + bool is_authenticated_cipher_{false}; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/io/crypto/aes_cipher_test.cc b/paddle/fluid/framework/io/crypto/aes_cipher_test.cc new file mode 100644 index 00000000000..e80781854f5 --- /dev/null +++ b/paddle/fluid/framework/io/crypto/aes_cipher_test.cc @@ -0,0 +1,107 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/io/crypto/aes_cipher.h" + +#include +#include + +#include +#include + +#include "paddle/fluid/framework/io/crypto/cipher.h" +#include "paddle/fluid/framework/io/crypto/cipher_utils.h" + +namespace paddle { +namespace framework { + +class AESTest : public ::testing::Test { + public: + std::string key; + + void SetUp() override { key = CipherUtils::GenKey(256); } + static void GenConfigFile(const std::string& cipher_name); +}; + +void AESTest::GenConfigFile(const std::string& cipher_name) { + std::ofstream fout("aes_test.conf"); + fout << "cipher_name : " << cipher_name << std::endl; + fout.close(); +} + +TEST_F(AESTest, security_string) { + std::vector name_list( + {"AES_CTR_NoPadding", "AES_CBC_PKCSPadding", "AES_ECB_PKCSPadding", + "AES_GCM_NoPadding"}); + const std::string plaintext("hello world."); + for (auto& i : name_list) { + AESTest::GenConfigFile(i); + try { + auto cipher = CipherFactory::CreateCipher("aes_test.conf"); + std::string ciphertext = cipher->Encrypt(plaintext, AESTest::key); + + std::string plaintext1 = cipher->Decrypt(ciphertext, AESTest::key); + EXPECT_EQ(plaintext, plaintext1); + } catch (CryptoPP::Exception& e) { + LOG(ERROR) << e.what(); + } + } +} + +TEST_F(AESTest, security_vector) { + std::vector name_list( + {"AES_CTR_NoPadding", "AES_CBC_PKCSPadding", "AES_ECB_PKCSPadding", + "AES_GCM_NoPadding"}); + std::vector input{1, 2, 3, 4}; + for (auto& i : name_list) { + AESTest::GenConfigFile(i); + try { + auto cipher = CipherFactory::CreateCipher("aes_test.conf"); + for (auto& i : input) { + std::string ciphertext = + cipher->Encrypt(std::to_string(i), AESTest::key); + + std::string plaintext = cipher->Decrypt(ciphertext, AESTest::key); + + int output = std::stoi(plaintext); + + EXPECT_EQ(i, output); + } + } catch (CryptoPP::Exception& e) { + LOG(ERROR) << e.what(); + } + } +} + +TEST_F(AESTest, encrypt_to_file) { + std::vector name_list( + {"AES_CTR_NoPadding", "AES_CBC_PKCSPadding", "AES_ECB_PKCSPadding", + "AES_GCM_NoPadding"}); + const std::string plaintext("hello world."); + std::string filename("aes_test.ciphertext"); + for (auto& i : name_list) { + AESTest::GenConfigFile(i); + try { + auto cipher = CipherFactory::CreateCipher("aes_test.conf"); + cipher->EncryptToFile(plaintext, AESTest::key, filename); + std::string plaintext1 = cipher->DecryptFromFile(AESTest::key, filename); + EXPECT_EQ(plaintext, plaintext1); + } catch (CryptoPP::Exception& e) { + LOG(ERROR) << e.what(); + } + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/io/crypto/cipher.cc b/paddle/fluid/framework/io/crypto/cipher.cc new file mode 100644 index 00000000000..763c282017e --- /dev/null +++ b/paddle/fluid/framework/io/crypto/cipher.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/io/crypto/cipher.h" +#include "paddle/fluid/framework/io/crypto/aes_cipher.h" +#include "paddle/fluid/framework/io/crypto/cipher_utils.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { + +std::shared_ptr CipherFactory::CreateCipher( + const std::string& config_file) { + std::string cipher_name; + int iv_size; + int tag_size; + std::unordered_map config; + if (!config_file.empty()) { + config = CipherUtils::LoadConfig(config_file); + CipherUtils::GetValue(config, "cipher_name", &cipher_name); + } else { + // set default cipher name + cipher_name = "AES_CTR_NoPadding"; + } + if (cipher_name.find("AES") != cipher_name.npos) { + auto ret = std::make_shared(); + // if not set iv_size, set default value + if (config_file.empty() || + !CipherUtils::GetValue(config, "iv_size", &iv_size)) { + iv_size = CipherUtils::AES_DEFAULT_IV_SIZE; + } + // if not set tag_size, set default value + if (config_file.empty() || + !CipherUtils::GetValue(config, "tag_size", &tag_size)) { + tag_size = CipherUtils::AES_DEFAULT_IV_SIZE; + } + ret->Init(cipher_name, iv_size, tag_size); + return ret; + } else { + PADDLE_THROW( + "Invalid cipher name is specied. " + "Please check you have specified valid cipher" + " name in CryptoProperties."); + } + return nullptr; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/io/crypto/cipher.h b/paddle/fluid/framework/io/crypto/cipher.h new file mode 100644 index 00000000000..9072cb1180d --- /dev/null +++ b/paddle/fluid/framework/io/crypto/cipher.h @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace framework { + +class Cipher { + public: + Cipher() = default; + virtual ~Cipher() {} + // encrypt string + virtual std::string Encrypt(const std::string& plaintext, + const std::string& key) = 0; + // decrypt string + virtual std::string Decrypt(const std::string& ciphertext, + const std::string& key) = 0; + + // encrypt strings and read them to file, + virtual void EncryptToFile(const std::string& plaintext, + const std::string& key, + const std::string& filename) = 0; + // read from file and decrypt them + virtual std::string DecryptFromFile(const std::string& key, + const std::string& filename) = 0; +}; + +class CipherFactory { + public: + CipherFactory() = default; + static std::shared_ptr CreateCipher(const std::string& config_file); +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/io/crypto/cipher_utils.cc b/paddle/fluid/framework/io/crypto/cipher_utils.cc new file mode 100644 index 00000000000..77fe9810ac2 --- /dev/null +++ b/paddle/fluid/framework/io/crypto/cipher_utils.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/io/crypto/cipher_utils.h" + +#include + +#include +#include + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { + +std::string CipherUtils::GenKey(int length) { + CryptoPP::AutoSeededRandomPool prng; + int bit_length = length / 8; + std::string rng; + rng.resize(bit_length); + // CryptoPP::byte key[length]; + prng.GenerateBlock(reinterpret_cast(&(rng.at(0))), + rng.size()); + return rng; +} + +std::string CipherUtils::GenKeyToFile(int length, const std::string& filename) { + CryptoPP::AutoSeededRandomPool prng; + std::string rng; + int bit_length = length / 8; + rng.resize(bit_length); + // CryptoPP::byte key[length]; + prng.GenerateBlock(reinterpret_cast(&(rng.at(0))), + rng.size()); + std::ofstream fout(filename); + PADDLE_ENFORCE_EQ(fout.is_open(), true, + paddle::platform::errors::Unavailable( + "Failed to open file : %s, " + "make sure input filename is available.", + filename)); + fout.write(rng.c_str(), rng.size()); + fout.close(); + return rng; +} + +std::string CipherUtils::ReadKeyFromFile(const std::string& filename) { + std::ifstream fin(filename, std::ios::binary); + std::string ret{std::istreambuf_iterator(fin), + std::istreambuf_iterator()}; + fin.close(); + return ret; +} + +std::unordered_map CipherUtils::LoadConfig( + const std::string& config_file) { + std::ifstream fin(config_file); + PADDLE_ENFORCE_EQ(fin.is_open(), true, + paddle::platform::errors::Unavailable( + "Failed to open file : %s, " + "make sure input filename is available.", + config_file)); + std::unordered_map ret; + char c; + std::string line; + std::istringstream iss; + while (std::getline(fin, line)) { + if (line.at(0) == '#') { + continue; + } + iss.clear(); + iss.str(line); + std::string key; + std::string value; + if (!(iss >> key >> c >> value) && (c == ':')) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Parse config file error, " + "check the format of configure in file %s.", + config_file)); + } + ret.insert({key, value}); + } + return ret; +} + +template <> +bool CipherUtils::GetValue( + const std::unordered_map& config, + const std::string& key, bool* output) { + auto itr = config.find(key); + if (itr == config.end()) { + return false; + } + std::istringstream iss(itr->second); + *output = false; + iss >> *output; + if (iss.fail()) { + iss.clear(); + iss >> std::boolalpha >> *output; + } + return true; +} + +const int CipherUtils::AES_DEFAULT_IV_SIZE = 96; +const int CipherUtils::AES_DEFAULT_TAG_SIZE = 128; +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/io/crypto/cipher_utils.h b/paddle/fluid/framework/io/crypto/cipher_utils.h new file mode 100644 index 00000000000..0533275798f --- /dev/null +++ b/paddle/fluid/framework/io/crypto/cipher_utils.h @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { + +class CipherUtils { + public: + CipherUtils() = default; + static std::string GenKey(int length); + static std::string GenKeyToFile(int length, const std::string& filename); + static std::string ReadKeyFromFile(const std::string& filename); + + static std::unordered_map LoadConfig( + const std::string& config_file); + + template + static bool GetValue( + const std::unordered_map& config, + const std::string& key, val_type* output); + + static const int AES_DEFAULT_IV_SIZE; + static const int AES_DEFAULT_TAG_SIZE; +}; + +template <> +bool CipherUtils::GetValue( + const std::unordered_map& config, + const std::string& key, bool* output); + +template +bool CipherUtils::GetValue( + const std::unordered_map& config, + const std::string& key, val_type* output) { + auto itr = config.find(key); + if (itr == config.end()) { + return false; + } + std::istringstream iss(itr->second); + iss >> *output; + return true; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/io/crypto/cipher_utils_test.cc b/paddle/fluid/framework/io/crypto/cipher_utils_test.cc new file mode 100644 index 00000000000..2f7d7282b8b --- /dev/null +++ b/paddle/fluid/framework/io/crypto/cipher_utils_test.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +#include "paddle/fluid/framework/io/crypto/cipher_utils.h" + +namespace paddle { +namespace framework { + +TEST(CipherUtils, load_config) { + std::string filename("cryptotest_config_file.conf"); + + std::ofstream fout(filename, std::ios::out); + fout << "# anotation test line:" + " must have two space along ':'." + << std::endl; + std::vector key_value; + key_value.emplace_back("key_str : ciphername"); + key_value.emplace_back("key_int : 1"); + key_value.emplace_back("key_bool : true"); + key_value.emplace_back("key_bool1 : false"); + key_value.emplace_back("key_bool2 : 0"); + for (auto& i : key_value) { + fout << i << std::endl; + } + fout.close(); + + auto config = CipherUtils::LoadConfig(filename); + + std::string out_str; + EXPECT_TRUE(CipherUtils::GetValue(config, "key_str", &out_str)); + EXPECT_EQ(out_str, std::string("ciphername")); + + int out_int; + EXPECT_TRUE(CipherUtils::GetValue(config, "key_int", &out_int)); + EXPECT_EQ(out_int, 1); + + bool out_bool; + EXPECT_TRUE(CipherUtils::GetValue(config, "key_bool", &out_bool)); + EXPECT_EQ(out_bool, true); + + bool out_bool1; + EXPECT_TRUE(CipherUtils::GetValue(config, "key_bool1", &out_bool1)); + EXPECT_EQ(out_bool1, false); + + bool out_bool2; + EXPECT_TRUE(CipherUtils::GetValue(config, "key_bool2", &out_bool2)); + EXPECT_EQ(out_bool2, false); +} + +TEST(CipherUtils, gen_key) { + std::string filename("test_keyfile"); + std::string key = CipherUtils::GenKey(256); + std::string key1 = CipherUtils::GenKeyToFile(256, filename); + EXPECT_NE(key, key1); + std::string key2 = CipherUtils::ReadKeyFromFile(filename); + EXPECT_EQ(key1, key2); + EXPECT_EQ(key.size(), 256 / 8); +} + +} // namespace framework +} // namespace paddle -- GitLab