From 3a1cfa9d055fab357f46e653a8786f96336f6b47 Mon Sep 17 00:00:00 2001 From: yangqingyou Date: Thu, 28 May 2020 02:49:22 +0000 Subject: [PATCH] add crypto api for python, test=develop --- CMakeLists.txt | 2 +- cmake/configure.cmake | 4 + paddle/fluid/pybind/CMakeLists.txt | 5 + paddle/fluid/pybind/crypto.cc | 136 ++++++++++++++++++ paddle/fluid/pybind/crypto.h | 23 +++ paddle/fluid/pybind/pybind.cc | 7 + .../fluid/tests/unittests/CMakeLists.txt | 4 + .../fluid/tests/unittests/test_crypto.py | 49 +++++++ 8 files changed, 229 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/pybind/crypto.cc create mode 100644 paddle/fluid/pybind/crypto.h create mode 100644 python/paddle/fluid/tests/unittests/test_crypto.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 127df949dd7..d79f3458867 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,7 +88,7 @@ option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE} option(SANITIZER_TYPE "Choose the type of sanitizer, options are: Address, Leak, Memory, Thread, Undefined" OFF) option(WITH_LITE "Compile Paddle Fluid with Lite Engine" OFF) option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON) -option(WITH_CRYPTO "Compile PaddlePaddle with paddle_crypto lib" ON) +option(WITH_CRYPTO "Compile PaddlePaddle with crypto support" ON) # PY_VERSION if(NOT PY_VERSION) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index b0ce1a4ea2d..00a593de744 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -154,3 +154,7 @@ endif(WITH_BRPC_RDMA) if(ON_INFER) add_definitions(-DPADDLE_ON_INFERENCE) endif(ON_INFER) + +if(WITH_CRYPTO) + add_definitions(-DPADDLE_WITH_CRYPTO) +endif(WITH_CRYPTO) \ No newline at end of file diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 70cd1b5d1af..6f47d312d2a 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -38,6 +38,11 @@ set(PYBIND_SRCS ir.cc inference_api.cc) +if (WITH_CRYPTO) + set(PYBIND_DEPS ${PYBIND_DEPS} paddle_crypto) + set(PYBIND_SRCS ${PYBIND_SRCS} crypto.cc) +endif (WITH_CRYPTO) + if (WITH_DISTRIBUTE) list(APPEND PYBIND_SRCS communicator_py.cc) endif() diff --git a/paddle/fluid/pybind/crypto.cc b/paddle/fluid/pybind/crypto.cc new file mode 100644 index 00000000000..8fbf395bf18 --- /dev/null +++ b/paddle/fluid/pybind/crypto.cc @@ -0,0 +1,136 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pybind/crypto.h" + +#include +#include + +#include "paddle/fluid/framework/io/crypto/aes_cipher.h" +#include "paddle/fluid/framework/io/crypto/cipher.h" +#include "paddle/fluid/framework/io/crypto/cipher_utils.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +using paddle::framework::AESCipher; +using paddle::framework::Cipher; +using paddle::framework::CipherFactory; +using paddle::framework::CipherUtils; + +namespace { + +class PyCipher : public Cipher { + public: + using Cipher::Cipher; + // encrypt string + std::string Encrypt(const std::string& plaintext, + const std::string& key) override { + PYBIND11_OVERLOAD_PURE_NAME(std::string, Cipher, "encrypt", Encrypt, + plaintext, key); + } + // decrypt string + std::string Decrypt(const std::string& ciphertext, + const std::string& key) override { + PYBIND11_OVERLOAD_PURE_NAME(std::string, Cipher, "decrypt", Decrypt, + ciphertext, key); + } + + // encrypt strings and read them to file, + void EncryptToFile(const std::string& plaintext, const std::string& key, + const std::string& filename) override { + PYBIND11_OVERLOAD_PURE_NAME(void, Cipher, "encrypt_to_file", EncryptToFile, + plaintext, key, filename); + } + // read from file and decrypt them + std::string DecryptFromFile(const std::string& key, + const std::string& filename) override { + PYBIND11_OVERLOAD_PURE_NAME(std::string, Cipher, "decrypt_from_file", + DecryptFromFile, key, filename); + } +}; + +void BindCipher(py::module* m) { + py::class_>(*m, "Cipher") + .def(py::init<>()) + .def("encrypt", + [](Cipher& c, const std::string& plaintext, const std::string& key) { + std::string ret = c.Encrypt(plaintext, key); + return py::bytes(ret); + }) + .def( + "decrypt", + [](Cipher& c, const std::string& ciphertext, const std::string& key) { + std::string ret = c.Decrypt(ciphertext, key); + return py::bytes(ret); + }) + .def("encrypt_to_file", + [](Cipher& c, const std::string& plaintext, const std::string& key, + const std::string& filename) { + c.EncryptToFile(plaintext, key, filename); + }) + .def("decrypt_from_file", + [](Cipher& c, const std::string& key, const std::string& filename) { + std::string ret = c.DecryptFromFile(key, filename); + return py::bytes(ret); + }); +} + +void BindAESCipher(py::module* m) { + py::class_>(*m, "AESCipher") + .def(py::init<>()); +} + +void BindCipherFactory(py::module* m) { + py::class_(*m, "CipherFactory") + .def(py::init<>()) + .def_static("create_cipher", + [](const std::string& config_file) { + return CipherFactory::CreateCipher(config_file); + }, + py::arg("config_file") = std::string()); +} + +void BindCipherUtils(py::module* m) { + py::class_(*m, "CipherUtils") + .def_static("gen_key", + [](int length) { + std::string ret = CipherUtils::GenKey(length); + return py::bytes(ret); + }) + .def_static("gen_key_to_file", + [](int length, const std::string& filename) { + std::string ret = + CipherUtils::GenKeyToFile(length, filename); + return py::bytes(ret); + }) + .def_static("read_key_from_file", [](const std::string& filename) { + std::string ret = CipherUtils::ReadKeyFromFile(filename); + return py::bytes(ret); + }); +} + +} // namespace + +void BindCrypto(py::module* m) { + BindCipher(m); + BindCipherFactory(m); + BindCipherUtils(m); + BindAESCipher(m); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/crypto.h b/paddle/fluid/pybind/crypto.h new file mode 100644 index 00000000000..d66aaad9193 --- /dev/null +++ b/paddle/fluid/pybind/crypto.h @@ -0,0 +1,23 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace paddle { +namespace pybind { +void BindCrypto(pybind11::module *m); +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index fc9a1c468a7..e6fbcfec017 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -91,6 +91,10 @@ limitations under the License. */ #include "paddle/fluid/pybind/communicator_py.h" #endif +#ifdef PADDLE_WITH_CRYPTO +#include "paddle/fluid/pybind/crypto.h" +#endif + #include "pybind11/stl.h" DECLARE_bool(use_mkldnn); @@ -2420,6 +2424,9 @@ All parameter, weight, gradient are variables in Paddle. BindNode(&m); BindInferenceApi(&m); BindDataset(&m); +#ifdef PADDLE_WITH_CRYPTO + BindCrypto(&m); +#endif #ifdef PADDLE_WITH_DISTRIBUTE BindCommunicator(&m); #endif diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 1cbe12c60e6..5894f72285e 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -101,6 +101,10 @@ if(WITH_GPU OR NOT WITH_MKLML) LIST(REMOVE_ITEM TEST_OPS test_matmul_op_with_head) endif() +if(NOT WITH_CRYPTO) + LIST(REMOVE_ITEM TEST_OPS test_crypto) +endif() + function(py_test_modules TARGET_NAME) if(WITH_TESTING) set(options SERIAL) diff --git a/python/paddle/fluid/tests/unittests/test_crypto.py b/python/paddle/fluid/tests/unittests/test_crypto.py new file mode 100644 index 00000000000..d903f175575 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_crypto.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.core import CipherUtils +from paddle.fluid.core import CipherFactory +from paddle.fluid.core import Cipher + +import unittest + + +class CipherUtilsTestCase(unittest.TestCase): + def test_gen_key(self): + key1 = CipherUtils.gen_key(256) + key2 = CipherUtils.gen_key_to_file(256, "/tmp/paddle_aes_test.keyfile") + self.assertNotEquals(key1, key2) + key3 = CipherUtils.read_key_from_file("/tmp/paddle_aes_test.keyfile") + self.assertEqual(key2, key3) + self.assertEqual(len(key1), 32) + self.assertEqual(len(key2), 32) + + +class CipherTestCase(unittest.TestCase): + def test_aes_cipher(self): + plaintext = "hello world" + key = CipherUtils.gen_key(256) + cipher = CipherFactory.create_cipher() + + ciphertext = cipher.encrypt(plaintext, key) + cipher.encrypt_to_file(plaintext, key, "paddle_aes_test.ciphertext") + + plaintext1 = cipher.decrypt(ciphertext, key) + plaintext2 = cipher.decrypt_from_file(key, "paddle_aes_test.ciphertext") + self.assertEqual(plaintext, plaintext1) + self.assertEqual(plaintext1, plaintext2) + + +if __name__ == '__main__': + unittest.main() -- GitLab