diff --git a/CMakeLists.txt b/CMakeLists.txt index 6df43c8ac458b0283cf819f905ddc4d0d1978e95..539390338f0345be990a76e6b1d161799a045825 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15) project(PaddleEncrypted) -add_compile_options(-msse4.2 -maes -fPIC -DPADDLE_WITH_MKLDNN) +add_compile_options(-msse4.2 -fPIC -DPADDLE_WITH_MKLDNN -O2) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(CMAKE_CXX_STANDARD 11) @@ -34,8 +34,8 @@ execute_process(COMMAND ${PYTHON} -c "import paddle;print(paddle.version.full_ve RESULT_VARIABLE ret OUTPUT_VARIABLE paddle_version OUTPUT_STRIP_TRAILING_WHITESPACE) if (NOT ret) - if (NOT ${paddle_version} STREQUAL "1.8.0") - message(FATAL_ERROR "Paddle installation of 1.8.0 is required but ${paddle_version} is found") + if (NOT ${paddle_version} STRGREATER_EQUAL "1.8.0") + message(FATAL_ERROR "Paddle installation of >= 1.8.0 is required but ${paddle_version} is found") endif() else() message(FATAL_ERROR "Could not get paddle version.") @@ -57,6 +57,10 @@ option(WITH_TESTING "Compile with unit testing" ON) option(WITH_PSI "Compile with psi lib" ON) +option(USE_AES_NI "Compile with AES NI" ON) + +option(USE_OPENMP "Compile with OpenMP" ON) + ########################### the project build part ############################### message(STATUS "Using paddlepaddle installation of ${paddle_version}") message(STATUS "paddlepaddle include directory: ${PADDLE_INCLUDE}") @@ -70,6 +74,15 @@ include_directories(.) include_directories(${PADDLE_INCLUDE}) include_directories(${PADDLE_INCLUDE}/third_party) +if (USE_AES_NI) + add_compile_definitions(USE_AES_NI) + add_compile_options(-maes) +endif (USE_AES_NI) + +if (USE_OPENMP) + add_compile_options(-fopenmp) + find_package(OpenMP REQUIRED) +endif(USE_OPENMP) add_subdirectory(core/privc3) add_subdirectory(core/paddlefl_mpc/mpc_protocol) diff --git a/core/paddlefl_mpc/data_utils/CMakeLists.txt b/core/paddlefl_mpc/data_utils/CMakeLists.txt index 15922e3bd164630255ce7fdb66d74ef669926882..d857ea0911617e4598759256b80bd5a014874da7 100644 --- a/core/paddlefl_mpc/data_utils/CMakeLists.txt +++ b/core/paddlefl_mpc/data_utils/CMakeLists.txt @@ -1,9 +1,7 @@ -add_compile_options(-msse4.2 -maes) - set(PYBIND_SRCS "./data_utils.cc" ) - + if (NOT PYTHON_INCLUDE_DIRS) find_package(PythonLibs REQUIRED) endif() diff --git a/core/paddlefl_mpc/data_utils/data_utils.cc b/core/paddlefl_mpc/data_utils/data_utils.cc index 87f6eada3b99fc64dffff1ed2b962a5af940dd42..2beec523124aa259f99838c2b318d67b0a02eb6a 100644 --- a/core/paddlefl_mpc/data_utils/data_utils.cc +++ b/core/paddlefl_mpc/data_utils/data_utils.cc @@ -1,16 +1,16 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include #include @@ -21,8 +21,8 @@ #include #include -#include "core/paddlefl_mpc/mpc_protocol/aby3_operators.h" #include "core/privc3/fixedpoint_util.h" +#include "core/paddlefl_mpc/mpc_protocol/aby3_operators.h" #include "core/psi/psi_api.h" namespace py = pybind11; @@ -30,68 +30,75 @@ namespace py = pybind11; namespace aby3 { // split plaintext into three shares. -template py::array_t share(double input) { - size_t share_num = 3; - auto shares = py::array_t(share_num); - py::buffer_info shares_buf = shares.request(); - T *shares_buf_ptr = (T *)shares_buf.ptr; - T *ret_ptr[share_num]; - for (size_t i = 0; i < share_num; ++i) { - ret_ptr[i] = &shares_buf_ptr[i]; - } - - FixedPointUtil::share(input, ret_ptr); - - return shares; +template +py::array_t share(double input) { + size_t share_num = 3; + auto shares = py::array_t(share_num); + py::buffer_info shares_buf = shares.request(); + T* shares_buf_ptr = (T*)shares_buf.ptr; + T* ret_ptr[share_num]; + for (size_t i = 0; i < share_num; ++i) { + ret_ptr[i] = &shares_buf_ptr[i]; + } + + FixedPointUtil::share(input, ret_ptr); + + return shares; } // combine three shares to reveal plaintext. -template double reveal(py::array_t shares) { - size_t share_num = 3; - py::buffer_info shares_buf = shares.request(); - T *shares_buf_ptr = (T *)shares_buf.ptr; - T *ret[share_num]; +template +double reveal(py::array_t shares) { + size_t share_num = 3; + py::buffer_info shares_buf = shares.request(); + T *shares_buf_ptr = (T *) shares_buf.ptr; + T *ret[share_num]; - for (size_t idx = 0; idx < share_num; ++idx) { - ret[idx] = &shares_buf_ptr[idx]; - } + for (size_t idx = 0; idx < share_num; ++idx) { + ret[idx] = &shares_buf_ptr[idx]; + } - double result = FixedPointUtil::reveal(ret); + double result = FixedPointUtil::reveal(ret); - return result; + return result; } // call psi_send -int send_psi(int port, const std::set &input) { - std::atomic prog(0); - return psi::psi_send(port, input, &prog); +int send_psi(int port, const std::set& input) { + std::atomic prog(0); + return psi::psi_send(port, input, &prog); } // call psi_recv -std::vector recv_psi(const std::string &remote_ip, int port, - const std::set &input) { - std::vector output; - std::atomic prog(0); - int ret = psi::psi_recv(remote_ip, port, input, &output, &prog); - if (ret != 0) { - output.clear(); +std::vector recv_psi(const std::string &remote_ip, + int port, + const std::set& input) { + std::vector output; + std::atomic prog(0); + int ret = psi::psi_recv(remote_ip, port, input, &output, &prog); + if (ret != 0) { + output.clear(); + return output; + } return output; - } - return output; } -PYBIND11_MODULE(mpc_data_utils, m) { - // optional module docstring - m.doc() = "pybind11 paddle-mpc plugin: data_utils (share, reveal, psi)"; +PYBIND11_MODULE(mpc_data_utils, m) +{ + // optional module docstring + m.doc() = "pybind11 paddle-mpc plugin: data_utils (share, reveal, psi)"; + + m.def("share", &share, + "split plaintext into three shares."); + m.def("reveal", &reveal, + "combine three shares to reveal plaintext."); - m.def("share", &share, - "split plaintext into three shares."); - m.def("reveal", &reveal, - "combine three shares to reveal plaintext."); + m.def("send_psi", &send_psi, "Send input in two party PSI."); + m.def("recv_psi", &recv_psi, "Send input and return PSI result as output in two party PSI."); - m.def("send_psi", &send_psi, "Send input in two party PSI."); - m.def("recv_psi", &recv_psi, - "Send input and return PSI result as output in two party PSI."); + m.attr("mpc_one_share") = (1 << paddle::mpc::ABY3_SCALING_FACTOR) / 3; } -} // namespace aby3 +} // namespace aby3 + + diff --git a/core/paddlefl_mpc/mpc_protocol/CMakeLists.txt b/core/paddlefl_mpc/mpc_protocol/CMakeLists.txt index 714f99df3e99c332f852e68c0884d9420edb6ee0..b56e623a4e16f15b41e0e78a6df2054ce297ec03 100644 --- a/core/paddlefl_mpc/mpc_protocol/CMakeLists.txt +++ b/core/paddlefl_mpc/mpc_protocol/CMakeLists.txt @@ -1,5 +1,3 @@ -add_compile_options(-msse4.2 -maes) - set(PROTO_SRCS "./aby3_protocol.cc" "./mesh_network.cc" @@ -17,3 +15,5 @@ target_link_libraries(mpc_protocol fluid_framework gloo hiredis privc3) cc_test(mesh_network_test SRCS mesh_network_test.cc DEPS mpc_protocol) cc_test(mpc_protocol_test SRCS mpc_protocol_test.cc DEPS mpc_protocol) cc_test(mpc_instance_test SRCS mpc_instance_test.cc DEPS mpc_protocol) + + diff --git a/core/paddlefl_mpc/mpc_protocol/aby3_operators.h b/core/paddlefl_mpc/mpc_protocol/aby3_operators.h index 9981b2a0087d6e7c2915f88d6cb9ad32b9e234a0..369b02c03a0829900d9aea722e93a7bdf12a744b 100644 --- a/core/paddlefl_mpc/mpc_protocol/aby3_operators.h +++ b/core/paddlefl_mpc/mpc_protocol/aby3_operators.h @@ -1,16 +1,16 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ // Description: implementations of each virtual op according to ABY3 protocol @@ -21,9 +21,9 @@ #include "context_holder.h" #include "mpc_operators.h" #include "paddle/fluid/framework/tensor.h" -#include "core/privc3/boolean_tensor.h" #include "core/privc3/circuit_context.h" #include "core/privc3/fixedpoint_tensor.h" +#include "core/privc3/boolean_tensor.h" #include "core/privc3/paddle_tensor.h" namespace paddle { @@ -32,259 +32,344 @@ namespace mpc { using paddle::framework::Tensor; using aby3::CircuitContext; // TODO: decide scaling factor -const size_t ABY3_SCALING_FACTOR = 16; +const size_t ABY3_SCALING_FACTOR = FIXED_POINTER_SCALING_FACTOR; using FixedTensor = aby3::FixedPointTensor; using BoolTensor = aby3::BooleanTensor; using PaddleTensor = aby3::PaddleTensor; class Aby3OperatorsImpl : public MpcOperators { public: - void add(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { - auto lhs_tuple = from_tensor(lhs); - auto rhs_tuple = from_tensor(rhs); - auto out_tuple = from_tensor(out); + void add(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { + + auto lhs_tuple = from_tensor(lhs); + auto rhs_tuple = from_tensor(rhs); + auto out_tuple = from_tensor(out); + + auto lhs_ = std::get<0>(lhs_tuple).get(); + auto rhs_ = std::get<0>(rhs_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); + + lhs_->add(rhs_, out_); + + } + + // TODO: override + void sub(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { + + auto lhs_tuple = from_tensor(lhs); + auto rhs_tuple = from_tensor(rhs); + auto out_tuple = from_tensor(out); + + auto lhs_ = std::get<0>(lhs_tuple).get(); + auto rhs_ = std::get<0>(rhs_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); + + lhs_->sub(rhs_, out_); + } + + void neg(const Tensor *op, Tensor *out) override { + + auto op_tuple = from_tensor(op); + auto out_tuple = from_tensor(out); + + auto op_ = std::get<0>(op_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); + + op_->negative(out_); + } - auto lhs_ = std::get<0>(lhs_tuple).get(); - auto rhs_ = std::get<0>(rhs_tuple).get(); - auto out_ = std::get<0>(out_tuple).get(); + void sum(const Tensor *op, Tensor *out) override { - lhs_->add(rhs_, out_); - } + auto op_tuple = from_tensor(op); + auto out_tuple = from_tensor(out); - // TODO: override - void sub(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { + auto op_ = std::get<0>(op_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); - auto lhs_tuple = from_tensor(lhs); - auto rhs_tuple = from_tensor(rhs); - auto out_tuple = from_tensor(out); + op_->sum(out_); + } - auto lhs_ = std::get<0>(lhs_tuple).get(); - auto rhs_ = std::get<0>(rhs_tuple).get(); - auto out_ = std::get<0>(out_tuple).get(); + void mul(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { - lhs_->sub(rhs_, out_); - } + auto lhs_tuple = from_tensor(lhs); + auto rhs_tuple = from_tensor(rhs); + auto out_tuple = from_tensor(out); - void neg(const Tensor *op, Tensor *out) override { + auto lhs_ = std::get<0>(lhs_tuple).get(); + auto rhs_ = std::get<0>(rhs_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); - auto op_tuple = from_tensor(op); - auto out_tuple = from_tensor(out); + lhs_->mul(rhs_, out_); + } - auto op_ = std::get<0>(op_tuple).get(); - auto out_ = std::get<0>(out_tuple).get(); + void matmul(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { - op_->negative(out_); - } + auto lhs_tuple = from_tensor(lhs); + auto rhs_tuple = from_tensor(rhs); + auto out_tuple = from_tensor(out); - void sum(const Tensor *op, Tensor *out) override { + auto lhs_ = std::get<0>(lhs_tuple).get(); + auto rhs_ = std::get<0>(rhs_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); - auto op_tuple = from_tensor(op); - auto out_tuple = from_tensor(out); + lhs_->mat_mul(rhs_, out_); + } - auto op_ = std::get<0>(op_tuple).get(); - auto out_ = std::get<0>(out_tuple).get(); + void scale(const Tensor *lhs, const double factor, Tensor *out) override { + auto lhs_tuple = from_tensor(lhs); + auto out_tuple = from_tensor(out); - op_->sum(out_); - } + auto lhs_ = std::get<0>(lhs_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); - void mul(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { + PaddleTensor scale_tensor(ContextHolder::device_ctx()); + scale_tensor.from_float_point_scalar(factor, lhs_->shape(), ABY3_SCALING_FACTOR); - auto lhs_tuple = from_tensor(lhs); - auto rhs_tuple = from_tensor(rhs); - auto out_tuple = from_tensor(out); + lhs_->mul(&scale_tensor, out_); + } - auto lhs_ = std::get<0>(lhs_tuple).get(); - auto rhs_ = std::get<0>(rhs_tuple).get(); - auto out_ = std::get<0>(out_tuple).get(); + void relu(const Tensor *op, Tensor *out) override { + auto op_tuple = from_tensor(op); + auto out_tuple = from_tensor(out); - lhs_->mul(rhs_, out_); - } + auto op_ = std::get<0>(op_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); - void matmul(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { + op_->relu(out_); + } - auto lhs_tuple = from_tensor(lhs); - auto rhs_tuple = from_tensor(rhs); - auto out_tuple = from_tensor(out); + void relu_with_derivative(const Tensor *op, Tensor *out, Tensor *derivative) override { + auto op_tuple = from_tensor(op); + auto out_tuple = from_tensor(out); + auto der_tuple = from_tensor(derivative); - auto lhs_ = std::get<0>(lhs_tuple).get(); - auto rhs_ = std::get<0>(rhs_tuple).get(); - auto out_ = std::get<0>(out_tuple).get(); + auto op_ = std::get<0>(op_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); + auto der_ = std::get<0>(der_tuple).get(); - lhs_->mat_mul(rhs_, out_); - } + op_->relu_with_derivative(out_, der_); + } - void scale(const Tensor *lhs, const double factor, Tensor *out) override { - auto lhs_tuple = from_tensor(lhs); - auto out_tuple = from_tensor(out); + void sigmoid(const Tensor *op, Tensor *out) override { + auto op_tuple = from_tensor(op); + auto out_tuple = from_tensor(out); - auto lhs_ = std::get<0>(lhs_tuple).get(); - auto out_ = std::get<0>(out_tuple).get(); + auto op_ = std::get<0>(op_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); - PaddleTensor scale_tensor(ContextHolder::device_ctx()); - scale_tensor.from_float_point_scalar(factor, lhs_->shape(), - ABY3_SCALING_FACTOR); + op_->sigmoid(out_); + } - lhs_->mul(&scale_tensor, out_); - } + void sigmoid_enhanced(const Tensor *op, Tensor *out) override { + auto op_tuple = from_tensor(op); + auto out_tuple = from_tensor(out); - void relu(const Tensor *op, Tensor *out) override { - auto op_tuple = from_tensor(op); - auto out_tuple = from_tensor(out); + auto op_ = std::get<0>(op_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); - auto op_ = std::get<0>(op_tuple).get(); - auto out_ = std::get<0>(out_tuple).get(); + op_->sigmoid_enhanced(out_); + } - op_->relu(out_); - } + void sigmoid_chebyshev(const Tensor *op, Tensor *out) override { + auto op_tuple = from_tensor(op); + auto out_tuple = from_tensor(out); - void sigmoid(const Tensor *op, Tensor *out) override { - auto op_tuple = from_tensor(op); - auto out_tuple = from_tensor(out); + auto op_ = std::get<0>(op_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); - auto op_ = std::get<0>(op_tuple).get(); - auto out_ = std::get<0>(out_tuple).get(); + op_->sigmoid_chebyshev(out_); + } - op_->sigmoid(out_); - } + void softmax(const Tensor *op, Tensor *out, bool use_relu, bool use_long_div) override { + auto op_tuple = from_tensor(op); + auto out_tuple = from_tensor(out); - void softmax(const Tensor *op, Tensor *out) override { - auto op_tuple = from_tensor(op); - auto out_tuple = from_tensor(out); + auto op_ = std::get<0>(op_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); - auto op_ = std::get<0>(op_tuple).get(); - auto out_ = std::get<0>(out_tuple).get(); + op_->softmax(out_, use_relu, use_long_div); + } - op_->softmax(out_); - } + void gt(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { - void gt(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { + auto lhs_tuple = from_tensor(lhs); - auto lhs_tuple = from_tensor(lhs); + auto lhs_ = std::get<0>(lhs_tuple).get(); - auto lhs_ = std::get<0>(lhs_tuple).get(); + PaddleTensor rhs_(ContextHolder::device_ctx()); + rhs_.from_float_point_type(*rhs, ABY3_SCALING_FACTOR); - PaddleTensor rhs_(ContextHolder::device_ctx()); - rhs_.from_float_point_type(*rhs, ABY3_SCALING_FACTOR); + PaddleTensor out_(ContextHolder::device_ctx(), *out); - PaddleTensor out_(ContextHolder::device_ctx(), *out); + auto tmp0 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape()); + auto tmp1 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape()); - auto tmp0 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape()); - auto tmp1 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape()); + BoolTensor bool_out(tmp0.get(), tmp1.get()); - BoolTensor bool_out(tmp0.get(), tmp1.get()); + lhs_->gt(&rhs_, &bool_out); - lhs_->gt(&rhs_, &bool_out); + bool_out.reveal(&out_); + } - bool_out.reveal(&out_); - } + void geq(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { + lt(lhs, rhs, out); + std::transform(out->data(), out->data() + out->numel(), + out->data(), [](int64_t b) { return 1 - b; }); + } - void geq(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { - lt(lhs, rhs, out); - std::transform(out->data(), out->data() + out->numel(), - out->data(), [](int64_t b) { return 1 - b; }); - } + void lt(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { - void lt(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { + auto lhs_tuple = from_tensor(lhs); - auto lhs_tuple = from_tensor(lhs); + auto lhs_ = std::get<0>(lhs_tuple).get(); - auto lhs_ = std::get<0>(lhs_tuple).get(); + PaddleTensor rhs_(ContextHolder::device_ctx(), *rhs); + rhs_.from_float_point_type(*rhs, ABY3_SCALING_FACTOR); - PaddleTensor rhs_(ContextHolder::device_ctx(), *rhs); - rhs_.from_float_point_type(*rhs, ABY3_SCALING_FACTOR); + PaddleTensor out_(ContextHolder::device_ctx(), *out); - PaddleTensor out_(ContextHolder::device_ctx(), *out); + auto tmp0 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape()); + auto tmp1 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape()); - auto tmp0 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape()); - auto tmp1 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape()); + BoolTensor bool_out(tmp0.get(), tmp1.get()); - BoolTensor bool_out(tmp0.get(), tmp1.get()); + lhs_->lt(&rhs_, &bool_out); - lhs_->lt(&rhs_, &bool_out); + bool_out.reveal(&out_); + } - bool_out.reveal(&out_); - } + void leq(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { + gt(lhs, rhs, out); + std::transform(out->data(), out->data() + out->numel(), + out->data(), [](int64_t b) { return 1 - b; }); + } - void leq(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { - gt(lhs, rhs, out); - std::transform(out->data(), out->data() + out->numel(), - out->data(), [](int64_t b) { return 1 - b; }); - } + void eq(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { - void eq(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { + auto lhs_tuple = from_tensor(lhs); - auto lhs_tuple = from_tensor(lhs); + auto lhs_ = std::get<0>(lhs_tuple).get(); - auto lhs_ = std::get<0>(lhs_tuple).get(); + PaddleTensor rhs_(ContextHolder::device_ctx(), *rhs); + rhs_.from_float_point_type(*rhs, ABY3_SCALING_FACTOR); - PaddleTensor rhs_(ContextHolder::device_ctx(), *rhs); - rhs_.from_float_point_type(*rhs, ABY3_SCALING_FACTOR); + PaddleTensor out_(ContextHolder::device_ctx(), *out); - PaddleTensor out_(ContextHolder::device_ctx(), *out); + auto tmp0 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape()); + auto tmp1 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape()); - auto tmp0 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape()); - auto tmp1 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape()); + BoolTensor bool_out(tmp0.get(), tmp1.get()); - BoolTensor bool_out(tmp0.get(), tmp1.get()); + lhs_->eq(&rhs_, &bool_out); - lhs_->eq(&rhs_, &bool_out); + bool_out.reveal(&out_); + } - bool_out.reveal(&out_); - } + void neq(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { + eq(lhs, rhs, out); + std::transform(out->data(), out->data() + out->numel(), + out->data(), [](int64_t b) { return 1 - b; }); + } - void neq(const Tensor *lhs, const Tensor *rhs, Tensor *out) override { - eq(lhs, rhs, out); - std::transform(out->data(), out->data() + out->numel(), - out->data(), [](int64_t b) { return 1 - b; }); - } + void relu_grad(const Tensor *y, const Tensor *dy, + Tensor *dx, float point = 0.0f) override { - void relu_grad(const Tensor *y, const Tensor *dy, Tensor *dx, - float point = 0.0f) override { + auto y_tuple = from_tensor(y); - auto y_tuple = from_tensor(y); + auto y_ = std::get<0>(y_tuple).get(); - auto y_ = std::get<0>(y_tuple).get(); + PaddleTensor point_(ContextHolder::device_ctx()); - PaddleTensor point_(ContextHolder::device_ctx()); + point_.from_float_point_scalar(point, y_->shape(), ABY3_SCALING_FACTOR); - point_.from_float_point_scalar(point, y_->shape(), - ABY3_SCALING_FACTOR); + auto tmp0 = ContextHolder::tensor_factory()->create_int64_t(y_->shape()); + auto tmp1 = ContextHolder::tensor_factory()->create_int64_t(y_->shape()); - auto tmp0 = ContextHolder::tensor_factory()->create_int64_t(y_->shape()); - auto tmp1 = ContextHolder::tensor_factory()->create_int64_t(y_->shape()); + BoolTensor bool_out(tmp0.get(), tmp1.get()); - BoolTensor bool_out(tmp0.get(), tmp1.get()); + y_->gt(&point_, &bool_out); - y_->gt(&point_, &bool_out); + auto out_tuple = from_tensor(dx); + auto out_ = std::get<0>(out_tuple).get(); - auto out_tuple = from_tensor(dx); - auto out_ = std::get<0>(out_tuple).get(); + auto dy_tuple = from_tensor(dy); + auto dy_ = std::get<0>(dy_tuple).get(); - auto dy_tuple = from_tensor(dy); - auto dy_ = std::get<0>(dy_tuple).get(); + bool_out.mul(dy_, out_); + } - bool_out.mul(dy_, out_); - } + void arith_bool_mul(const Tensor* op_a, const Tensor* op_b, Tensor* out) override { + + auto a_tuple = from_tensor(op_a); + auto a_ = std::get<0>(a_tuple).get(); + + auto b_tuple = from_tensor(op_b); + auto b_ = std::get<0>(b_tuple).get(); + + auto out_tuple = from_tensor(out); + auto out_ = std::get<0>(out_tuple).get(); + + b_->mul(a_, out_); + } + + void max_pooling(const Tensor* in, Tensor* out, Tensor* pos_info) override { + + auto a_tuple = from_tensor(in); + auto a_ = std::get<0>(a_tuple).get(); + + auto b_tuple = from_tensor(pos_info); + auto b_ = std::get<0>(b_tuple).get(); + + auto out_tuple = from_tensor(out); + auto out_ = std::get<0>(out_tuple).get(); + + a_->max_pooling(out_, b_); + } + + void inverse_square_root(const Tensor* in, Tensor* out) override { + auto x_tuple = from_tensor(in); + auto x_ = std::get<0>(x_tuple).get(); + + auto y_tuple = from_tensor(out); + auto y_ = std::get<0>(y_tuple).get(); + + x_->inverse_square_root(y_); + } private: - std::tuple, std::shared_ptr, - std::shared_ptr> - from_tensor(const Tensor *t) { + template + std::tuple< + std::shared_ptr, + std::shared_ptr, + std::shared_ptr > from_tensor(const Tensor* t) { + + PADDLE_ENFORCE_EQ(t->dims()[0], 2); + + auto pt0 = std::make_shared(ContextHolder::device_ctx(), t->Slice(0, 1)); + auto pt1 = std::make_shared(ContextHolder::device_ctx(), t->Slice(1, 2)); + + // remove leading 1 in shape + auto shape = pt0->shape(); + shape.erase(shape.begin()); + pt0->reshape(shape); + pt1->reshape(shape); + + aby3::TensorAdapter* pt_array[2] = {pt0.get(), pt1.get()}; - PADDLE_ENFORCE_EQ(t->dims()[0], 2); + auto ft = std::make_shared(pt_array); - auto pt0 = std::make_shared(ContextHolder::device_ctx(), - t->Slice(0, 1)); - auto pt1 = std::make_shared(ContextHolder::device_ctx(), - t->Slice(1, 2)); + return std::make_tuple(ft, pt0, pt1); + } - aby3::TensorAdapter *pt_array[2] = {pt0.get(), pt1.get()}; + std::tuple< + std::shared_ptr, + std::shared_ptr, + std::shared_ptr > from_tensor(const Tensor* t) { - auto ft = std::make_shared(pt_array); + return from_tensor(t); + } - return std::make_tuple(ft, pt0, pt1); - } }; } // mpc diff --git a/core/paddlefl_mpc/mpc_protocol/mpc_operators.h b/core/paddlefl_mpc/mpc_protocol/mpc_operators.h index 8fc6977512ea38b3b21d7b301fc91d1595c30c43..3064e88508988a588d132cee811a2d27d9971af6 100644 --- a/core/paddlefl_mpc/mpc_protocol/mpc_operators.h +++ b/core/paddlefl_mpc/mpc_protocol/mpc_operators.h @@ -1,16 +1,16 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ // Description: // abstract mpc operation interface @@ -24,43 +24,67 @@ namespace mpc { using paddle::framework::Tensor; +// TODO: decide scaling factor +const size_t FIXED_POINTER_SCALING_FACTOR = 16; + class MpcOperators { public: - virtual void add(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + virtual void add(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + + virtual void sub(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + + virtual void neg(const Tensor *op, Tensor *out) = 0; - virtual void sub(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + virtual void sum(const Tensor *op, Tensor *out) = 0; - virtual void neg(const Tensor *op, Tensor *out) = 0; + virtual void mul(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; - virtual void sum(const Tensor *op, Tensor *out) = 0; + virtual void matmul(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; - virtual void mul(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + virtual void scale(const Tensor *lhs, const double factor, Tensor *out) = 0; - virtual void matmul(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + virtual void relu(const Tensor *op, Tensor *out) = 0; - virtual void scale(const Tensor *lhs, const double factor, Tensor *out) = 0; + virtual void relu_with_derivative(const Tensor *op, Tensor *out, + Tensor *derivative) = 0; - virtual void relu(const Tensor *op, Tensor *out) = 0; + virtual void sigmoid(const Tensor *op, Tensor *out) = 0; - virtual void sigmoid(const Tensor *op, Tensor *out) = 0; + virtual void sigmoid_enhanced(const Tensor *op, Tensor *out) = 0; - virtual void softmax(const Tensor *op, Tensor *out) = 0; + virtual void sigmoid_chebyshev(const Tensor *op, Tensor *out) = 0; - virtual void gt(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + virtual void softmax(const Tensor *op, Tensor *out, bool use_relu, bool use_long_div) = 0; - virtual void geq(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + virtual void gt(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; - virtual void lt(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + virtual void geq(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; - virtual void leq(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + virtual void lt(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; - virtual void eq(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + virtual void leq(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; - virtual void neq(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + virtual void eq(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; - virtual void relu_grad(const Tensor *y, const Tensor *dy, Tensor *dx, - const float point) = 0; + virtual void neq(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; + + virtual void relu_grad(const Tensor *y, const Tensor *dy, Tensor *dx, const float point) = 0; + + // arithmetic tensor mult boolean tensor, element-wisely + // see [ABY3, sec 5.4.1] + // for aby3 only + // example (in plaintext): + // [1, 2, 3, 4] * [0, 0, 1, 0] = [0, 0, 3, 0] + virtual void arith_bool_mul(const Tensor* op_a, const Tensor* op_b, Tensor* out) {} + + // max pooling in which shape of filter is nx1 + // pos_info keeps which element is max in a col, for backward grad + // for filter in other shape, reshape input first + virtual void max_pooling(const Tensor* in, Tensor* out, Tensor* pos_info) {} + + virtual void inverse_square_root(const Tensor* in, Tensor* out) = 0; }; } // mpc } // paddle + diff --git a/core/paddlefl_mpc/operators/CMakeLists.txt b/core/paddlefl_mpc/operators/CMakeLists.txt index 38b110682c8b9d02b4c6e58012c1b98b50a9822f..89a2873630748bb95fb4f1ac8eb24a028d067c7d 100644 --- a/core/paddlefl_mpc/operators/CMakeLists.txt +++ b/core/paddlefl_mpc/operators/CMakeLists.txt @@ -1,7 +1,6 @@ -add_compile_options(-msse4.2 -maes) - aux_source_directory(. DIR_SRCS) -add_library(mpc_ops_o OBJECT ${DIR_SRCS}) +aux_source_directory(./math MATH_SRCS) +add_library(mpc_ops_o OBJECT ${DIR_SRCS} ${MATH_SRCS}) add_dependencies(mpc_ops_o fluid_framework gloo) add_library(mpc_ops STATIC $) diff --git a/core/paddlefl_mpc/operators/conv_op.cc b/core/paddlefl_mpc/operators/conv_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f9c829b9e588dfe97b71d93cb8cb2be555c484f --- /dev/null +++ b/core/paddlefl_mpc/operators/conv_op.cc @@ -0,0 +1,361 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "./conv_op.h" + +#include +#include +#include + +namespace paddle { +namespace operators { + +std::vector ConvOp::ComputeOutputShape( + framework::InferShapeContext* ctx) const { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv"); + OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv"); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + std::string padding_algorithm = + ctx->Attrs().Get("padding_algorithm"); + int groups = ctx->Attrs().Get("groups"); + std::vector dilations = ctx->Attrs().Get>("dilations"); + const std::string data_format = ctx->Attrs().Get("data_format"); + + // MKL-DNN Kernels are using NCHW order of dims description + // so we ignore data_format consideration for MKL-DNN kernel + const bool channel_last = (this->IsMKLDNNType() == false) && + (data_format == "NHWC" || data_format == "NDHWC"); + + PADDLE_ENFORCE_EQ( + // 1 for share dim + in_dims.size() == 4 + 1 || in_dims.size() == 5 + 1, true, + platform::errors::InvalidArgument( + "The input of Op(Conv) should be a 4-D or 5-D Tensor. But " + "received: input's dimension is %u, input's shape is [%s].", + in_dims.size(), in_dims)); + + PADDLE_ENFORCE_EQ( + in_dims.size(), filter_dims.size(), + platform::errors::InvalidArgument( + "The input's dimension and filter's dimension of " + "Op(Conv) should be equal. But received: the input's shape is [%s], " + "the input's dimension is %d; the filter's shape is [%s], " + "the filter's dimension is %d.", + in_dims, in_dims.size(), filter_dims, filter_dims.size())); + + int in_sub_stride_size = in_dims.size() - strides.size(); + PADDLE_ENFORCE_EQ( + in_dims.size(), strides.size() + 2U + 1, + platform::errors::InvalidArgument( + "The difference of input's dimension and Attr(strides)'s " + "length must be euqal to 2 for Op(Conv). " + "But received: input's dimension is %d, input's shape is [%s]; " + "Attr(stride)'s length is %d, Attr(stride) is [%s]; " + "difference of input's dimention and Attr(strides)'s length = %u.", + in_dims.size(), in_dims, strides.size(), + framework::make_ddim(strides), in_sub_stride_size)); + + const auto input_channels = + channel_last ? in_dims[in_dims.size() - 1] : in_dims[1 + 1]; + + PADDLE_ENFORCE_EQ( + input_channels, filter_dims[1 + 1] * groups, + platform::errors::InvalidArgument( + "The number of input's channels should be equal to filter's channels " + "* groups for Op(Conv). But received: the input's channels is %d, " + "the input's shape is [%s]; the filter's channels is %d, the " + "filter's shape is [%s]; the groups is %d, the data_format is %s. " + "The error may come from wrong data_format setting.", + input_channels, in_dims, filter_dims[1 + 1], filter_dims, groups, + data_format)); + PADDLE_ENFORCE_EQ( + filter_dims[0 + 1] % groups, 0, + platform::errors::InvalidArgument( + "The number of output's channels (filter's first dimension) of " + "Op(Conv) should be divided by groups. But received: " + "the output channels is %d, the filter's shape is [%s], " + "the groups is %d.", + filter_dims[0 + 1], filter_dims, groups)); + + framework::DDim in_data_dims; + if (channel_last) { + in_data_dims = framework::slice_ddim(in_dims, 1 + 1, in_dims.size() - 1); + } else { + in_data_dims = framework::slice_ddim(in_dims, 2 + 1, in_dims.size()); + } + + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2 + 1, filter_dims.size()); + + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + std::vector output_shape({in_dims[0], in_dims[1]}); + if (!channel_last) { + output_shape.push_back(filter_dims[0 + 1]); + } + for (int i = 0; i < in_data_dims.size(); ++i) { + if ((!ctx->IsRuntime()) && + (in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) { + output_shape.push_back(-1); + } else { + output_shape.push_back( + ConvOutputSize(in_data_dims[i], filter_data_dims[i], dilations[i], + paddings[2 * i], paddings[2 * i + 1], strides[i])); + } + } + if (channel_last) { + output_shape.push_back(filter_dims[1]); + } + + return output_shape; +} + +framework::OpKernelType ConvOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; + framework::LibraryType library{framework::LibraryType::kPlain}; + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); + std::string data_format = + "AnyLayout"; // todo enable data layout when it's ready + framework::DataLayout layout = framework::StringToDataLayout(data_format); + + if (input_data_type != framework::proto::VarType::INT8 && + input_data_type != framework::proto::VarType::UINT8) { + auto filter_data_type = ctx.Input("Filter")->type(); + PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, + platform::errors::InvalidArgument( + "input and filter data type should be consistent")); + } + if (input_data_type == framework::proto::VarType::FP16) { + PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN, + platform::errors::InvalidArgument( + "float16 can only be used when CUDNN is used")); + } + + auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library, customized_type_value); + return type; +} + +framework::OpKernelType ConvOp::GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); +} + +void Conv2DOpMaker::Make() { + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddInput("Input", + "(Tensor) The input tensor of convolution operator. " + "The format of input tensor is NCHW or NHWC, where N is batch size, " + "C is the " + "number of channels, H is the height of the feature, " + "and W is the width of the feature."); + AddInput("Filter", + "(Tensor) The filter tensor of convolution operator. " + "The format of the filter tensor is MCHW, where M is the number of " + "output image channels, C is the number of input image channels, " + "H is the height of the filter, and W is the width of the filter. " + "If the groups attribute is greater than 1, C equals the number of " + "input image channels divided by the groups."); + AddInput("Bias", + "(Tensor) Bias to be added to each output of filter application." + "The format of output tensor is X (one-dimensional) of size equal" + "to the number of output channels. Only used with MKL-DNN.") + .AsDispensable(); + AddOutput("Output", + "(Tensor) The output tensor of convolution operator. " + "It has same data fromat and data type as the Input."); + AddAttr>("strides", + "(vector default:{1, 1}), the " + "strides(h_stride, w_stride) of " + "convolution operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", + "(vector default:{0, 0}), the " + "paddings(pad_height_top, pad_height_bottom, " + "pad_width_left, pad_wifth_right) of " + "convolution operator.") + .SetDefault({0, 0}); + AddAttr( + "padding_algorithm", + "(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\"," + "\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. " + "Set to \"SAME\" or \"VALID\" for algorithm of padding. ") + .SetDefault("EXPLICIT"); + AddAttr( + "groups", + "(int default:1), the groups number of the convolution operator. " + "According to grouped convolution in Alex Krizhevsky's Deep CNN paper: " + "when group=2, the first half of the filters is only connected to the " + "first half of the input channels, while the second half of the filters " + "is only connected to the second half of the input channels.") + .SetDefault(1); + AddAttr>("dilations", + "(vector default:{1, 1}), the " + "dilations(h_dilation, w_dilation) of " + "convolution operator.") + .SetDefault({1, 1}); + AddAttr("use_quantizer", + "(bool, default false) " + "Set to true for operators that should be quantized and use " + "int8 kernel. " + "Only used on CPU.") + .SetDefault(false); + AddAttr("Scale_in", + "Scale_in to be used for int8 input data." + "Only used with MKL-DNN INT8.") + .SetDefault(1.0f); + AddAttr("Scale_out", + "Scale_out to be used for int8 output data." + "Only used with MKL-DNN INT8.") + .SetDefault(1.0f); + AddAttr("Scale_in_eltwise", + "Scale_in_eltwise to be used for int8 eltwise input data." + "Only used with MKL-DNN INT8.") + .SetDefault(1.0f); + AddAttr>("Scale_weights", + "Scale_weights to be used for int8 weights data." + "Only used with MKL-DNN INT8.") + .SetDefault({1.0f}); + AddAttr("force_fp32_output", + "(bool, default false) Force INT8 kernel output FP32, only " + "used in MKL-DNN INT8") + .SetDefault(false); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("NCHW"); + // TODO(dzhwinter): need to registered layout transform function + AddAttr("exhaustive_search", + "(bool, default false) cuDNN has many algorithm to calculation " + "convolution, whether enable exhaustive search " + "for cuDNN convolution or not, default is False.") + .SetDefault(false); + + AddComment(R"DOC( +Convolution Operator. + +The convolution operation calculates the output based on the input, filter +and strides, paddings, dilations, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +Input(Input) and Output(Output) are in NCHW or NHWC format. Where N is batch +size, C is the number of channels, H is the height of the feature, and W is +the width of the feature. +Filters(Input) is MCHW format format. Where M is the number of output image channels, C is +the number of input image channels, H is the height of the filter, and W +is the width of the filter. +Parameters(strides, paddings, dilations) are two elements. These two elements represent +height and width, respectively. +The input(X) size and output(Out) size may be different. + +Example: + Input: + Input shape: $(N, C_{in}, H_{in}, W_{in})$ + Filter shape: $(C_{out}, C_{in}, H_f, W_f)$ + Output: + Output shape: $(N, C_{out}, H_{out}, W_{out})$ + Where +$$ + H_{out}= \frac{(H_{in} + pad_height_top + pad_height_bottom - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\ + W_{out}= \frac{(W_{in} + pad_width_left + pad_width_right - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1 +$$ +)DOC"); + Apply(); +} + +void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } +} + +framework::OpKernelType ConvOpGrad::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; + framework::LibraryType library_{framework::LibraryType::kPlain}; + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + std::string data_format = "AnyLayout"; + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + + auto type = framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_, customized_type_value); + return type; +} + +framework::OpKernelType ConvOpGrad::GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); +} + +template +class Conv2DGradMaker : public framework::SingleGradOpMaker { +public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr op) const override { + op->SetType(this->ForwardOpType() + "_grad"); + op->SetInput("Input", this->Input("Input")); + op->SetInput("Filter", this->Input("Filter")); + op->SetInput("Bias", this->Input("Bias")); + op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); + + op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); + op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter")); + op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(mpc_conv2d, ops::ConvOp, ops::Conv2DOpMaker, + ops::ConvOpInferVarType, + ops::Conv2DGradMaker, + ops::Conv2DGradMaker); + +REGISTER_OPERATOR(mpc_conv2d_grad, ops::ConvOpGrad); + +REGISTER_OP_CPU_KERNEL( + mpc_conv2d, ops::GemmConvKernel); +REGISTER_OP_CPU_KERNEL( + mpc_conv2d_grad, + ops::GemmConvGradKernel); diff --git a/core/paddlefl_mpc/operators/conv_op.h b/core/paddlefl_mpc/operators/conv_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8338ce9093d0d8b561009b1b54c767dc13a10ba7 --- /dev/null +++ b/core/paddlefl_mpc/operators/conv_op.h @@ -0,0 +1,1048 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "./math/im2col.h" +#include "./math/vol2col.h" +#include "./math/math_function.h" +#include "mpc_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +constexpr int kConvMKLDNNFP32 = 1; +constexpr int kConvMKLDNNINT8 = 2; +constexpr int MaxKeyLength = 256; + +// Base convolution operator definations for other conv +// like operators to reuse the implementation. +inline int ConvOutputSize(int input_size, int filter_size, int dilation, + int padding, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = (input_size + 2 * padding - dkernel) / stride + 1; + PADDLE_ENFORCE_GT( + output_size, 0, + platform::errors::InvalidArgument( + "The output's size is expected to be greater than 0. " + "But recieved: output's size is %d. The output's size is computed by " + "((input_size + 2 * padding - (dilation * (filter_size - 1) + 1)) / " + "stride + 1), where input_size is %d, padding is %d, " + "filter_size is %d, dilation is %d, stride is %d.", + output_size, input_size, padding, filter_size, dilation, stride)); + + return output_size; +} + +inline int ConvOutputSize(int input_size, int filter_size, int dilation, + int padding_1, int padding_2, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = (input_size + padding_1 + padding_2 - dkernel) / stride + 1; + PADDLE_ENFORCE_GT( + output_size, 0, + platform::errors::InvalidArgument( + "The output's size is expected to be greater than 0. " + "But recieved: output's size is %d. The output's size is computed by " + "((input_size + padding_1 + padding_2 - (dilation * (filter_size - " + "1) + 1)) / stride + 1), where input_size is %d, padding is " + "(%d, %d), filter_size is %d, dilation is %d, stride is %d.", + output_size, input_size, padding_1, padding_2, filter_size, dilation, + stride)); + + return output_size; +} + +template +inline void UpdatePaddingAndDilation(std::vector* paddings, + std::vector* dilation, + const std::string& padding_algorithm, + const framework::DDim data_dims, + const std::vector& strides, + const std::vector& ksize) { + // set padding size == data_dims.size() * 2 + auto data_shape = framework::vectorize(data_dims); + if (static_cast(paddings->size()) == data_dims.size()) { + for (int i = 0; i < data_dims.size(); ++i) { + T copy_pad = *(paddings->begin() + 2 * i); + paddings->insert(paddings->begin() + 2 * i + 1, copy_pad); + } + } else { + PADDLE_ENFORCE_EQ( + data_dims.size() * 2, paddings->size(), + platform::errors::InvalidArgument( + "Attribute padding's size should be the same or twice as the " + "input's dimension. " + "But recieved: padding's size is %d, padding is [%s]; input's " + "dimension is %d, input's shape is [%s].", + paddings->size(), framework::make_ddim(*paddings), data_dims.size(), + data_dims)); + } + + // when padding_algorithm is "VALID" or "SAME" + if (padding_algorithm == "SAME") { + for (int i = 0; i < data_dims.size(); ++i) { + T out_size = (data_dims[i] + strides[i] - 1) / strides[i]; + T pad_sum = + std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], + static_cast(0)); + T pad_0 = pad_sum / 2; + T pad_1 = pad_sum - pad_0; + *(paddings->begin() + i * 2) = pad_0; + *(paddings->begin() + i * 2 + 1) = pad_1; + + // dilation + *(dilation->begin() + i) = 1; + } + + } else if (padding_algorithm == "VALID") { + for (auto it = paddings->begin(); it != paddings->end(); it++) { + *it = 0; + } + } +} + +inline bool IsExpand(const std::vector& filter_dim, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations) { + bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; + for (size_t j = 0; j < strides.size(); ++j) { + // extra 1 for share dim + filter_1 = filter_1 && (static_cast(filter_dim[j + 2 + 1]) == 1); + strides_1 = strides_1 && (strides[j] == 1); + padding_0 = padding_0 && (paddings[j] == 0); + dilation_1 = dilation_1 && (dilations[j] == 1); + } + if (paddings.size() != strides.size()) { + for (size_t j = 0; j < paddings.size(); ++j) { + padding_0 = padding_0 && (paddings[j] == 0); + } + } + return !(filter_1 && strides_1 && padding_0 && dilation_1); +} + +template +inline void ResizeToChannelFirst(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + // extra 1 for leading share dim S + int dim = input->dims().size() - 2 - 1; + if (dim == 3) { + // input + transformed_input->Resize(input->dims()); + + // SNDHWC -> NCSDHW + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[0] = input->dims()[1]; + in_dims_vec[1] = input->dims()[5]; + in_dims_vec[2] = input->dims()[0]; + in_dims_vec[3] = input->dims()[2]; + in_dims_vec[4] = input->dims()[3]; + in_dims_vec[5] = input->dims()[4]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + + } else if (dim == 2) { + // input + transformed_input->Resize(input->dims()); + + // SNHWC -> NCSHW + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[0] = input->dims()[1]; + in_dims_vec[1] = input->dims()[4]; + in_dims_vec[2] = input->dims()[0]; + in_dims_vec[3] = input->dims()[2]; + in_dims_vec[4] = input->dims()[3]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } +} + +template +inline void ResizeToChannelLast(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + // extra 1 for leading share dim S + int dim = input->dims().size() - 2 - 1; + if (dim == 3) { + // input + transformed_input->Resize(input->dims()); + + // NCSDHW -> SNDHWC + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[0] = input->dims()[2]; + in_dims_vec[1] = input->dims()[0]; + in_dims_vec[2] = input->dims()[3]; + in_dims_vec[3] = input->dims()[4]; + in_dims_vec[4] = input->dims()[5]; + in_dims_vec[5] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + + } else if (dim == 2) { + // input + transformed_input->Resize(input->dims()); + + // NCSHW -> SNHWC + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[0] = input->dims()[2]; + in_dims_vec[1] = input->dims()[0]; + in_dims_vec[2] = input->dims()[3]; + in_dims_vec[3] = input->dims()[4]; + in_dims_vec[4] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } +} + +template +inline void ResizeToShareLast(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + transformed_input->Resize(input->dims()); + + // SNC.. -> NCS.. + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[0] = input->dims()[1]; + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[0]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); +} + +template +inline void ResizeToShareFirst(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + transformed_input->Resize(input->dims()); + + // NCS.. -> SNC.. + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[0] = input->dims()[2]; + in_dims_vec[1] = input->dims()[0]; + in_dims_vec[2] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); +} + +template +inline void TransToChannelFirst(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + // extra 1 for leading share dim + // swap share and batch_size + int dim = input->dims().size() - 2 - 1; + if (dim == 3) { + auto& dev_ctx = context.template device_context(); + std::vector axis{1, 5, 0, 2, 3, 4}; + math::Transpose trans6; + trans6(dev_ctx, *input, transformed_input, axis); + + } else if (dim == 2) { + auto& dev_ctx = context.template device_context(); + std::vector axis{1, 4, 0, 2, 3}; + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + } +} + +template +inline void TransToChannelLast(const framework::ExecutionContext& context, + const Tensor* input, Tensor* transformed_input) { + // extra 1 for leading share dim + // swap share and batch_size + int dim = input->dims().size() - 2 - 1; + if (dim == 3) { + auto& dev_ctx = context.template device_context(); + std::vector axis{2, 0, 3, 4, 5, 1}; + math::Transpose trans6; + trans6(dev_ctx, *input, transformed_input, axis); + + } else if (dim == 2) { + auto& dev_ctx = context.template device_context(); + std::vector axis{2, 0, 3, 4, 1}; + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + } +} + +template +inline void TransToShareFirst(const framework::ExecutionContext& context, + const Tensor* input, Tensor* transformed_input) { + int dim = input->dims().size(); + + PADDLE_ENFORCE_GT( + dim, 4, + platform::errors::InvalidArgument( + "The input's dim is expected to be greater than 4.")); + + std::vector axis(dim); + for (size_t i = 3; i < dim; ++i) { + axis[i] = i; + } + // share + axis[0] = 2; + // N + axis[1] = 0; + // C + axis[2] = 1; + + auto& dev_ctx = context.template device_context(); + + switch(dim) { + + case 5: + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + break; + + case 6: + math::Transpose trans6; + trans6(dev_ctx, *input, transformed_input, axis); + break; + + default: + PADDLE_ENFORCE_LT( + dim, 7, platform::errors::InvalidArgument( + "The input's dim greater than 6 not supported yet. ")); + } +} + +template +inline void TransToShareLast(const framework::ExecutionContext& context, + const Tensor* input, Tensor* transformed_input) { + int dim = input->dims().size(); + + PADDLE_ENFORCE_GT( + dim, 4, + platform::errors::InvalidArgument( + "The input's dim is expected to be greater than 4.")); + + std::vector axis(dim); + for (size_t i = 3; i < dim; ++i) { + axis[i] = i; + } + // SNC -> NCS + axis[0] = 1; + axis[1] = 2; + axis[2] = 0; + + auto& dev_ctx = context.template device_context(); + + switch(dim) { + + case 5: + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + break; + + case 6: + math::Transpose trans6; + trans6(dev_ctx, *input, transformed_input, axis); + break; + + default: + PADDLE_ENFORCE_LT( + dim, 7, platform::errors::InvalidArgument( + "The input's dim greater than 6 not supported yet. ")); + } +} +template +inline void TransToBatchFirst(const framework::ExecutionContext& context, + const Tensor* input, Tensor* transformed_input) { + int dim = input->dims().size(); + + PADDLE_ENFORCE_GT( + dim, 4, + platform::errors::InvalidArgument( + "The input's dim is expected to be greater than 4.")); + + std::vector axis(dim); + for (size_t i = 3; i < dim; ++i) { + axis[i] = i; + } + // N + axis[0] = 1; + // C + axis[1] = 2; + // share + axis[2] = 0; + + auto& dev_ctx = context.template device_context(); + + switch(dim) { + + case 5: + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + break; + + case 6: + math::Transpose trans6; + trans6(dev_ctx, *input, transformed_input, axis); + break; + + default: + PADDLE_ENFORCE_LT( + dim, 7, platform::errors::InvalidArgument( + "The input's dim greater than 6 not supported yet. ")); + } +} + +template +inline void ResizeToSwapedLeadingDims(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + transformed_input->Resize(input->dims()); + + // NS.. -> SN.. + // or CS.. -> SC.. + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[0] = input->dims()[1]; + in_dims_vec[1] = input->dims()[0]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); +} + +template +void TransToSwapedLeadingDims(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* output){ + output->Resize(input->dims()); + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[0] = input->dims()[1]; + in_dims_vec[1] = input->dims()[0]; + output->Resize(framework::make_ddim(in_dims_vec)); + output->mutable_data(context.GetPlace()); + + const int dim = input->dims().size(); + + std::vector axis(dim); + for (size_t i = 0; i < dim; ++i) { + axis[i] = i; + } + axis[0] = 1; + axis[1] = 0; + + auto& dev_ctx = context.template device_context(); + + switch(dim) { + + case 3: + math::Transpose trans3; + trans3(dev_ctx, *input, output, axis); + break; + + case 4: + math::Transpose trans4; + trans4(dev_ctx, *input, output, axis); + break; + + case 5: + math::Transpose trans5; + trans5(dev_ctx, *input, output, axis); + break; + + case 6: + math::Transpose trans6; + trans6(dev_ctx, *input, output, axis); + break; + + default: + PADDLE_ENFORCE_GT( + dim, 2, platform::errors::InvalidArgument( + "The input's dim less than 3 not supported yet. ")); + PADDLE_ENFORCE_LT( + dim, 7, platform::errors::InvalidArgument( + "The input's dim greater than 6 not supported yet. ")); + } + return; +} + +template +void SharesToCols(const framework::ExecutionContext& context, + const Tensor* input, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + Tensor* col, Func data2col) { + // // input: CSHW or CSDHW, S for share dim + + framework::DDim in_plain_dim = + framework::slice_ddim(input->dims(), 1, input->dims().size()); + framework::DDim col_plain_dim = + framework::slice_ddim(col->dims(), 1, col->dims().size()); + + auto& dev_ctx = context.template device_context(); + + const int share_size = input->dims()[0]; + for (size_t i = 0; i < share_size; ++i) { + Tensor share = input->Slice(i, i + 1).Resize(in_plain_dim); + Tensor col_share = col->Slice(i, i + 1).Resize(col_plain_dim); + data2col(dev_ctx, share, dilations, strides, paddings, &col_share); + } +} + +template +Tensor SwapedLeadingDims(const framework::ExecutionContext& context, + const Tensor* input) { + Tensor output(input->type()); + + ResizeToSwapedLeadingDims(context, input, + &output); + TransToSwapedLeadingDims(context, input, + &output); + return output; +} + +template +Tensor TransposeMpcMat(const framework::ExecutionContext& context, + const Tensor* input) { + Tensor output(input->type()); + + auto in_dims_vec = framework::vectorize(input->dims()); + + PADDLE_ENFORCE_EQ( + in_dims_vec.size(), 3, platform::errors::InvalidArgument( + "The input's dim should be 3. ")); + in_dims_vec[0] = input->dims()[0]; + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[1]; + output.Resize(framework::make_ddim(in_dims_vec)); + output.mutable_data(context.GetPlace()); + + std::vector axis(3); + axis[0] = 0; + axis[1] = 2; + axis[2] = 1; + + auto& dev_ctx = context.template device_context(); + + math::Transpose trans3; + trans3(dev_ctx, *input, &output, axis); + + return output; +} + +// Define Op classes in .h file so that other conv +// operator implementations can reuse the code. +class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() final; + + protected: + virtual void Apply() {} +}; + +class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{ + {"Input", /*->*/ "Output"}}; + return m; + } +}; + +class ConvOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + std::vector output_shape = ComputeOutputShape(ctx); + + OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "Conv"); + ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); + ctx->ShareLoD("Input", "Output"); + } + + protected: + std::vector ComputeOutputShape( + framework::InferShapeContext* ctx) const; + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override; +}; + +class ConvOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override; +}; + +// TODO: add conv double grad + +template +class GemmConvKernel : public MpcOpKernel { + public: + void ComputeImpl(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + + const int groups = context.Attr("groups"); + const std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + const std::string data_format = context.Attr("data_format"); + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + Tensor transformed_input(input->type()); + Tensor transformed_output(output->type()); + + if (channel_last) { + ResizeToChannelFirst(context, input, + &transformed_input); + TransToChannelFirst(context, input, &transformed_input); + + ResizeToChannelFirst(context, output, + &transformed_output); + + } else { + ResizeToShareLast(context, input, + &transformed_input); + TransToShareLast(context, input, &transformed_input); + + ResizeToShareLast(context, output, + &transformed_output); + } + + // update padding and dilation + auto trans_in_dims = transformed_input.dims(); + auto filter_dims = filter.dims(); + + // extra 1 for share dim + framework::DDim in_data_dims = + framework::slice_ddim(trans_in_dims, 2 + 1, trans_in_dims.size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2 + 1, filter_dims.size()); + + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + auto& dev_ctx = context.template device_context(); + + const int batch_size = static_cast(transformed_input.dims()[0]); + + // filter_shape_vec: + // {k_share, k_o, k_i, k_h, k_w} or {k_share, k_o, k_i, k_d, k_h, k_w} + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + + // output_shape_vec: + // {o_n, o_c, o_share, o_h, o_w} or {o_n, o_c, o_share, o_d, o_h, o_w} + std::vector output_shape_vec( + framework::vectorize(transformed_output.dims())); + + // use col_shape in the im2col calculation + // col_shape_vec: + // {i_s, i_c/g, k_h, k_w, o_h, o_w} or {i_s, i_c/g, k_d, k_h, k_w, + // o_d, o_h, o_w} + size_t data_dim = filter_shape_vec.size() - 2 - 1; + + std::vector col_shape_vec(2 + 2 * data_dim); + col_shape_vec[0] = trans_in_dims[2]; + col_shape_vec[1] = trans_in_dims[1] / groups; + + std::vector col_matrix_shape_vec(3); + col_matrix_shape_vec[0] = col_shape_vec[0]; + col_matrix_shape_vec[1] = col_shape_vec[1]; + col_matrix_shape_vec[2] = 1; + // use col_matrix_shape in the gemm calculation + // size: + // (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d * o_h * + // o_w) + + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 2] = filter_shape_vec[j + 3]; + col_shape_vec[j + 2 + data_dim] = output_shape_vec[j + 3]; + col_matrix_shape_vec[1] *= filter_shape_vec[j + 3]; + col_matrix_shape_vec[2] *= output_shape_vec[j + 3]; + } + + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + + framework::DDim col_matrix_shape(framework::make_ddim(col_matrix_shape_vec)); + + bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); + + Tensor col; + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix; + if (is_expand) { + col = context.AllocateTmpTensor(col_shape, dev_ctx); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + // with share dim + framework::DDim in_matrix_shape = framework::slice_ddim( + transformed_input.dims(), 1, transformed_input.dims().size()); + + // SOIHW or SOIDHW + framework::DDim filter_matrix_shape = {filter.dims()[0], filter.dims()[1], + filter.numel() / (filter.dims()[0] * filter.dims()[1]) }; + filter.Resize(filter_matrix_shape); + + // OSIHW or OSIDHW + Tensor filter_ = SwapedLeadingDims(context, &filter); + + // CS(H * W) or CS(D * H * W) + framework::DDim output_matrix_shape = { + transformed_output.dims()[1], + transformed_output.dims()[2], + transformed_output.numel() / + (transformed_output.dims()[0] + * transformed_output.dims()[1] + * transformed_output.dims()[2])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(transformed_input.dims()[1]) / groups; + int out_step = static_cast(transformed_output.dims()[1]) / groups; + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = + transformed_input.Slice(i, i + 1).Resize(in_matrix_shape); + Tensor out_batch = + transformed_output.Slice(i, i + 1).Resize(output_matrix_shape); + + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + Tensor in_slice_ = SwapedLeadingDims(context, &in_slice); + + if (!is_expand) { + col.ShareDataWith(in_slice_); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + SharesToCols(context, &in_slice_, dilations, strides, + std::vector{paddings[0], paddings[2], paddings[1], + paddings[3]}, &col, im2col); + } else if (data_dim == 3U) { + SharesToCols(context, &in_slice_, dilations, strides, paddings, &col, vol2col); + } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter_.Slice(g * out_step, (g + 1) * out_step); + Tensor out_slice_ = SwapedLeadingDims(context, &out_slice); + Tensor filter_slice_ = SwapedLeadingDims(context, &filter_slice); + + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul( + &filter_slice_, &col_matrix, &out_slice_); + + TransToSwapedLeadingDims(context, &out_slice_, + &out_slice); + } + } + if (channel_last) { + TransToChannelLast(context, &transformed_output, + output); + } else { + TransToShareFirst(context, &transformed_output, + output); + } + } +}; + +template +class GemmConvGradKernel : public MpcOpKernel { + public: + void ComputeImpl(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + // The filter and filter_grad will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + + if (!input_grad && !filter_grad) return; + + int groups = context.Attr("groups"); + const std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + const std::string data_format = context.Attr("data_format"); + + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + Tensor transformed_input(input->type()); + Tensor transformed_output_grad(output_grad->type()); + + if (channel_last) { + ResizeToChannelFirst(context, input, + &transformed_input); + TransToChannelFirst(context, input, &transformed_input); + + ResizeToChannelFirst(context, output_grad, + &transformed_output_grad); + TransToChannelFirst(context, output_grad, + &transformed_output_grad); + } else { + ResizeToShareLast(context, input, + &transformed_input); + TransToShareLast(context, input, &transformed_input); + ResizeToShareLast(context, output_grad, + &transformed_output_grad); + TransToShareLast(context, output_grad, &transformed_output_grad); + } + + // update padding and dilation + auto in_dims = transformed_input.dims(); + auto filter_dims = filter.dims(); + // extra 1 for share dim + framework::DDim in_data_dims = + framework::slice_ddim(in_dims, 2 + 1, in_dims.size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter_dims, 2 + 1, filter_dims.size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + + const int batch_size = static_cast(transformed_input.dims()[0]); + + auto& dev_ctx = context.template device_context(); + + // filter_shape_vec: {k_share, k_o, k_i, k_h, k_w} or {k_share, k_o, k_i, k_d, k_h, k_w} + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + // output_shape_vec: {o_n, o_c, o_share, o_h, o_w} or {o_n, o_c, o_share, o_d, o_h, o_w} + std::vector output_shape_vec( + framework::vectorize(transformed_output_grad.dims())); + + // use col_shape in the im2col calculation + // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, + // o_h, o_w} + size_t data_dim = filter_shape_vec.size() - 2 - 1; + std::vector col_shape_vec(2 + 2 * data_dim); + col_shape_vec[0] = in_dims[2]; + col_shape_vec[1] = in_dims[1] / groups; + + std::vector col_matrix_shape_vec(3); + col_matrix_shape_vec[0] = col_shape_vec[0]; + col_matrix_shape_vec[1] = col_shape_vec[1]; + col_matrix_shape_vec[2] = 1; + // use col_matrix_shape in the gemm calculation + // size: + // (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d * o_h * + // o_w) + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 2] = filter_shape_vec[j + 3]; + col_shape_vec[j + 2 + data_dim] = output_shape_vec[j + 3]; + col_matrix_shape_vec[1] *= filter_shape_vec[j + 3]; + col_matrix_shape_vec[2] *= output_shape_vec[j + 3]; + } + + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + framework::DDim col_matrix_shape(framework::make_ddim(col_matrix_shape_vec)); + + // with share dim + framework::DDim input_shape = framework::slice_ddim( + transformed_input.dims(), 1, transformed_input.dims().size()); + + // SOIHW or SOIDHW + framework::DDim filter_matrix_shape = {filter.dims()[0], filter.dims()[1], + filter.numel() / (filter.dims()[0] * filter.dims()[1]) }; + + // OSIHW or OSIDHW + framework::DDim filter_matrix_shape_ = {filter.dims()[1], filter.dims()[0], + filter.numel() / (filter.dims()[0] * filter.dims()[1]) }; + filter.Resize(filter_matrix_shape); + + Tensor filter_ = SwapedLeadingDims(context, &filter); + + // CS(H * W) or CS(D * H * W) + framework::DDim output_matrix_shape = { + transformed_output_grad.dims()[1], + transformed_output_grad.dims()[2], + transformed_output_grad.numel() / + (transformed_output_grad.dims()[0] + * transformed_output_grad.dims()[1] + * transformed_output_grad.dims()[2])}; + // convolution backward input operator: gemm + col2im(or col2vol) + // convolution backward weight operator: im2col(or vol2col) + gemm + int in_step = static_cast(transformed_input.dims()[1]) / groups; + int out_step = static_cast(transformed_output_grad.dims()[1]) / groups; + + bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); + + Tensor col; + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix; + if (is_expand) { + col = context.AllocateTmpTensor(col_shape, dev_ctx); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + math::SetConstant set_zero; + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + Tensor transformed_input_grad(input_grad->type()); + if (channel_last) { + ResizeToChannelFirst(context, input_grad, + &transformed_input_grad); + + } else { + ResizeToShareLast(context, input_grad, + &transformed_input_grad); + } + // if is_expand is false, the operation of set_zero is unnecessary, + // because math::matmul will reset input_grad. + if (is_expand) { + set_zero(dev_ctx, &transformed_input_grad, static_cast(0)); + } + math::Col2VolFunctor col2vol; + math::Col2ImFunctor col2im; + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + transformed_output_grad.Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = + transformed_input_grad.Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter_.Slice(g * out_step, (g + 1) * out_step); + + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + + Tensor in_grad_slice_ = SwapedLeadingDims(context, &in_grad_slice); + if (!is_expand) { + col_matrix.ShareDataWith(in_grad_slice); + col_matrix.Resize(col_matrix_shape); + } + Tensor filter_slice_ = SwapedLeadingDims(context, &filter_slice); + Tensor out_grad_slice_ = SwapedLeadingDims(context, &out_grad_slice); + Tensor filter_slice_t = TransposeMpcMat(context, &filter_slice_); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul( + &filter_slice_t, &out_grad_slice_, &col_matrix); + + if (is_expand && data_dim == 2U) { + SharesToCols(context, &col, dilations, strides, + std::vector{paddings[0], paddings[2], paddings[1], + paddings[3]}, + &in_grad_slice_, col2im); + } else if (is_expand && data_dim == 3U) { + SharesToCols(context, &col, dilations, strides, paddings, &in_grad_slice_, col2vol); + } + TransToSwapedLeadingDims(context, &in_grad_slice_, + &in_grad_slice); + } + } + if (channel_last) { + TransToChannelLast(context, &transformed_input_grad, + input_grad); + } else { + TransToShareFirst(context, &transformed_input_grad, + input_grad); + } + } + + if (filter_grad) { + filter_grad->mutable_data(context.GetPlace()); + auto filter_grad_dims = filter_grad->dims(); + + Tensor filter_grad_ = SwapedLeadingDims(context, filter_grad); + filter_grad_.Resize(filter_matrix_shape_); + + set_zero(dev_ctx, filter_grad, static_cast(0)); + math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + transformed_output_grad.Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = transformed_input.Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // im2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + Tensor in_slice_ = SwapedLeadingDims(context, &in_slice); + if (!is_expand) { + col.ShareDataWith(in_slice_); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + SharesToCols(context, &in_slice_, dilations, strides, + std::vector{paddings[0], paddings[2], paddings[1], + paddings[3]}, &col, im2col); + + } else if (data_dim == 3U) { + SharesToCols(context, &in_slice_, dilations, strides, paddings, &col, vol2col); + } + + Tensor out_grad_slice_ = SwapedLeadingDims(context, &out_grad_slice); + Tensor col_mat_t = TransposeMpcMat(context, &col_matrix); + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_grad_slice_ = SwapedLeadingDims(context, &filter_grad_slice); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(&out_grad_slice_, &col_mat_t, &filter_grad_slice_); + TransToSwapedLeadingDims(context, &filter_grad_slice_, + &filter_grad_slice); + } + } + TransToSwapedLeadingDims(context, &filter_grad_, + filter_grad); + filter_grad->Resize(filter_grad_dims); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/math/concat_and_split.cc b/core/paddlefl_mpc/operators/math/concat_and_split.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ad2a07f5e8a2a93b3b3f7391048da171e90c69f --- /dev/null +++ b/core/paddlefl_mpc/operators/math/concat_and_split.cc @@ -0,0 +1,120 @@ +/* Copyright (c) 2020 paddlepaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "core/paddlefl_mpc/operators/math/concat_and_split.h" +#include + +namespace paddle { +namespace operators { +namespace math { + +/* + * All tensors' dimension should be the same and the values of + * each dimension must be the same, except the axis dimension. + */ +template +class ConcatFunctor { +public: + void operator()(const platform::CPUDeviceContext& context, + const std::vector& input, int axis, + framework::Tensor* output) { + // TODO(zcd): Add input data validity checking + int num = input.size(); + + int rows = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(input.size()); + for (int i = 0; i < num; ++i) { + int t_cols = input[i].numel() / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + auto cpu_place = boost::get(context.GetPlace()); + + // computation + auto output_data = output->data(); + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + auto input_data = input[j].data(); + for (int k = 0; k < out_rows; ++k) { + memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place, + input_data + k * col_len, sizeof(T) * col_len); + } + col_idx += col_len; + } + } +}; + +/* + * All tensors' dimension should be the same and the values of + * each dimension must be the same, except the axis dimension. + */ +template +class SplitFunctor { +public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, + const std::vector& ref_inputs, + const int axis, std::vector* outputs) { + // TODO(zcd): Add input data validity checking + size_t num = outputs->size(); + + int input_rows = 1; + auto dim_0 = ref_inputs[0]->dims(); + for (int i = 0; i < axis; ++i) { + input_rows *= dim_0[i]; + } + + int input_cols = 0; + + std::vector output_cols(outputs->size()); + for (size_t i = 0; i < num; ++i) { + int t_cols = ref_inputs[i]->numel() / input_rows; + input_cols += t_cols; + output_cols[i] = t_cols; + } + auto cpu_place = boost::get(context.GetPlace()); + + // computation + for (int k = 0; k < input_rows; ++k) { + const T* src_ptr = input.data() + k * input_cols; + int col_idx = 0; + for (size_t j = 0; j < num; ++j) { + int col_len = output_cols[j]; + auto* out_tensor = outputs->at(j); + if (out_tensor != nullptr) { + T* dst_ptr = out_tensor->data() + k * col_len; + memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx, + sizeof(T) * col_len); + } + col_idx += col_len; + } + } + } +}; +#define DEFINE_FUNCTOR(type) \ + template class ConcatFunctor; \ + template class SplitFunctor; + +FOR_ALL_TYPES(DEFINE_FUNCTOR); + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/math/concat_and_split.h b/core/paddlefl_mpc/operators/math/concat_and_split.h new file mode 100644 index 0000000000000000000000000000000000000000..2d73353f10cb44b3f210d6f86542e6528ec27e10 --- /dev/null +++ b/core/paddlefl_mpc/operators/math/concat_and_split.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * \brief Concatenate the input tensors along the dimension axis. + * TODO(zcd): maybe it needs to be more detailed. + * Examples: + * Input[0] = [[1,2],[3,4]] + * Input[1] = [[5,6]] + * axis = 0 + * + * Output = [[1,2], + * [3,4], + * [5,6]] + */ +template +class ConcatFunctor { +public: + void operator()(const DeviceContext& context, + const std::vector& input, int axis, + framework::Tensor* output); +}; + +/* + * \brief Split the input tensors along the dimension axis into outputs. + * TODO(zcd): maybe it needs to be more detailed. + * Examples: + * Input = [[1,2], + * [3,4], + * [5,6]] + * axis = 0 + * + * Output[0] = [[1,2],[3,4]] + * Output[1] = [[5,6]] + */ +template +class SplitFunctor { +public: + void operator()(const DeviceContext& context, const framework::Tensor& input, + const std::vector& ref_inputs, + int axis, std::vector* outputs); +}; + +} // namespace math +} // namespace operators +} // namespace paddle + +#define FOR_ALL_TYPES(macro) \ + macro(int64_t); \ + diff --git a/core/paddlefl_mpc/operators/math/im2col.cc b/core/paddlefl_mpc/operators/math/im2col.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca2ab1d617912b96b2d2d2e0cd3840cc71cbb8a2 --- /dev/null +++ b/core/paddlefl_mpc/operators/math/im2col.cc @@ -0,0 +1,284 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "./im2col.h" +#include +#include "./im2col_cfo_cpu.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * im = [input_channels, input_height, input_width] + * col = + * [input_channels, filter_height, filter_width, output_height, output_width] + */ +template +class Im2ColFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col, + const DataLayout data_layout) { + PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3."); + PADDLE_ENFORCE_EQ(col->dims().size(), 5, + "The dimension of col should be 5."); + + if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 && + dilation[1] == 1) { + if (padding[0] == 0 && padding[1] == 0 && padding[2] == 0 && + padding[3] == 0) { + im2col_sh1sw1dh1dw1ph0pw0(im, col, data_layout); + return; + } else if (padding[0] == 1 && padding[1] == 1 && padding[2] == 1 && + padding[3] == 1) { + im2col_sh1sw1dh1dw1ph1pw1(im, col, data_layout); + return; + } + // TODO(TJ): complete padding >=2 + } + im2col_common(im, dilation, stride, padding, col, data_layout); + } +}; + +/* + * im = [input_channels, input_height, input_width] + * col = + * [input_channels, filter_height, filter_width, output_height, output_width] + */ +template +class Col2ImFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im, + const DataLayout data_layout) { + PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3."); + PADDLE_ENFORCE_EQ(col.dims().size(), 5, + "The dimension of col should be 5."); + int im_channels = + (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]); + int im_height = + (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]); + int im_width = + (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]); + int filter_height = col.dims()[1]; + int filter_width = col.dims()[2]; + int col_height = col.dims()[3]; + int col_width = col.dims()[4]; + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + ((dilation[0] * (filter_height - 1) + 1))) / + stride[0] + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + ((dilation[1] * (filter_width - 1) + 1))) / + stride[1] + + 1, + col_width, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + + int channels_col = im_channels * filter_height * filter_width; + + T* im_data = im->data(); + const T* col_data = col.data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int c_im = c / (filter_width * filter_height); + for (int h = 0; h < col_height; ++h) { + int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; + for (int w = 0; w < col_width; ++w) { + int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; + if ((im_row_idx) >= 0 && (im_row_idx) < im_height && + (im_col_idx) >= 0 && (im_col_idx) < im_width) { + int im_offset; + if (data_layout != DataLayout::kNHWC) { + im_offset = + (c_im * im_height + im_row_idx) * im_width + im_col_idx; + } else { + im_offset = + (im_row_idx * im_width + im_col_idx) * im_channels + c_im; + } + im_data[im_offset] += + col_data[(c * col_height + h) * col_width + w]; + } + } + } + } + } +}; + +template class Im2ColFunctor; +template class Col2ImFunctor; + +/* + * im = [input_channels, input_height, input_width] + * col = + * [output_height, output_width, input_channels, filter_height, filter_width] + */ +template +class Im2ColFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col, + const DataLayout data_layout) { + PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3."); + PADDLE_ENFORCE_EQ(col->dims().size(), 5, + "The dimension of col should be 5."); + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; + int filter_height = col->dims()[3]; + int filter_width = col->dims()[4]; + int col_height = col->dims()[0]; + int col_width = col->dims()[1]; + + const T* im_data = im.data(); + T* col_data = col->data(); + + for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { + for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { + for (int channel = 0; channel < im_channels; ++channel) { + for (int filter_row_idx = 0; filter_row_idx < filter_height; + ++filter_row_idx) { + int im_row_offset = + col_row_idx * stride[0] + filter_row_idx - padding[0]; + for (int filter_col_idx = 0; filter_col_idx < filter_width; + ++filter_col_idx) { + int im_col_offset = + col_col_idx * stride[1] + filter_col_idx - padding[1]; + + int col_offset = + ((((col_row_idx)*col_width + col_col_idx) * im_channels + + channel) * + filter_height + + filter_row_idx) * + filter_width + + filter_col_idx; + + int im_offset = (channel * im_height + im_row_offset) * im_width + + im_col_offset; + col_data[col_offset] = + (im_row_offset < 0 || im_row_offset >= im_height || + im_col_offset < 0 || im_col_offset >= im_width) + ? static_cast(0) + : im_data[im_offset]; + } + } + } + } + } + } +}; + +/* + * im = [input_channels, input_height, input_width] + * col = + * [output_height, output_width, input_channels, filter_height, filter_width] + */ +template +class Col2ImFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im, + const DataLayout data_layout) { + PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3."); + PADDLE_ENFORCE_EQ(col.dims().size(), 5, + "The dimension of col should be 5."); + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; + int filter_height = col.dims()[3]; + int filter_width = col.dims()[4]; + int col_height = col.dims()[0]; + int col_width = col.dims()[1]; + + PADDLE_ENFORCE_EQ( + (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); + + T* im_data = im->data(); + const T* col_data = col.data(); + + for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { + for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { + for (int channel = 0; channel < im_channels; ++channel) { + for (int filter_row_idx = 0; filter_row_idx < filter_height; + ++filter_row_idx) { + int im_row_offset = + col_row_idx * stride[0] + filter_row_idx - padding[0]; + for (int filter_col_idx = 0; filter_col_idx < filter_width; + ++filter_col_idx) { + int im_col_offset = + col_col_idx * stride[1] + filter_col_idx - padding[1]; + + int col_offset = + (((col_row_idx * col_width + col_col_idx) * im_channels + + channel) * + filter_height + + filter_row_idx) * + filter_width + + filter_col_idx; + + if (im_row_offset >= 0 && im_row_offset < im_height && + im_col_offset >= 0 && im_col_offset < im_width) { + int im_offset = + (channel * im_height + im_row_offset) * im_width + + im_col_offset; + im_data[im_offset] += col_data[col_offset]; + } + } + } + } + } + } + } +}; + +template class Im2ColFunctor; +template class Col2ImFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/math/im2col.h b/core/paddlefl_mpc/operators/math/im2col.h new file mode 100644 index 0000000000000000000000000000000000000000..3865443170481de53ea4679d43075e14d386bb71 --- /dev/null +++ b/core/paddlefl_mpc/operators/math/im2col.h @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +using DataLayout = framework::DataLayout; + +/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */ +enum class ColFormat { kCFO = 0, kOCF = 1 }; + +/* + * \brief Converts the image data of three dimensions(CHW) into a colData of + * five dimensions in the Im2ColFunctor calculation, + * And in the Col2ImFunctor calculation, it is reversed. + * + * \param imData Image data. + * \param imShape The shape of imData, + * [input_channels, input_height, input_width]. + * \param colData Column data. + * \param colShape The shape of colData. + * + * \param dilations dilation data. + * \param 2-dimension [dilation_height, dilation_width]. + * + * \param strides stride data. + * \param 2-dimension [stride_height, stride_width]. + * + * \param paddings padding data. + * \param 4-dimension [up_pad, left_pad, down_pad, right_pad]. + * + * If the template argument Format is kCFO, the shape of colData is: + * [input_channels, filter_height, filter_width, output_height, output_width] + * So, it is easy to reshape into a convolution matrix for convolution + * calculation based on matrix multiplication. + * The shape of convolution matrix is [height, width], where the height is equal + * input_channels * filter_height * filter_width, and the width is equal + * output_height * output_width. + * + * Reshape: + * shape of colData shape of convolution matrix + * [input_channels, + * filter_height, + * filter_width, ======> [height, width] + * output_height, + * output_width] + * + * If the template argument Format is kOCF, the shape of colData is: + * [output_height, output_width, input_channels, filter_height, filter_width] + * So, it is easy to reshape into a sequence matrix for rnn calculation. + * The shape of sequence matrix is [seq_length, step_size], where the seq_length + * is equal output_height * output_width, and the step_size is equal + * input_channels * filter_height * filter_width. + * + * Reshape: + * shape of colData shape of sequence matrix + * [output_height, + * output_width, + * input_channels, ======> [seqLength, stepSize] + * filter_height, + * filter_width] + * + * \note The caller needs to ensure that imShape.inputChannels is equal to + * colShape.inputChannels. + */ +template +class Im2ColFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& im, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col, + const DataLayout data_layout = DataLayout::kNCHW); +}; + +template +class Col2ImFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im, + const DataLayout data_layout = DataLayout::kNCHW); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/math/im2col_cfo_cpu.h b/core/paddlefl_mpc/operators/math/im2col_cfo_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..01f1e220e65d9de00de5b82d2f7c278494be8f32 --- /dev/null +++ b/core/paddlefl_mpc/operators/math/im2col_cfo_cpu.h @@ -0,0 +1,317 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +/** + * The most common im2col algorithm. + * Support dilation, stride and padding. + */ +template +inline void im2col_common(const framework::Tensor& im, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, + framework::Tensor* col, + const DataLayout data_layout = DataLayout::kNCHW) { + int im_channels = + (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]); + int im_height = + (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]); + int im_width = + (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]); + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int output_height = col->dims()[3]; + int output_width = col->dims()[4]; + int channels_col = im_channels * filter_height * filter_width; + + const T* im_data = im.data(); + T* col_data = col->data(); + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int c_im = c / (filter_width * filter_height); + for (int h = 0; h < output_height; ++h) { + int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; + for (int w = 0; w < output_width; ++w) { + int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; + int im_idx; + if (data_layout != DataLayout::kNHWC) { + im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; + } else { + im_idx = (im_row_idx * im_width + im_col_idx) * im_channels + c_im; + } + int col_idx = (c * output_height + h) * output_width + w; + + col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || + im_col_idx < 0 || im_col_idx >= im_width) + ? static_cast(0) + : im_data[im_idx]; + } + } + } +} + +/** + * im2col algorithm with strides == 1, dilations == 1, paddings == 0 + */ +template +inline void im2col_sh1sw1dh1dw1ph0pw0( + const framework::Tensor& im, framework::Tensor* col, + const DataLayout data_layout = DataLayout::kNCHW) { + int im_channels = + (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]); + int im_height = + (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]); + int im_width = + (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]); + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int output_height = col->dims()[3]; + int output_width = col->dims()[4]; + + const T* im_data = im.data(); + T* col_data = col->data(); + int col_matrix_width = output_width * output_height; + int im_size = im_height * im_width; + size_t copy_size = sizeof(T) * output_width; + const T* im_data_oh = im_data; + T* dst_data_oh = col_data; + for (int oh = 0; oh < output_height; ++oh) { + const T* src_data_ic = im_data_oh; + T* dst_data = dst_data_oh; + for (int ic = 0; ic < im_channels; ++ic) { + const T* src_data = src_data_ic; + for (int kh = 0; kh < filter_height; ++kh) { + for (int kw = 0; kw < filter_width; ++kw) { + if (data_layout != DataLayout::kNHWC) { + std::memcpy(dst_data, src_data + kw, copy_size); + } else { + for (int kow = 0; kow < output_width; ++kow) { + dst_data[kow] = + im_data[((oh + kh) * im_width + kw + kow) * im_channels + ic]; + } + } + dst_data = dst_data + col_matrix_width; + } + src_data = src_data + im_width; + } + src_data_ic = src_data_ic + im_size; + } + im_data_oh = im_data_oh + im_width; + dst_data_oh = dst_data_oh + output_width; + } +} + +/** + * im2col algorithm with strides == 1, dilations == 1, paddings == 1 + * and filter_width == 1 have a special implementation + */ +template +inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, + framework::Tensor* col, + const DataLayout data_layout) { + int im_channels = + (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]); + int im_height = + (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]); + int im_width = + (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]); + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int output_height = col->dims()[3]; + int output_width = col->dims()[4]; + + constexpr int plh = 1; + constexpr int prh = 1; + constexpr int plw = 1; + constexpr int prw = 1; + + const T* im_data = im.data(); + T* col_data = col->data(); + int im_size = im_height * im_width; + int col_matrix_width = output_width * output_height; + int col_block_fh = filter_width * col_matrix_width; // fw*oh*ow + int col_block_ic = filter_height * col_block_fh; // fh*fw*oh*ow + + // fill height padding + { + size_t copy_size = sizeof(T) * output_width; + T* col_start_l = col_data; + T* col_start_r = col_data + (filter_height - 1) * col_block_fh + + col_matrix_width - output_width; + for (int ic = 0; ic < im_channels; ++ic) { + T* dst_data_l = col_start_l; + T* dst_data_r = col_start_r; + for (int kw = 0; kw < filter_width; ++kw) { + std::memset(dst_data_l, 0, copy_size); + std::memset(dst_data_r, 0, copy_size); + dst_data_l = dst_data_l + col_matrix_width; + dst_data_r = dst_data_r + col_matrix_width; + } + col_start_l = col_start_l + col_block_ic; + col_start_r = col_start_r + col_block_ic; + } + } + + auto pad = static_cast(0); + if (filter_width == 1) { + // fill width padding + T* dst_data_ic = col_data; + for (int ic = 0; ic < im_channels; ++ic) { + T* dst_data_kh = dst_data_ic; + for (int kh = 0; kh < filter_height; ++kh) { + T* dst_data = dst_data_kh; + for (int oh = 0; oh < output_height; ++oh) { + *dst_data = pad; + dst_data = dst_data + output_width - 1; + *dst_data = pad; + ++dst_data; + } + dst_data_kh = dst_data_kh + col_block_fh; + } + dst_data_ic = dst_data_ic + col_block_ic; + } + // fill core + size_t copy_size = sizeof(T) * (output_width - plw - prw); + for (int oh = 0; oh < output_height; ++oh) { + const T* im_data_start = + im_data + (oh - plh > 0 ? oh - plh : 0) * im_width; + T* dst_data = col_data + oh * output_width; + for (int ic = 0; ic < im_channels; ++ic) { + const T* src_data = im_data_start + ic * im_size; + for (int kh = 0; kh < filter_height; ++kh) { + if ((oh < plh && kh < plh) || (oh > (output_height - prh - 1) && + kh > (filter_height - prh - 1))) { + dst_data = dst_data + col_matrix_width; + continue; + } + if (data_layout != DataLayout::kNHWC) { + std::memcpy(dst_data + plw, src_data, copy_size); + } else { + for (int kow = 0; kow < output_width - plw - prw; ++kow) { + dst_data[plw + kow] = + im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width + + kow) * + im_channels + + ic]; + } + } + dst_data = dst_data + col_matrix_width; + src_data = src_data + im_width; + } + } + } + return; + } + + // filter_width != 1 + // fill width padding + T* dst_data_ic = col_data; + for (int ic = 0; ic < im_channels; ++ic) { + T* dst_data_kh = dst_data_ic; + for (int kh = 0; kh < filter_height; ++kh) { + for (T* dst_data : + {dst_data_kh, dst_data_kh + (filter_width - prw) * col_matrix_width + + output_width - 1}) { + // TODO(TJ): from plh, saving repeated assignment + for (int oh = 0; oh < output_height; ++oh) { + *dst_data = pad; + dst_data = dst_data + output_width; + } + } + dst_data_kh = dst_data_kh + col_block_fh; + } + dst_data_ic = dst_data_ic + col_block_ic; + } + + // TODO(TJ): use array like: size_t copy_size[kw]={sizeof(T) * + // (output_width-1)} + // length of copy_size is equal kw. + for (int oh = 0; oh < output_height; ++oh) { + const T* im_data_start = im_data + (oh - plh > 0 ? oh - plh : 0) * im_width; + T* dst_data = col_data + oh * output_width; + for (int ic = 0; ic < im_channels; ++ic) { + const T* src_data = im_data_start + ic * im_size; + for (int kh = 0; kh < filter_height; ++kh) { + if ((oh < plh && kh < plh) || (oh > (output_height - prh - 1) && + kh > (filter_height - prh - 1))) { + dst_data = dst_data + filter_width * col_matrix_width; + continue; + } + // TODO(TJ): reuse plw-kw outside this for + // try to unify + for (int kw = 0; kw < plw; ++kw) { + if (data_layout != DataLayout::kNHWC) { + std::memcpy(dst_data + (plw - kw), src_data, + sizeof(T) * (output_width - (plw - kw))); + } else { + for (int kow = 0; kow < output_width - (plw - kw); ++kow) { + dst_data[plw - kw + kow] = + im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width + + kow) * + im_channels + + ic]; + } + } + dst_data = dst_data + col_matrix_width; + } + for (int kw = plw; kw < filter_width - prw; ++kw) { + if (data_layout != DataLayout::kNHWC) { + std::memcpy(dst_data, src_data + (kw - plw), + sizeof(T) * output_width); + } else { + for (int kow = 0; kow < output_width; ++kow) { + dst_data[kow] = + im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width + + kw - plw + kow) * + im_channels + + ic]; + } + } + dst_data = dst_data + col_matrix_width; + } + int i = 1; + for (int kw = filter_width - prw; kw < filter_width; ++kw, ++i) { + if (data_layout != DataLayout::kNHWC) { + std::memcpy(dst_data, src_data + (kw - plw), + sizeof(T) * (output_width - i)); + } else { + for (int kow = 0; kow < output_width - i; ++kow) { + dst_data[kow] = + im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width + + kw - plw + kow) * + im_channels + + ic]; + } + } + dst_data = dst_data + col_matrix_width; + } + src_data = src_data + im_width; + } + } + } +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/math/math_function.cc b/core/paddlefl_mpc/operators/math/math_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..adf501c44a092dd3e12327623ef4bd535e1e767b --- /dev/null +++ b/core/paddlefl_mpc/operators/math/math_function.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "core/paddlefl_mpc/operators/math/math_function.h" + +#include +#include "paddle/fluid/framework/data_type.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct RowwiseAdd { + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& vector, framework::Tensor* output) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector.numel(), size); + PADDLE_ENFORCE_EQ(output->dims(), in_dims); + + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenVector::Flatten(vector); + auto out = framework::EigenMatrix::From(*output); + + for (int64_t i = 0; i < in_dims[0]; ++i) { + out.chip(i, 0) = in.chip(i, 0) + vec; + } + } +}; + +template struct RowwiseAdd; + +using float16 = paddle::platform::float16; + +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; + +#define DEFINE_CPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; + +DEFINE_CPU_TRANS(1); +DEFINE_CPU_TRANS(2); +DEFINE_CPU_TRANS(3); +DEFINE_CPU_TRANS(4); +DEFINE_CPU_TRANS(5); +DEFINE_CPU_TRANS(6); + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/math/math_function.h b/core/paddlefl_mpc/operators/math/math_function.h new file mode 100644 index 0000000000000000000000000000000000000000..adda0c6bb576a7b0b70f7e3dfe4ccb3401bd320a --- /dev/null +++ b/core/paddlefl_mpc/operators/math/math_function.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct RowwiseAdd { + void operator()(const DeviceContext& context, const framework::Tensor& input, + const framework::Tensor& vec, framework::Tensor* output); +}; + +template +struct SetConstant { + void operator()(const DeviceContext& context, framework::Tensor* tensor, + T num); +}; + +template +struct Transpose { + void operator()(const DeviceContext& context, const framework::Tensor& in, + framework::Tensor* out, const std::vector& axis); +}; + +template +struct ColwiseSum { + void operator()(const DeviceContext& context, const framework::Tensor& input, + framework::Tensor* vec); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/math/math_function_impl.h b/core/paddlefl_mpc/operators/math/math_function_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..51d41e8ea3cac423463f464b3f889bcc9226bfc1 --- /dev/null +++ b/core/paddlefl_mpc/operators/math/math_function_impl.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/data_type.h" +#include "core/paddlefl_mpc/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template +void SetConstant::operator()(const DeviceContext& context, + framework::Tensor* tensor, + T num) { + auto t = framework::EigenVector::Flatten(*tensor); + t.device(*context.eigen_device()) = t.constant(static_cast(num)); +} + +template +void Transpose::operator()( + const DeviceContext& context, const framework::Tensor& in, + framework::Tensor* out, const std::vector& axis) { + Eigen::array permute; + for (int i = 0; i < Rank; i++) { + permute[i] = axis[i]; + } + auto eigen_in = framework::EigenTensor::From(in); + auto eigen_out = framework::EigenTensor::From(*out); + auto* dev = context.eigen_device(); + eigen_out.device(*dev) = eigen_in.shuffle(permute); +} + +template +void ColwiseSum::operator()(const DeviceContext& context, + const framework::Tensor& input, + framework::Tensor* out) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(out->numel(), size); + + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenVector::Flatten(*out); + + vec.device(*context.eigen_device()) = in.sum(Eigen::array({{0}})); +} + +// Specialize for CPU, since Eigen implement a general reduce. However, +// colwise-sum can be easily implemented. General reduce has a huge overhead in +// CPU +template +class ColwiseSum { +public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, framework::Tensor* out) { + auto& in_dims = input.dims(); + auto height = in_dims[0]; + auto size = in_dims[1]; + PADDLE_ENFORCE_EQ(out->numel(), size); + + T* out_buf = out->mutable_data(out->place()); + const T* in_buf = input.data(); + + for (size_t i = 0; i < static_cast(height); ++i) { + for (size_t j = 0; j < static_cast(size); ++j) { + if (i == 0) { + out_buf[j] = in_buf[i * size + j]; + } else { + out_buf[j] += in_buf[i * size + j]; + } + } + } + } +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/math/sequence2batch.cc b/core/paddlefl_mpc/operators/math/sequence2batch.cc new file mode 100644 index 0000000000000000000000000000000000000000..55b24aba94025b420f4a76f3bbe9618a378df8ee --- /dev/null +++ b/core/paddlefl_mpc/operators/math/sequence2batch.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "core/paddlefl_mpc/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class CopyMatrixRowsFunctor { +public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& src, + framework::Vector index_lod, framework::Tensor* dst, + bool is_src_index) { + size_t* index = index_lod.data(); + auto src_dims = src.dims(); + auto dst_dims = dst->dims(); + PADDLE_ENFORCE_EQ(src_dims.size(), 2UL, + "The src must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(dst_dims.size(), 2UL, + "The dst must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], + "The width of src and dst must be same."); + auto height = dst_dims[0]; + auto width = dst_dims[1]; + auto* src_data = src.data(); + auto* dst_data = dst->data(); + const int sz = width * sizeof(T); + if (is_src_index) { + for (int i = 0; i < height; ++i) { + memcpy(dst_data + i * width, src_data + index[i] * width, sz); + } + } else { + for (int i = 0; i < height; ++i) { + memcpy(dst_data + index[i] * width, src_data + i * width, sz); + } + } + } +}; + +template class CopyMatrixRowsFunctor; + +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensorFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/math/sequence2batch.h b/core/paddlefl_mpc/operators/math/sequence2batch.h new file mode 100644 index 0000000000000000000000000000000000000000..3ba75ab28836eddf3df91f96318d70fc5a63c8b0 --- /dev/null +++ b/core/paddlefl_mpc/operators/math/sequence2batch.h @@ -0,0 +1,179 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +template +using EigenMatrix = framework::EigenMatrix; + +template +class CopyMatrixRowsFunctor { +public: + // If is_src_index is true, + // copy the indexed rows of input src to the output dst. + // If is_src_index is false, + // copy the input src to the indexed rows of output dst. + // The indexed rows are based on the input index. + void operator()(const DeviceContext& context, const framework::Tensor& src, + framework::Vector index_lod, framework::Tensor* dst, + bool is_src_index); +}; + +template +class LoDTensor2BatchFunctor { + // Calculate the length of each sequence and + // sort sequence index by the length. + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} + // + struct SeqInfo { + SeqInfo(size_t start, size_t length, size_t seq_idx) + : start(start), length(length), seq_idx(seq_idx) {} + size_t start; + size_t length; + size_t seq_idx; + }; + +public: + void operator()(const DeviceContext& context, + const framework::LoDTensor& lod_tensor, + framework::LoDTensor* batch, bool is_cal_batch_lod, + bool is_reverse = false) const { + if (!is_cal_batch_lod) { + auto lods = batch->lod(); + PADDLE_ENFORCE_GT(lods.size(), 2UL, + "The LoD of LoDTensor should inlcude at least 2-level " + "sequence information."); + PADDLE_ENFORCE_EQ( + lods[1].size(), static_cast(lod_tensor.dims()[0]), + "The LoD information should be consistent with the dims."); + CopyMatrixRowsFunctor to_batch; + to_batch(context, lod_tensor, lods[1], batch, true); + return; + } + + auto lods = lod_tensor.lod(); + PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); + + const auto& lod = lods[0]; + + std::vector seq_info; + for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { + size_t length = lod[seq_id + 1] - lod[seq_id]; + seq_info.emplace_back(lod[seq_id], length, seq_id); + } + + std::sort(seq_info.begin(), seq_info.end(), + [](SeqInfo a, SeqInfo b) { + return a.length > b.length; + }); + + // Calculate the start position of each batch. + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // max_seqlen = 5, + // batchIndex = {b0, b1, b2, b3, b4} + // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 + // batch_start_positions[6] = {0, 3, 6, 9, 11, 12} + // batch_start_positions[0] = len(b0) + // batch_start_positions[1] = len(b0) + len(b1) + // batch_start_positions[2] = len(b0) + len(b1) + len(b2) + // ... + // seq2batch_idx[12] = {4, 0, 9, + // 5, 1, 10, + // 6, 2, 11, + // 7, 3, + // 8} + // seq_order = {1, 0, 2}, the sort order. + // where 1 is the second sequence, + // 0 is the first sequence, + // 2 is the third sequence. + // The max_seqlen represents batch size after rearranging the + // input LodTensor. It is also the maximum length of input sequence. + + paddle::framework::LoD batch_lods; + batch_lods.emplace_back(std::vector {0}); + batch_lods.emplace_back(std::vector {0}); + batch_lods.emplace_back(std::vector {0}); + + // batch_lods[0] is the start positions for batch LoDTensor + size_t max_seqlen = seq_info[0].length; + batch_lods[0].resize(max_seqlen + 1); + // batch_lods[1] is the raw index in the input LoDTensor + batch_lods[1].resize(static_cast(lod_tensor.dims()[0])); + // batch_lods[2] is the sort order for the input LoDTensor. + batch_lods[2].resize(seq_info.size()); + + size_t* batch_starts = batch_lods[0].data(); + size_t* seq2batch_idx = batch_lods[1].data(); + batch_starts[0] = 0; + for (size_t n = 0; n < max_seqlen; n++) { + size_t batch_id = batch_starts[n]; + for (size_t i = 0; i < seq_info.size(); ++i) { + size_t seq_len = seq_info[i].length; + size_t start = seq_info[i].start; + if (n < seq_len) { + seq2batch_idx[batch_id] = + is_reverse ? start + seq_len - 1 - n : start + n; + batch_id++; + } else { + break; + } + } + batch_starts[n + 1] = batch_id; + } + size_t* seq_order = batch_lods[2].data(); + for (size_t i = 0; i < seq_info.size(); ++i) { + seq_order[i] = seq_info[i].seq_idx; + } + batch->set_lod(batch_lods); + + CopyMatrixRowsFunctor to_batch; + to_batch(context, lod_tensor, batch_lods[1], batch, true); + } +}; + +template +class Batch2LoDTensorFunctor { +public: + void operator()(const DeviceContext& context, + const framework::LoDTensor& batch, + framework::LoDTensor* lod_tensor) const { + auto in_lod = batch.lod(); + PADDLE_ENFORCE_GT(in_lod.size(), 2UL, + "The LoD of LoDTensor should inlcude at least 2-level " + "sequence information."); + PADDLE_ENFORCE_EQ( + in_lod[1].size(), static_cast(lod_tensor->dims()[0]), + "The LoD information should be consistent with the dims."); + CopyMatrixRowsFunctor to_seq; + to_seq(context, batch, in_lod[1], lod_tensor, false); + } +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/math/vol2col.cc b/core/paddlefl_mpc/operators/math/vol2col.cc new file mode 100644 index 0000000000000000000000000000000000000000..becc1b52793c2deda2e7261b4476f06d31afbc81 --- /dev/null +++ b/core/paddlefl_mpc/operators/math/vol2col.cc @@ -0,0 +1,239 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "./vol2col.h" +#include + +namespace paddle { +namespace operators { +namespace math { + +/* + * vol = [input_channels, input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& vol, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* col, + const DataLayout data_layout) const { + PADDLE_ENFORCE_EQ(vol.dims().size(), 4, + "The dimension of vol should be 4."); + PADDLE_ENFORCE_EQ(col->dims().size(), 7, + "The dimension of col should be 7."); + + int input_channels = + (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]); + int input_depth = + (data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]); + int input_height = + (data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]); + int input_width = + (data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]); + int filter_depth = col->dims()[1]; + int filter_height = col->dims()[2]; + int filter_width = col->dims()[3]; + int output_depth = col->dims()[4]; + int output_height = col->dims()[5]; + int output_width = col->dims()[6]; + int channels_col = + input_channels * filter_depth * filter_height * filter_width; + + // changed + bool paddings_size_is_6 = (paddings.size() == 6); + int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0]; + int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0]; + int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1]; + int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; + int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; + int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; + + PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + + 1, + output_depth, + "input_depth and output_depth are " + "mismatching."); + PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + + 1, + output_height, + "input_height and output_height are " + "mismatching."); + PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + + 1, + output_width, + "input_width and output_width are " + "mismatching."); + const T* vol_data = vol.data(); + T* col_data = col->data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int d_offset = (c / filter_width / filter_height) % filter_depth; + int c_in = c / filter_width / filter_height / filter_depth; + for (int d = 0; d < output_depth; ++d) { + int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0]; + for (int h = 0; h < output_height; ++h) { + int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1]; + for (int w = 0; w < output_width; ++w) { + int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2]; + + int col_idx = + ((c * output_depth + d) * output_height + h) * output_width + w; + int vol_idx; + if (data_layout != DataLayout::kNHWC) { + vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + } else { + vol_idx = ((d_pad * input_height + h_pad) * input_width + w_pad) * + input_channels + + c_in; + } + col_data[col_idx] = + (h_pad < 0 || h_pad >= input_height || w_pad < 0 || + w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) + ? static_cast(0) + : vol_data[vol_idx]; + } + } + } + } + } +}; + +/* + * vol = [input_channels,input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Col2VolFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& col, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* vol, + const DataLayout data_layout) const { + PADDLE_ENFORCE_EQ(vol->dims().size(), 4, + "The dimension of vol should be 4."); + PADDLE_ENFORCE_EQ(col.dims().size(), 7, + "The dimension of col should be 7."); + + int input_channels = + (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]); + int input_depth = + (data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]); + int input_height = + (data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]); + int input_width = + (data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]); + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + int channels_col = + input_channels * filter_depth * filter_height * filter_width; + + bool paddings_size_is_6 = (paddings.size() == 6); + int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0]; + int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0]; + int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1]; + int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; + int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; + int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; + + PADDLE_ENFORCE_EQ((input_depth + pad_d_forth + pad_d_back - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + + 1, + output_depth, + "input_depth and output_depth are " + "mismatching."); + PADDLE_ENFORCE_EQ((input_height + pad_h_up + pad_h_down - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + + 1, + output_height, + "input_height and output_height are " + "mismatching."); + PADDLE_ENFORCE_EQ((input_width + pad_w_left + pad_w_right - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + + 1, + output_width, + "input_width and output_width are " + "mismatching."); + T* vol_data = vol->data(); + const T* col_data = col.data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int d_offset = (c / filter_width / filter_height) % filter_depth; + int cIm = c / filter_width / filter_height / filter_depth; + for (int d = 0; d < output_depth; ++d) { + int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0]; + for (int h = 0; h < output_height; ++h) { + int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1]; + for (int w = 0; w < output_width; ++w) { + int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2]; + + if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && + w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { + int vol_idx; + if (data_layout != DataLayout::kNHWC) { + vol_idx = ((cIm * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + } else { + vol_idx = + ((d_pad * input_height + h_pad) * input_width + w_pad) * + input_channels + + cIm; + } + int col_idx = + ((c * output_depth + d) * output_height + h) * output_width + + w; + vol_data[vol_idx] += col_data[col_idx]; + } + } + } + } + } + } +}; + +template class Vol2ColFunctor; +template class Col2VolFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/math/vol2col.h b/core/paddlefl_mpc/operators/math/vol2col.h new file mode 100644 index 0000000000000000000000000000000000000000..3122828b2eeba5fb1428235dd3a5f926705bd78e --- /dev/null +++ b/core/paddlefl_mpc/operators/math/vol2col.h @@ -0,0 +1,92 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +using DataLayout = framework::DataLayout; + +/* + * \brief Converts the feature data of four dimensions(CDHW) into a colData of + * seven dimensions in the Vol2ColFunctor calculation, + * And in the Col2VolFunctor calculation, it is reversed. + * + * \param volData Vol data. + * \param volShape The shape of volData, + * [input_channels, input_depth, input_height, input_width]. + * \param colData Column data. + * \param colShape The shape of colData. + * + * \param dilations dilation data. + * \param 3-dimension [dilation_depth, dilation_height, dilation_width]. + * + * \param strides stride data. + * \param 3-dimension [stride_depth, stride_height, stride_width]. + * + * \param paddings padding data. + * \param 3-dimension [d_pad, h_pad, w_pad]. + * + * The shape of colData is: + * [input_channels, filter_depth, filter_height, filter_width, output_depth, + * output_height, output_width] + * So, it is easy to reshape into a convolution matrix for convolution + * calculation based on matrix multiplication. + * The shape of convolution matrix is [height, width], where the height is equal + * input_channels * filter_depth * filter_height * filter_width, and the width + * is equal output_depth * output_height * output_width. + * + * Reshape: + * shape of colData shape of convolution matrix + * [input_channels, + * filter_depth, + * filter_height, + * filter_width, ======> [height, width] + * output_depth, + * output_height, + * output_width] + * + * \note The caller needs to ensure that volShape.inputChannels is equal to + * colShape.inputChannels. + */ +template +class Vol2ColFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& vol, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* col, + const DataLayout data_layout = DataLayout::kNCHW) const; +}; + +template +class Col2VolFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& col, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* vol, + const DataLayout data_layout = DataLayout::kNCHW) const; +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/mpc_adam_op.cc b/core/paddlefl_mpc/operators/mpc_adam_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..420958f0e6a07f76d49d2ee5a2a2da8b38e83457 --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_adam_op.cc @@ -0,0 +1,220 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "mpc_adam_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class MpcAdamOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override; +}; + +void MpcAdamOp::InferShape(framework::InferShapeContext *ctx) const { + PADDLE_ENFORCE_EQ( + ctx->HasInput("Param"), true, + platform::errors::NotFound("Input(Param) of AdamOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Grad"), true, + platform::errors::NotFound("Input(Grad) of AdamOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Moment1"), true, + platform::errors::NotFound( + "Input(Moment1) of AdamOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Moment2"), true, + platform::errors::NotFound( + "Input(Moment2) of AdamOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("LearningRate"), true, + platform::errors::NotFound( + "Input(LearningRate) of AdamOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Beta1Pow"), true, + platform::errors::NotFound( + "Input(Beta1Pow) of AdamOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Beta2Pow"), true, + platform::errors::NotFound( + "Input(Beta2Pow) of AdamOp should not be null.")); + + PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true, + platform::errors::NotFound( + "Output(ParamOut) of AdamOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment1Out"), true, + platform::errors::NotFound( + "Output(Moment1Out) of AdamOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment2Out"), true, + platform::errors::NotFound( + "Output(Moment2Out) of AdamOp should not be null.")); + + auto lr_dims = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_NE( + framework::product(lr_dims), 0, + platform::errors::InvalidArgument( + "The number of LearningRate shall not be 0, but received %d. Maybe " + "the Input variable LearningRate has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.", + framework::product(lr_dims))); + PADDLE_ENFORCE_EQ( + framework::product(lr_dims), 1, + platform::errors::InvalidArgument( + "Learning rate should have 1 dimension, but received %d", + framework::product(lr_dims))); + auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow"); + VLOG(3) << "dims of Beta1Pow : [" << beta1_pow_dims << "]"; + PADDLE_ENFORCE_GE(framework::product(beta1_pow_dims), 1, + platform::errors::InvalidArgument( + "The size of Beta1 power accumulator should be greater " + "than 0, but received %d.", + framework::product(beta1_pow_dims))); + auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow"); + VLOG(3) << "dims of Beta2Pow : [" << beta2_pow_dims << "]"; + PADDLE_ENFORCE_GE(framework::product(beta2_pow_dims), 1, + platform::errors::InvalidArgument( + "The size of Beta2 power accumulator should be greater " + "than 0, but received %d.", + framework::product(beta2_pow_dims))); + + auto param_dims = ctx->GetInputDim("Param"); + if (ctx->GetInputsVarType("Grad")[0] == + framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Grad"), + platform::errors::InvalidArgument( + "Param and Grad input of AdamOp should have same dimension. But " + "received Param dims: [%s], Grad dims: [%s].", + param_dims, ctx->GetInputDim("Grad"))); + } + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Moment1"), + platform::errors::InvalidArgument( + "Param and Moment1 input of AdamOp should have same dimension. But " + "received Param dims: [%s], Moment1 dims: [%s].", + param_dims, ctx->GetInputDim("Moment1"))); + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Moment2"), + platform::errors::InvalidArgument( + "Param and Moment2 input of AdamOp should have same dimension. But " + "received Param dims: [%s], Moment2 dims: [%s].", + param_dims, ctx->GetInputDim("Moment2"))); + + ctx->SetOutputDim("ParamOut", param_dims); + ctx->SetOutputDim("Moment1Out", param_dims); + ctx->SetOutputDim("Moment2Out", param_dims); + ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims); + ctx->SetOutputDim("Beta2PowOut", beta2_pow_dims); +} + +framework::OpKernelType MpcAdamOp::GetExpectedKernelType( + const framework::ExecutionContext &ctx) const { + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); +} + +framework::OpKernelType MpcAdamOp::GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const { + if (var_name == "Beta1Pow" || var_name == "Beta2Pow") { + return expected_kernel_type; + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +} + +class MpcAdamOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Param", "(Tensor) Input parameter"); + AddInput("Grad", "(Tensor) Input gradient"); + AddInput("LearningRate", "(Tensor) Learning rate"); + AddInput("Moment1", "(Tensor) Input first moment"); + AddInput("Moment2", "(Tensor) Input second moment"); + AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator"); + AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator"); + + AddInput("Beta1Tensor", + "(Tensor, optional) If provided, Adam will use this " + "as beta1, this has a higher priority than attr(beta1), the " + "shape of this tensor MUST BE [1].") + .AsDispensable(); + AddInput("Beta2Tensor", + "(Tensor, optional) If provided, Adam will use this " + "as beta2, this has a higher priority than attr(beta2), the " + "shape of this tensor MUST BE [1].") + .AsDispensable(); + + AddOutput("ParamOut", "(Tensor) Output parameter"); + AddOutput("Moment1Out", "(Tensor) Output first moment"); + AddOutput("Moment2Out", "(Tensor) Output second moment"); + AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator"); + AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator"); + + AddAttr("beta1", + "(float, default 0.9) " + "Exponential decay rate for the " + "first moment estimates.") + .SetDefault(0.9f); + AddAttr("beta2", + "(float, default 0.999) " + "exponential decay rate for the " + "second moment estimates.") + .SetDefault(0.999f); + AddAttr("epsilon", + "(float, default 1.0e-4) " + "Constant for numerical stability") + .SetDefault(1.0e-4f); + + AddComment(R"DOC( +Adam Optimizer. + +This implements the Adam optimizer from Section 2 of the Adam +paper : https://arxiv.org/abs/1412.6980. +Adam is a first-order gradient-based optimization method based on +adaptive estimates of lower-order moments. + +Adam updates: + +$$ +moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\ +moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\ +learning\_rate = learning\_rate * + \frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\ +param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon} +$$ + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + mpc_adam, ops::MpcAdamOp, ops::MpcAdamOpMaker); + +REGISTER_OP_CPU_KERNEL( + mpc_adam, + ops::MpcAdamOpKernel); diff --git a/core/paddlefl_mpc/operators/mpc_adam_op.h b/core/paddlefl_mpc/operators/mpc_adam_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e1ea7930ef50c0680990083bf8421257f8d636a0 --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_adam_op.h @@ -0,0 +1,158 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "mpc_op.h" + +#include + +#include "./math/math_function.h" +#include "core/paddlefl_mpc/mpc_protocol/aby3_operators.h" + +namespace paddle { +namespace operators { + +static inline float GetAttrFromTensor(const framework::Tensor* tensor) { + const float* tensor_data = tensor->data(); + framework::Tensor cpu_tensor; + return tensor_data[0]; +} + +template +class MpcAdamOpKernel : public MpcOpKernel { + public: + void ComputeImpl(const framework::ExecutionContext &ctx) const override{ + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.InputNames("Param").front(), + framework::ToTypeName(param_var->Type())); + + using paddle::framework::LoDTensor; + + T1 epsilon = static_cast(ctx.Attr("epsilon")); + auto* param = ctx.Input("Param"); + auto* grad_var = ctx.InputVar("Grad"); + auto* mom1 = ctx.Input("Moment1"); + auto* mom2 = ctx.Input("Moment2"); + auto* lr = ctx.Input("LearningRate"); + + auto* beta1_pow = ctx.Input("Beta1Pow"); + auto* beta2_pow = ctx.Input("Beta2Pow"); + + auto* param_out = ctx.Output("ParamOut"); + auto* mom1_out = ctx.Output("Moment1Out"); + auto* mom2_out = ctx.Output("Moment2Out"); + auto* beta1_pow_out = ctx.Output("Beta1PowOut"); + auto* beta2_pow_out = ctx.Output("Beta2PowOut"); + + T1 beta1 = static_cast(ctx.Attr("beta1")); + if (ctx.HasInput("Beta1Tensor")) { + auto* beta1_tensor = ctx.Input("Beta1Tensor"); + PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(Beta1Tensor) size must be 1, but get %d", + beta1_tensor->numel())); + beta1 = static_cast(GetAttrFromTensor(beta1_tensor)); + } + T1 beta2 = static_cast(ctx.Attr("beta2")); + if (ctx.HasInput("Beta2Tensor")) { + auto* beta2_tensor = ctx.Input("Beta2Tensor"); + PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(Beta2Tensor) size must be 1, but get %d", + beta2_tensor->numel())); + beta2 = static_cast(GetAttrFromTensor(beta2_tensor)); + } + VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel() + << "beta2_pow.numel() : " << beta2_pow->numel(); + VLOG(3) << "param.numel(): " << param->numel(); + + PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1, + platform::errors::InvalidArgument( + "beta1 pow output size should be 1, but received " + "value is:%d.", + beta1_pow_out->numel())); + + PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1, + platform::errors::InvalidArgument( + "beta2 pow output size should be 1, but received " + "value is:%d.", + beta2_pow_out->numel())); + + if (grad_var->IsType()) { + auto* grad = ctx.Input("Grad"); + + // AdamFunctor functor( + // beta1, beta2, epsilon, beta1_pow->data(), beta2_pow->data(), + // mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), + // mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), + // lr->data(), grad->data(), param->data(), + // param_out->mutable_data(ctx.GetPlace())); + // functor(param->numel()); + + T1 lr_value = *lr->template data(); + + T1 beta1_pow_ = *beta1_pow->template data(); + T1 beta2_pow_ = *beta2_pow->template data(); + + double lr_ = lr_value * sqrt(1 - beta2_pow_) / (1 - beta1_pow_); + + framework::Tensor temp; + temp.mutable_data(param->dims(), ctx.GetPlace()); + + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(grad, (1 - beta1), &temp); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(mom1, beta1, mom1_out); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(mom1_out, &temp, mom1_out); + + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(grad, (1 - beta2), &temp); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(grad, &temp, &temp); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(mom2, beta2, mom2_out); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(mom2_out, &temp, mom2_out); + + // mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(grad, lr[0], &temp); + // mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(param, &temp, param_out); + + math::SetConstant set_const; + auto& dev_ctx = ctx.template device_context(); + set_const( + dev_ctx, + &temp, + T(epsilon * pow(2, mpc::ABY3_SCALING_FACTOR) / 3)); + + // temp = epsilon + mom2_out + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(mom2_out, &temp, &temp); + // temp = 1 / sqrt(epsilon + mom2_out) + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->inverse_square_root(&temp, &temp); + // temp = mom1_out / sqrt(epsilon + mom2_out) + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(mom1_out, &temp, &temp); + // temp = lr * mom1_out / sqrt(epsilon + mom2_out) + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(&temp, lr_, &temp); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(param, &temp, param_out); + + beta1_pow_out->mutable_data(ctx.GetPlace())[0] = + beta1 * beta1_pow->template data()[0]; + beta2_pow_out->mutable_data(ctx.GetPlace())[0] = + beta2 * beta2_pow->template data()[0]; + + } else { + PADDLE_THROW("Variable type not supported by adam_op"); + } + + } +}; +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/mpc_batch_norm_op.cc b/core/paddlefl_mpc/operators/mpc_batch_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2be7506852bd91b095866bc295c450a6b0a3d1ba --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_batch_norm_op.cc @@ -0,0 +1,384 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/data_layout.h" +#include +#include +#include +#include "mpc_batch_norm_op.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +class MpcBatchNormOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override{ + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNorm"); + OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "BatchNorm"); + OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "BatchNorm"); + OP_INOUT_CHECK(ctx->HasInput("Mean"), "Input", "Mean", "BatchNorm"); + OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "Variance", "BatchNorm"); + OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "BatchNorm"); + + bool is_test = ctx->Attrs().Get("is_test"); + bool trainable_stats = ctx->Attrs().Get("trainable_statistics"); + bool test_mode = is_test && (!trainable_stats); + if (!test_mode) { + OP_INOUT_CHECK(ctx->HasOutput("MeanOut"), "Output", "MeanOut", "BatchNorm"); + OP_INOUT_CHECK(ctx->HasOutput("VarianceOut"), "Output", "VarianceOut", + "BatchNorm"); + OP_INOUT_CHECK(ctx->HasOutput("SavedMean"), "Output", "SavedMean", + "BatchNorm"); + OP_INOUT_CHECK(ctx->HasOutput("SavedVariance"), "Output", "SavedVariance", + "BatchNorm"); + } + + // make sure Mean/MeanOut and Variance/VarianceOut share memory in Python + PADDLE_ENFORCE_EQ(ctx->Inputs("Mean")[0], ctx->Outputs("MeanOut")[0], + platform::errors::InvalidArgument( + "Mean and MeanOut should share the same memory")); + + PADDLE_ENFORCE_EQ( + ctx->Inputs("Variance")[0], ctx->Outputs("VarianceOut")[0], + platform::errors::InvalidArgument( + "Variance and VarianceOut should share the same memory")); + + const auto x_dims = ctx->GetInputDim("X"); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + + if (ctx->IsRuntime() && ctx->HasInput("MomentumTensor")) { + auto mom = ctx->Inputs("MomentumTensor"); + PADDLE_ENFORCE_EQ(mom.size(), 1, + platform::errors::InvalidArgument( + "The input tensor MomentumTensor's size must be 1" + "But received: MomentumTensor's size is [%d]", + mom.size())); + } + + PADDLE_ENFORCE_GE( + x_dims.size(), 3, + platform::errors::InvalidArgument( + "ShapeError: the dimension of input " + "X must greater than or equal to 3. But received: the shape of input " + "X = [%s], the dimension of input X =[%d]", + x_dims, x_dims.size())); + + PADDLE_ENFORCE_LE( + x_dims.size(), 6, + platform::errors::InvalidArgument( + "ShapeError: the dimension of input X " + "must smaller than or equal to 6. But received: the shape of input X " + "= [%s], the dimension of input X = [%d]", + x_dims, x_dims.size())); + + + const int64_t C = + ((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) + ? x_dims[2] + : x_dims[x_dims.size() - 1]); + + auto scale_dim = ctx->GetInputDim("Scale"); + auto bias_dim = ctx->GetInputDim("Bias"); + VLOG(3) << "*** scale_dims: " << scale_dim; + VLOG(3) << "*** bias_dims: " << bias_dim; + VLOG(3) << "*** mean_dims: " << ctx->GetInputDim("Mean"); + VLOG(3) << "*** variance_dims: " << ctx->GetInputDim("Variance"); + //VLOG(3) << "*** Y_dims: " << ctx->GetInputDim("Y"); + + PADDLE_ENFORCE_EQ( + scale_dim.size(), 2UL, + platform::errors::InvalidArgument( + "ShapeError: the dimension of scale must equal to 2." + "But received: the shape of scale is [%s], the dimension " + "of scale is [%d]", + scale_dim, scale_dim.size())); + PADDLE_ENFORCE_EQ(bias_dim.size(), 2UL, + platform::errors::InvalidArgument( + "ShapeError: the dimension of bias must equal to 2." + "But received: the shape of bias is [%s],the dimension " + "of bias is [%d]", + bias_dim, bias_dim.size())); + + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 || + framework::product(bias_dim) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(scale_dim[1], C, + platform::errors::InvalidArgument( + "ShapeError: the shape of scale must equal to [%d]" + "But received: the shape of scale is [%d]", + C, scale_dim[1])); + PADDLE_ENFORCE_EQ(bias_dim[1], C, + platform::errors::InvalidArgument( + "ShapeError: the shape of bias must equal to [%d]" + "But received: the shape of bias is [%d]", + C, bias_dim[1])); + } + ctx->SetOutputDim("Y", x_dims); + ctx->SetOutputDim("MeanOut", {2, C}); // 2: share_num + ctx->SetOutputDim("VarianceOut", {2, C}); + ctx->SetOutputDim("SavedMean", {2, C}); + ctx->SetOutputDim("SavedVariance", {2, C}); + ctx->ShareLoD("X", "Y"); + } + +protected: + framework::OpKernelType GetExpectedKernelType(const framework::ExecutionContext& ctx) const { + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = "AnyLayout"; + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + + +class MpcBatchNormGradOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override{ + // check input + OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "BatchNormGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input", + framework::GradVarName("Y"), "BatchNormGrad"); + OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean", + "BatchNormGrad"); + OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance", + "BatchNormGrad"); + + // check output + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + framework::GradVarName("X"), "BatchNormGrad"); + + const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale")); + const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias")); + + PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true, + platform::errors::NotFound( + "Output(Scale@GRAD) and Output(Bias@GRAD) must be null " + "or not be null at same time. But now, " + "has Scale@Grad=[%d], has Bias@GRAD=[%d]", + has_scale_grad, has_bias_grad)); + + const bool use_global_stats = ctx->Attrs().Get("use_global_stats"); + if (use_global_stats) { + PADDLE_ENFORCE_EQ( + !ctx->Attrs().Get("use_mkldnn"), true, + platform::errors::InvalidArgument( + "Using global stats during training is not supported " + "in gradient op kernel of batch_norm_mkldnn_op now.")); + } + + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNormGrad"); + const auto x_dims = ctx->GetInputDim("X"); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + + const int C = + ((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) + ? x_dims[2] + : x_dims[x_dims.size() - 1]); + + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + // has_scale_grad == has_bias_grad, judge has_scale_grad is enough + if (has_scale_grad) { + ctx->SetOutputDim(framework::GradVarName("Scale"), {2, C}); // 2: share_num + ctx->SetOutputDim(framework::GradVarName("Bias"), {2, C}); + } + } + +protected: + framework::OpKernelType GetExpectedKernelType(const framework::ExecutionContext& ctx) const { + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = "AnyLayout"; + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, library_); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + + +class MpcBatchNormOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() { + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("momentum", "").SetDefault(0.9); + AddAttr("epsilon", "") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_GE( + epsilon, 0.0f, + platform::errors::InvalidArgument( + "'epsilon' should be greater or equal than 0.0.")); + PADDLE_ENFORCE_LE(epsilon, 0.001f, + platform::errors::InvalidArgument( + "'epsilon' should be less or equal than 0.001.")); + }); + AddAttr("data_layout", "").SetDefault("NCHW"); + AddInput("X", "The input tensor"); + AddInput("Scale", + "Scale is a 1-dimensional tensor of size C " + "that is applied to the output"); + AddInput("Bias", + "Bias is a 1-dimensional tensor of size C " + "that is applied to the output"); + AddInput("Mean", + "The global mean (for training) or " + "estimated mean (for testing)"); + AddInput("Variance", + "The global variance (for training) " + "or estimated Variance (for testing)"); + AddInput("MomentumTensor", + "(Tensor, optional) If provided, batch_norm will " + "use this as momentum, this has a higher priority than " + "attr(momentum), the shape of this tensor MUST BE [1].") + .AsDispensable(); + AddOutput("Y", "result after normalization"); + AddOutput("MeanOut", + "Share memory with Mean. " + "Store the global mean when training"); + AddOutput("VarianceOut", + "Share memory with Variance. " + "Store the global Variance when training"); + AddOutput("SavedMean", + "Mean of the current mini batch, " + "will apply to output when training") + .AsIntermediate(); + AddOutput("SavedVariance", + "Variance of the current mini batch, " + "will apply to output when training") + .AsIntermediate(); + AddOutput("ReserveSpace", + "Reserve GPU space for triggering the new semi-persistent " + "NHWC kernel") + .AsDispensable(); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("fuse_with_relu", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("use_global_stats", + "(bool, default false) Whether to use global mean and " + "variance. In inference or test mode, set use_global_stats " + "to true or is_test true. the behavior is equivalent. " + "In train mode, when setting use_global_stats True, the " + "global mean and variance are also used during train time, " + "the BN acts as scaling and shiffting.") + .SetDefault(false); + AddAttr("trainable_statistics", + "(bool, default false) Whether to calculate mean and variance " + "in test mode. If setting true in test mode, mean and variace " + "will be calculated by current batch statistics.") + .SetDefault(false); + AddComment(R"DOC( +Batch Normalization. +Batch Norm has been implemented as discussed in the paper: +https://arxiv.org/pdf/1502.03167.pdf +Can be used as a normalizer function for conv2d and fully_connected operations. +The required data format for this layer is one of the following: +1. NHWC `[batch, in_height, in_width, in_channels]` +2. NCHW `[batch, in_channels, in_height, in_width]` +)DOC"); + } +}; + +template +class MpcBatchNormGradOpMaker : public framework::SingleGradOpMaker { +public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + +protected: + void Apply(GradOpPtr op) const { + op->SetType(this->ForwardOpType() + "_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + + op->SetInput("Scale", this->Input("Scale")); + op->SetInput("Bias", this->Input("Bias")); + op->SetInput("SavedMean", this->Output("SavedMean")); + op->SetInput("SavedVariance", this->Output("SavedVariance")); + if (this->HasOutput("ReserveSpace")) { + op->SetInput("ReserveSpace", this->Output("ReserveSpace")); + } + + // used when setting use_global_stats True during training + if (boost::get(this->GetAttr("use_global_stats"))) { + op->SetInput("Mean", this->Output("MeanOut")); + op->SetInput("Variance", this->Output("VarianceOut")); + } + op->SetAttrMap(this->Attrs()); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale")); + op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); + } +}; + + +class MpcBatchNormOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { +protected: + std::unordered_map& GetInputOutputWithSameType() const override { + static std::unordered_map m{{"X", /*->*/ "Y"}}; + return m; + } +}; + + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR( + mpc_batch_norm, ops::MpcBatchNormOp, ops::MpcBatchNormOpMaker, + ops::MpcBatchNormOpInferVarType, + ops::MpcBatchNormGradOpMaker, + ops::MpcBatchNormGradOpMaker); +REGISTER_OPERATOR(mpc_batch_norm_grad, ops::MpcBatchNormGradOp); + +REGISTER_OP_CPU_KERNEL( + mpc_batch_norm, ops::MpcBatchNormKernel); +REGISTER_OP_CPU_KERNEL( + mpc_batch_norm_grad, ops::MpcBatchNormGradKernel); diff --git a/core/paddlefl_mpc/operators/mpc_batch_norm_op.h b/core/paddlefl_mpc/operators/mpc_batch_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ba93da597106057c5d780688891c350aabe6d67b --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_batch_norm_op.h @@ -0,0 +1,516 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "mpc_op.h" +#include "./math/math_function.h" +#include "core/paddlefl_mpc/mpc_protocol/mpc_operators.h" + +namespace paddle { +namespace operators { + +using DDim = framework::DDim; +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DataLayout = framework::DataLayout; +std::shared_ptr mpc_operators; +// TODO: remove dependency on aby3 protocol +const int MPC_ONE_SHARE = (1 << paddle::mpc::FIXED_POINTER_SCALING_FACTOR) / 3; + +template +void Expand(const Tensor* input, Tensor* output, int S, int N, int C, int sample_size) { + // Expand tensor into specified shape + // input shape: {S, C} + // outout shape: {S, N, C, H, W}, sample_size = H * W + const T* input_data = input->data(); + T* output_data = output->data(); + int input_share_offset = C; + int output_share_offset = N * C * sample_size; + for (int nc = 0; nc < N * C; ++nc) { + int nc_offset = nc * sample_size; + std::fill(output_data + nc_offset, output_data + nc_offset + sample_size, *(input_data + nc % C)); + std::fill(output_data + nc_offset + output_share_offset, + output_data + nc_offset + output_share_offset + sample_size, + *(input_data + nc % C + input_share_offset)); + } +} + +template +void TransToChannelFirst(const Tensor* input, Tensor* output, const framework::ExecutionContext &ctx) { + // Transpose tensor + // input shape: {S, N, C, H, W} + // output shape: {C, S, N, H, W} + // H and W is optional + auto& dev_ctx = ctx.template device_context(); + auto input_dims = input->dims(); + switch (input_dims.size()) { + case 3: { + std::vector axis{2, 0, 1}; + output->mutable_data({input_dims[2], input_dims[0], input_dims[1]}, ctx.GetPlace()); + math::Transpose trans3; + trans3(dev_ctx, *input, output, axis); + break; + } + case 4: { + std::vector axis{2, 0, 1, 3}; + output->mutable_data({input_dims[2], input_dims[0], input_dims[1], input_dims[3]}, ctx.GetPlace()); + math::Transpose trans4; + trans4(dev_ctx, *input, output, axis); + break; + } + case 5: { + std::vector axis{2, 0, 1, 3, 4}; + output->mutable_data({input_dims[2], input_dims[0], input_dims[1], input_dims[3], input_dims[4]}, + ctx.GetPlace()); + math::Transpose trans5; + trans5(dev_ctx, *input, output, axis); + break; + } + default: + PADDLE_THROW("The size of input X's dimensions should be larger than 2, less than 6."); + } +} + +template +void ComputeSum(const Tensor* input, int C, Tensor* sum, const framework::ExecutionContext &ctx) { + // Compute sum of each channel + // input shape: {S, N, C, H, W} + // output shape: {S, C} + // H and W is optional, compute the sum of each channel. + Tensor input_trans; + TransToChannelFirst(input, &input_trans, ctx); + Tensor input_slice; + Tensor sum_slice; + auto sum_slice_data = sum_slice.mutable_data(framework::make_ddim({2, 1}), ctx.GetPlace()); + auto sum_data = sum->data(); + for (size_t i = 0; i < C; ++i) { + input_slice = input_trans.Slice(i, i + 1); + auto shape = paddle::framework::vectorize(input_slice.dims()); + shape.erase(shape.begin()); + std::vector shape_(shape.cbegin(), shape.cend()); + DDim dim(shape_.data(), shape_.size()); + input_slice.Resize(dim); + mpc_operators->sum(&input_slice, &sum_slice); + sum_data[i] = sum_slice_data[0]; + sum_data[i + C] = sum_slice_data[1]; + } +} + + +template +void ComputeMeanVariance(const Tensor* input, int S, int N, int C, int sample_size, + Tensor* saved_mean_e, Tensor* saved_variance_e, + const framework::ExecutionContext &ctx) { + // Compute mean and variance of each channel + // input shape: {S, N, C, H, W} + // output shape: {S, C} + // H and W is optional + VLOG(3) << "Compute the mean and variance of each channel"; + Tensor input_trans; + TransToChannelFirst(input, &input_trans, ctx); + + ComputeSum(input, C, saved_mean_e, ctx); + mpc_operators->scale(saved_mean_e, 1.0 / (N * sample_size), saved_mean_e); // scale + + Tensor saved_mean_e_expand; + T* saved_mean_e_expand_data = saved_mean_e_expand.mutable_data(input->dims(), ctx.GetPlace()); + Expand(saved_mean_e, &saved_mean_e_expand, S, N, C, sample_size); + mpc_operators->sub(input, &saved_mean_e_expand, &saved_mean_e_expand); + mpc_operators->mul(&saved_mean_e_expand, &saved_mean_e_expand, &saved_mean_e_expand); + ComputeSum(&saved_mean_e_expand, C, saved_variance_e, ctx); + mpc_operators->scale(saved_variance_e, 1.0 / (N * sample_size), saved_variance_e); // scale + +} + +template +class MpcBatchNormKernel : public MpcOpKernel { +public: + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + + mpc_operators = mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators(); + VLOG(3) << "Start MpcBatchNormKernel."; + const float epsilon = ctx.Attr("epsilon"); + float momentum = ctx.Attr("momentum"); + const bool is_test = ctx.Attr("is_test"); + const bool use_global_stats = ctx.Attr("use_global_stats"); + const bool trainable_stats = ctx.Attr("trainable_statistics"); + bool test_mode = is_test && (!trainable_stats); + + bool global_stats = test_mode || use_global_stats; + + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + + const Tensor *x = ctx.Input("X"); + const DDim x_dims = x->dims(); + PADDLE_ENFORCE_GE( + x_dims.size(), 3, + platform::errors::InvalidArgument( + "The size of input X's dimensions should be larger than 2." + "But received: the size of input X's dimensions is [%d]", + x_dims.size())); + PADDLE_ENFORCE_LE( + x_dims.size(), 6, + platform::errors::InvalidArgument( + "The size of input X's dimensions should be less than 6." + "But received: the size of input X's dimensionss is [%d]", + x_dims.size())); + + const int S = 2; // share number + const int N = x_dims[1]; + const int C = (data_layout == DataLayout::kNCHW ? x_dims[2] : x_dims[x_dims.size() - 1]); + const int sample_size = x->numel() / S / N / C; + + auto *y = ctx.Output("Y"); + + auto *mean_out = ctx.Output("MeanOut"); + auto *variance_out = ctx.Output("VarianceOut"); + auto *saved_mean = ctx.Output("SavedMean"); + auto *saved_variance = ctx.Output("SavedVariance"); + + // alloc memory + y->mutable_data(ctx.GetPlace()); + mean_out->mutable_data(ctx.GetPlace()); + variance_out->mutable_data(ctx.GetPlace()); + saved_mean->mutable_data(ctx.GetPlace()); + saved_variance->mutable_data(ctx.GetPlace()); + + if (!global_stats) { + if ((N * sample_size) == 1) { + // Only 1 element in normalization dimension, + // we skip the batch norm calculation, let y = x. + framework::TensorCopy(*x, ctx.GetPlace(), y); + return; + } + + // saved_xx is use just in this batch of data + // compute mean and variance + switch (data_layout) { + case DataLayout::kNCHW: { + ComputeMeanVariance(x, S, N, C, sample_size, saved_mean, saved_variance, ctx); + break; + } + default: + PADDLE_THROW("Unknown storage order: %s", data_layout_str); + } + + // updata global mean and variance, for prediction + if (ctx.HasInput("MomentumTensor")) { + const auto *mom_tensor = ctx.Input("MomentumTensor"); + momentum = mom_tensor->data()[0]; + } + + Tensor saved_mean_scale; + Tensor mean_out_scale; + saved_mean_scale.mutable_data(saved_mean->dims(), ctx.GetPlace()); + mean_out_scale.mutable_data(mean_out->dims(), ctx.GetPlace()); + + mpc_operators->scale(mean_out, momentum, &mean_out_scale); + mpc_operators->scale(saved_mean, 1.0 - momentum, &saved_mean_scale); + mpc_operators->add(&mean_out_scale, &saved_mean_scale, mean_out); + + mpc_operators->scale(variance_out, momentum, &mean_out_scale); + mpc_operators->scale(saved_variance, 1.0 - momentum, &saved_mean_scale); + + mpc_operators->add(&mean_out_scale, &saved_mean_scale, variance_out); + } + + + // use SavedMean and SavedVariance to do normalize + // compute output y + Tensor inv_std; + Tensor mean_arr; + inv_std.mutable_data({S, C}, ctx.GetPlace()); + + Tensor epsilon_expand; + T* epsilon_expand_data = epsilon_expand.mutable_data({S, C}, ctx.GetPlace()); + std::fill(epsilon_expand_data, epsilon_expand_data + S * C, MPC_ONE_SHARE * epsilon); // todo + + // inv_std = 1 / sqrt(variance + epsilon) + if (global_stats) { + const Tensor* variance = ctx.Input("Variance"); + Tensor var_plus_epsilon; + var_plus_epsilon.mutable_data({S, C}, ctx.GetPlace()); + + mpc_operators->add(variance, &epsilon_expand, &var_plus_epsilon); + mpc_operators->inverse_square_root(&var_plus_epsilon, &inv_std); + + mean_arr.ShareDataWith(*ctx.Input("Mean")); + } else { + Tensor var_plus_epsilon; + var_plus_epsilon.mutable_data({S, C}, ctx.GetPlace()); + mpc_operators->add(saved_variance, &epsilon_expand, &var_plus_epsilon); + mpc_operators->inverse_square_root(&var_plus_epsilon, &inv_std); + + mean_arr.ShareDataWith(*saved_mean); + } + + // ((x - est_mean) * (inv_var) * scale + bias + // formula transform ====> + // (x * inv_var * scale) + (bias - est_mean * inv_var * scale) + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + const T* scale_data = scale->data(); + const T* bias_data = bias->data(); + + Tensor scale_expand; + auto* scale_expand_data = scale_expand.mutable_data({S, C}, ctx.GetPlace()); + std::fill(scale_expand_data, scale_expand_data + C, scale_data[0]); + std::fill(scale_expand_data + C, scale_expand_data + C + C, scale_data[1]); + + Tensor bias_expand; + auto* bias_expand_data = bias_expand.mutable_data({S, C}, ctx.GetPlace()); + std::fill(bias_expand_data, bias_expand_data + C, bias_data[0]); + std::fill(bias_expand_data + C, bias_expand_data + C + C, bias_data[1]); + + Tensor new_scale; + Tensor new_bias; + Tensor new_bias_tmp; + new_scale.mutable_data(scale_expand.dims(), ctx.GetPlace()); + new_bias.mutable_data(scale_expand.dims(), ctx.GetPlace()); + new_bias_tmp.mutable_data(scale_expand.dims(), ctx.GetPlace()); + + mpc_operators->mul(&inv_std, &scale_expand, &new_scale); + mpc_operators->mul(&mean_arr, &new_scale, &new_bias_tmp); + mpc_operators->sub(&bias_expand, &new_bias_tmp, &new_bias); + + switch (data_layout) { + case DataLayout::kNCHW: { + Tensor x_new_scale; + x_new_scale.mutable_data(y->dims(), ctx.GetPlace()); + + Tensor new_scale_expand; + new_scale_expand.mutable_data(x->dims(), ctx.GetPlace()); + Expand(&new_scale, &new_scale_expand, S, N, C, sample_size); + + Tensor new_bias_expand; + new_bias_expand.mutable_data(x->dims(), ctx.GetPlace()); + Expand(&new_bias, &new_bias_expand, S, N, C, sample_size); + + mpc_operators->mul(x, &new_scale_expand, &x_new_scale); + mpc_operators->add(&x_new_scale, &new_bias_expand, y); + break; + } + default: + PADDLE_THROW("Unknown storage order: %d", data_layout); + } + } +}; + + +template +class MpcBatchNormGradKernel : public MpcOpKernel { +public: + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + + mpc_operators = mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators(); + VLOG(3) << "Start MpcBatchNormGradKernel."; + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + const auto *saved_mean = ctx.Input("SavedMean"); + + // SavedVariance have been reverted in forward operator + const auto *saved_inv_variance = ctx.Input("SavedVariance"); + const std::string data_layout_str = ctx.Attr("data_layout"); + const bool use_global_stats = ctx.Attr("use_global_stats"); + const bool is_test = ctx.Attr("is_test"); + const float epsilon = ctx.Attr("epsilon"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + // batch_norm with inplace as false will take X as grad input, which + // is same as cuDNN batch_norm backward calculation, batch_norm + // with inplace as true only take Y as input and X should be calculate + // by inverse operation of batch_norm on Y + const Tensor *x; + x = ctx.Input("X"); + + PADDLE_ENFORCE_EQ( + is_test, false, + platform::errors::InvalidArgument( + "`is_test = True` CANNOT be used in train program. If " + "you want to use global status in pre_train model, " + "please set `use_global_stats = True`")); + + // Get the size for each dimension. + // NCHW [batch_size, in_channels, in_height, in_width] + const auto &x_dims = x->dims(); + PADDLE_ENFORCE_GE( + x_dims.size(), 3, + platform::errors::InvalidArgument( + "The size of input X's dimensions should be larger than 2." + "But received: the size of input X's dimensions is [%d]", + x_dims.size())); + PADDLE_ENFORCE_LE( + x_dims.size(), 6, + platform::errors::InvalidArgument( + "The size of input X's dimensions should be less than 6." + "But received: the size of input X's dimensionss is [%d]", + x_dims.size())); + const int S = 2; // share number + const int N = x_dims[1]; + const int C = (data_layout == DataLayout::kNCHW ? x_dims[2] : x_dims[x_dims.size() - 1]); + const int sample_size = x->numel() / S / N / C; + + d_x->mutable_data(ctx.GetPlace()); + + const T *mean_data = saved_mean->data(); + Tensor inv_var_tensor; + inv_var_tensor.ShareDataWith(*saved_inv_variance); // local variance + + // update mean_data, compute inv_var = 1 / sqrt(variance + epsilon) + if (use_global_stats) { + const auto *running_mean = ctx.Input("Mean"); + const auto *running_variance = ctx.Input("Variance"); + mean_data = running_mean->data(); + + Tensor inv_var_tmp; + inv_var_tmp.Resize({S, C}); + + Tensor var_plus_epsilon; + var_plus_epsilon.mutable_data(running_variance->dims(), ctx.GetPlace()); + + Tensor epsilon_expand; + T* epsilon_expand_data = epsilon_expand.mutable_data({S, C}, ctx.GetPlace()); + std::fill(epsilon_expand_data, epsilon_expand_data + S * C, MPC_ONE_SHARE * epsilon); + + mpc_operators->add(running_variance, &epsilon_expand, &var_plus_epsilon); + mpc_operators->inverse_square_root(&var_plus_epsilon, &inv_var_tmp); + framework::TensorCopy(inv_var_tmp, ctx.GetPlace(), &inv_var_tensor); + } + + + if (d_scale && d_bias) { + d_scale->mutable_data(ctx.GetPlace()); + d_bias->mutable_data(ctx.GetPlace()); + } + + // d_bias = np.sum(d_y, axis=0) + // d_scale = np.sum((X - mean) / inv_std * dy, axis=0) + if ((N * sample_size) == 1 && !use_global_stats) { + framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); + return; + } + + + switch (data_layout) { + case DataLayout::kNCHW: { + // d_bias = np.sum(d_y, axis=0) + Tensor dy_sum; + dy_sum.Resize({S, C}); + dy_sum.mutable_data(ctx.GetPlace()); + + ComputeSum(d_y, C, &dy_sum, ctx); // dy_sum + + // d_scale = np.sum((X - mean) / inv_std * dy, axis=0) + // = [np.sum(X * dy) - mean * dy_sum] * inv_std + Tensor x_mul_dy; + x_mul_dy.mutable_data(x->dims(), ctx.GetPlace()); + const DDim d_y_dim = d_y->dims(); + mpc_operators->mul(x, d_y, &x_mul_dy); // X * dy + + Tensor dy_mul_x_sub_mean_mul_invstd_sum; + dy_mul_x_sub_mean_mul_invstd_sum.mutable_data({S, C}, ctx.GetPlace()); + ComputeSum(&x_mul_dy, C, &dy_mul_x_sub_mean_mul_invstd_sum, ctx); // sum(X * dy) + + Tensor dy_sum_mul_mean; + dy_sum_mul_mean.mutable_data({S, C}, ctx.GetPlace()); + mpc_operators->mul(&dy_sum, saved_mean, &dy_sum_mul_mean); // mean * dy_sum + + Tensor tmp; + tmp.mutable_data({S, C}, ctx.GetPlace()); + // [np.sum(X * dy) - mean * dy_sum] + mpc_operators->sub(&dy_mul_x_sub_mean_mul_invstd_sum, &dy_sum_mul_mean, &tmp); + // [np.sum(X * dy) - mean * dy_sum] * inv_std + mpc_operators->mul(&tmp, saved_inv_variance, &dy_mul_x_sub_mean_mul_invstd_sum); + + + if (d_scale && d_bias) { + framework::TensorCopy(dy_sum, ctx.GetPlace(), d_bias); + framework::TensorCopy(dy_mul_x_sub_mean_mul_invstd_sum, ctx.GetPlace(), d_scale); + } + + // d_x = (1. / N) * scale * inv_var * (N * d_y - np.sum(d_y, axis=0) + // - (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0)) + int scale_coefff = use_global_stats ? 1 : N * sample_size; + + Tensor scale_inv_var_nhw; + T* scale_inv_var_nhw_data = scale_inv_var_nhw.mutable_data({S, C}, ctx.GetPlace()); + // scale * inv_var + mpc_operators->mul(scale, saved_inv_variance, &scale_inv_var_nhw); + // (1. / N) * scale * inv_var + mpc_operators->scale(&scale_inv_var_nhw, 1.0 / scale_coefff, &scale_inv_var_nhw); + Tensor scale_inv_var_nhw_expand; + scale_inv_var_nhw_expand.mutable_data(d_y_dim, ctx.GetPlace()); + Expand(&scale_inv_var_nhw, &scale_inv_var_nhw_expand, S, N, C, sample_size); + + if (!use_global_stats) { + Tensor dy_scale; + dy_scale.mutable_data(d_y_dim, ctx.GetPlace()); + // N * dy + mpc_operators->scale(d_y, N * sample_size, &dy_scale); + + Tensor dy_sum_expand; + dy_sum_expand.mutable_data(d_y_dim, ctx.GetPlace()); + Expand(&dy_sum, &dy_sum_expand, S, N, C, sample_size); + + Tensor dy_scale_minus_dy; + dy_scale_minus_dy.mutable_data(d_y_dim, ctx.GetPlace()); + // N * dy - np.sum(d_y, axis=0) + mpc_operators->sub(&dy_scale, &dy_sum_expand, &dy_scale_minus_dy); + + Tensor mean_expand; + mean_expand.mutable_data(d_y_dim, ctx.GetPlace()); + Expand(saved_mean, &mean_expand, S, N, C, sample_size); + + Tensor x_minus_mean; + x_minus_mean.mutable_data(d_y_dim, ctx.GetPlace()); + // (X - mean) + mpc_operators->sub(x, &mean_expand, &x_minus_mean); + // inv_var * inv_var * np.sum(d_y * (X - mean), axis=0)) + mpc_operators->mul(&dy_mul_x_sub_mean_mul_invstd_sum, saved_inv_variance, &tmp); + + Tensor tmp_expand; + tmp_expand.mutable_data(d_y_dim, ctx.GetPlace()); + Expand(&tmp, &tmp_expand, S, N, C, sample_size); + + Tensor tmp_expand2; + tmp_expand2.mutable_data(d_y_dim, ctx.GetPlace()); + // (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0) + mpc_operators->mul(&tmp_expand, &x_minus_mean, &tmp_expand2); + mpc_operators->sub(&dy_scale_minus_dy, &tmp_expand2, &dy_scale); + mpc_operators->mul(&scale_inv_var_nhw_expand, &dy_scale, d_x); + } else { + mpc_operators->mul(&scale_inv_var_nhw_expand, d_y, d_x); + } + break; + } + default: + PADDLE_THROW("Unknown storage order: %s", data_layout_str); + } // switch + } // void ComputeImpl +}; // class MpcBatchNormGradKernel + +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/mpc_elementwise_add_op.h b/core/paddlefl_mpc/operators/mpc_elementwise_add_op.h index d9345c757ec42f9dcc220566fe6c8a78dd64cc35..c495eac77c39e3cae925eed98ee64de6357ec501 100644 --- a/core/paddlefl_mpc/operators/mpc_elementwise_add_op.h +++ b/core/paddlefl_mpc/operators/mpc_elementwise_add_op.h @@ -69,6 +69,119 @@ private: int64_t n_; }; +template +class MidWiseTransformIterator; + +template +class MidWiseTransformIterator + : public std::iterator { + public: + MidWiseTransformIterator(const T *ptr, int n, int post) + : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} + + MidWiseTransformIterator &operator++() { + ++j_; + if (UNLIKELY(j_ == post_)) { + ++i_; + j_ = 0; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + } + return *this; + } + + MidWiseTransformIterator &operator+(int n) { + while (n-- > 0) { + ++j_; + if (UNLIKELY(j_ == post_)) { + ++i_; + j_ = 0; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + } + } + return *this; + } + + bool operator==(const MidWiseTransformIterator + &rhs) const { + return (ptr_ + i_) == &(*rhs); + } + + bool operator!=(const MidWiseTransformIterator + &rhs) const { + return (ptr_ + i_) != &(*rhs); + } + + const T &operator*() { return ptr_[i_]; } + + private: + const T *ptr_; + int64_t i_; + int64_t j_; + int64_t n_; + int64_t post_; +}; + +template +class TransformFunctor { + public: + TransformFunctor(const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z, const DeviceContext &ctx, Functor func, + const bool is_xsize_larger = true) + : x_(x->data()), + y_(y->data()), + z_(z->mutable_data(ctx.GetPlace())), + nx_(x->numel()), + ctx_(ctx), + func_(func), + is_xsize_larger_(is_xsize_larger) { + if (is_xsize_larger_ == false) { + nx_ = y->numel(); + } + } + + inline void Run() const { + platform::Transform trans; + trans(ctx_, x_, x_ + nx_, y_, z_, func_); + } + + inline void RunRowWise(int n, int pre) const { + platform::Transform trans; + if (is_xsize_larger_) { + trans(ctx_, x_, x_ + nx_, + RowwiseTransformIterator(y_, n), z_, func_); + } else { + trans(ctx_, y_, y_ + nx_, + RowwiseTransformIterator(x_, n), z_, func_); + } + } + + inline void RunMidWise(int n, int pre, int post) const { + platform::Transform trans; + if (is_xsize_larger_) { + trans(ctx_, x_, x_ + nx_, + MidWiseTransformIterator(y_, n, post), z_, func_); + } else { + trans(ctx_, y_, y_ + nx_, + MidWiseTransformIterator(x_, n, post), z_, func_); + } + } + + private: + const T *x_; + const T *y_; + OutType *z_; + int64_t nx_; + const DeviceContext &ctx_; + Functor func_; + bool is_xsize_larger_; +}; + template struct AddFunctor { inline HOSTDEVICE T operator()(T x, T y) { return x + y; } @@ -114,38 +227,45 @@ public: if (in_x_t->dims() == in_y_t->dims()) { mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(in_x_t, in_y_t, out_t); } else { - Tensor in_x_t_slice; - Tensor in_y_t_slice; - Tensor out_t_slice; + Tensor in_x_t_slice; + Tensor in_y_t_slice; + Tensor out_t_slice; - for (size_t i = 0; i < SHARE_NUM; ++i) { - in_x_t_slice = in_x_t->Slice(i, i + 1); - in_y_t_slice = in_y_t->Slice(i, i + 1); - out_t_slice = out_t->Slice(i, i + 1); + for (size_t i = 0; i < SHARE_NUM; ++i) { + in_x_t_slice = in_x_t->Slice(i, i + 1); + in_y_t_slice = in_y_t->Slice(i, i + 1); + out_t_slice = out_t->Slice(i, i + 1); - auto x_dims = in_x_t_slice.dims(); - auto y_dims = in_y_t_slice.dims(); - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), - "Axis should be in range [0, x_dims)"); + auto x_dims = in_x_t_slice.dims(); + auto y_dims = in_y_t_slice.dims(); - int pre, n, post; - GetMidDims get_mid_dims; - get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); - PADDLE_ENFORCE_EQ(post, 1, - "post should be equal 1, but received post is [%s]", post); - - auto x_ = in_x_t_slice.data(); - auto y_ = in_y_t_slice.data(); - auto out_ = out_t_slice.data(); - auto nx_ = in_x_t_slice.numel(); - paddle::platform::Transform trans; + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + int pre, n, post; + GetMidDims get_mid_dims; + get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); + + auto x_ = in_x_t_slice.data(); + auto y_ = in_y_t_slice.data(); + auto out_ = out_t_slice.data(); + auto nx_ = in_x_t_slice.numel(); + + paddle::platform::Transform trans; + if (post == 1) { trans(ctx.template device_context(), x_, x_ + nx_, - RowwiseTransformIterator(y_, n), - out_, AddFunctor()); + RowwiseTransformIterator(y_, n), + out_, AddFunctor()); + } else { + trans(ctx.template device_context(), x_, x_ + nx_, + MidWiseTransformIterator(y_, n, post), + out_, AddFunctor()); } } + } } }; @@ -185,17 +305,15 @@ public: int pre, n, post; GetMidDims get_mid_dims; get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); - PADDLE_ENFORCE_EQ(post, 1, - "post should be equal 1, but received post is [%s]", post); + + std::fill(dy_data, dy_data + dy->numel(), static_cast(0)); for (size_t i = 0; i < SHARE_NUM; ++i) { int y_offset = i * n; for (size_t j = 0; j < pre; ++j) { for (size_t k = 0; k < n; ++k) { - int out_offset = i * pre * n + j * n + k; - if (0 == j) { - dy_data[k + y_offset] = dout_data[out_offset]; - } else { + for (size_t m = 0; m < post; ++m) { + int out_offset = i * pre * n * post + j * n * post + k * post + m; dy_data[k + y_offset] += dout_data[out_offset]; } } diff --git a/core/paddlefl_mpc/operators/mpc_gru_op.cc b/core/paddlefl_mpc/operators/mpc_gru_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..83c460a95af3763f1bf574fb7fce279b05034b09 --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_gru_op.cc @@ -0,0 +1,472 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "mpc_gru_op.h" + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/platform/device_context.h" +#include "mpc_op.h" + +#include +#include +#include "core/paddlefl_mpc/operators/math/math_function.h" + +namespace paddle +{ +namespace operators +{ + +using framework::DDim; +using framework::Tensor; +using framework::LoD; + +class MpcGRUOp : public framework::OperatorWithKernel +{ +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override + { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(%s) of MpcGRUOp should not be null.", "Input"); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(%s) of MpcGRUOp should not be null.", "Weight"); + PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), + "Output(%s) of MpcGRUOp should not be null.", "BatchGate"); + PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"), + "Output(%s) of MpcGRUOp should not be null.", + "BatchResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"), + "Output(%s) of MpcGRUOp should not be null.", "BatchHidden"); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(%s) of MpcGRUOp should not be null.", "Hidden"); + auto input_dims_trans = ctx->GetInputDim("Input"); + auto input_dims = framework::make_ddim({input_dims_trans[1], + input_dims_trans[0], input_dims_trans[2]}); + auto weight_dims = ctx->GetInputDim("Weight"); + int input_size = input_dims[2]; + int frame_size = weight_dims[1]; + if (ctx->IsRuntime()) + { + PADDLE_ENFORCE_EQ( + input_size, frame_size * 3, + "The input_size must be 3 times of frame_size in MpcGRUOp."); + } + PADDLE_ENFORCE_EQ( + weight_dims[2], frame_size * 3, + "The shape of mpc Weight matrix must be [frame_size, frame_size * 3]."); + if (ctx->HasInput("H0")) + { + auto h0_dims = ctx->GetInputDim("H0"); + PADDLE_ENFORCE_EQ(h0_dims[2], frame_size, + "The width of H0 must be equal to frame_size."); + } + if (ctx->HasInput("Bias")) + { + auto bias_dims = ctx->GetInputDim("Bias"); + int bias_height = bias_dims[1]; + int bias_width = bias_dims[2]; + PADDLE_ENFORCE_EQ(bias_height, 1, + "The shape of Bias must be [1, frame_size * 3]."); + PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, + "The shape of Bias must be [1, frame_size * 3]."); + } + ctx->SetOutputDim("BatchGate", input_dims); + ctx->SetOutputDim("BatchResetHiddenPrev", {2, input_dims[1], frame_size}); + ctx->SetOutputDim("BatchHidden", {2, input_dims[1], frame_size}); + ctx->SetOutputDim("Hidden", {2, input_dims[1], frame_size}); + ctx->ShareLoD("Input", "Hidden"); + } +}; + +class MpcGRUOpMaker : public framework::OpProtoAndCheckerMaker +{ +public: + void Make() override + { + AddInput("Input", + "(LoDTensor) The first input is a LodTensor, which supports " + "variable-time length input sequence. The underlying tensor in " + "this LoDTenosr is a matrix with shape (T x 2 x 3D), where, T is the " + "total time steps in this mini-batch, D is the hidden size." + "Note: before call this OP, " + "Yout must transpose input shape of mini-batch dim to first dim," + "that is, (2, T, 3D) is transpose to (T, 2, 3D), " + "so that its lod information of shares can be set correctly"); + AddInput("H0", + "(Tensor, optional) The initial hidden state is an optional " + "input. This is a tensor with shape (2 x N x D), where N is the " + "batch size, D is the hidden size.") + .AsDispensable(); + AddInput( + "Weight", + "(Tensor) The learnable hidden-hidden weight matrix with shape " + "(2 x D x 3D), where D is the hidden size. The elements continuous in " + "memory can be divided into two parts. The first part are weights of " + "the update gate and reset gate with shape (2 x D x 2D), and the second " + "part are weights of output candidate with shape (2 x D x D)."); + AddInput("Bias", + "(Tensor, optional) Bias vector with shape (2 x 1 x 3D) concating " + "bias of the update gate, reset gate and output candidate.") + .AsDispensable(); + AddOutput("BatchGate", + "(LoDTensor) To compute with batches, sequence data will be " + "reorganized into several successive batches each containing " + "data from the same time step. The LoDTensor BatchGate contains " + "the update gate, reset gate and output candidate values " + "organized in batches. The LoD size is 2. The first LoD contains " + "the batch offsets and the second LoD contains the indexes in " + "the raw sequence data.") + .AsIntermediate(); + AddOutput( + "BatchResetHiddenPrev", + "(LoDTensor) The reset hidden state LoDTensor organized in batches. " + "This LoDTensor is a matrix with shape (2 x T x D) and has the same LoD " + "with `BatchGate`.") + .AsIntermediate(); + AddOutput( + "BatchHidden", + "(LoDTensor) The hidden state LoDTensor organized in batches. " + "This LoDTensor is a matrix with shape (2 x T x D) and has the same LoD " + "with `BatchGate`.") + .AsIntermediate(); + AddOutput( + "Hidden", + "(LoDTensor) the hidden state LoDTensor organized in sequences. " + "This LoDTensor is a matrix with shape (2 x T x D) and has the same LoD " + "with `BatchGate`."); + AddAttr("activation", + "(string, default tanh) " + "The activation type used for output candidate {h}_t.") + .SetDefault("relu"); + AddAttr( + "gate_activation", + "(string, default sigmoid) " + "The activation type used in update gate and reset gate.") + .SetDefault("sigmoid"); + AddAttr("is_reverse", + "(bool, default: False) " + "whether to compute reversed GRU.") + .SetDefault(false); + AddAttr("origin_mode", + "bool" + "use origin mode in article https://arxiv.org/abs/1412.3555") + .SetDefault(false); + AddComment(R"DOC( +GRU Operator implements part calculations of the complete GRU as following: + +$$ +update\_gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\ +reset\_gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\ +output\_candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\ +output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t) +$$ + +@note To implement the complete GRU, fully-connected operator must be used +before to feed xu, xr and xc as the Input of GRU operator. +)DOC"); + } +}; + +class MpcGRUGradOp : public framework::OperatorWithKernel +{ +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override + { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(%s) of MpcGRUGradOp should not be null.", "Input"); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(%s) of MpcGRUGradOp should not be null.", "Weight"); + PADDLE_ENFORCE(ctx->HasInput("BatchGate"), + "Input(%s) of MpcGRUGradOp should not be null.", "BatchGate"); + PADDLE_ENFORCE(ctx->HasInput("BatchResetHiddenPrev"), + "Input(%s) of MpcGRUGradOp should not be null.", + "BatchResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput("BatchHidden"), + "Input(%s) of MpcGRUOp should not be null.", "BatchHidden"); + PADDLE_ENFORCE(ctx->HasInput("Hidden"), + "Input(%s) of MpcGRUGradOp should not be null.", "Hidden"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), + "Input(%s@GRAD) of MpcGRUGradOp should not be null.", "Hidden"); + auto input_dims_trans = ctx->GetInputDim("Input"); + auto input_dims = framework::make_ddim({input_dims_trans[1], + input_dims_trans[0], input_dims_trans[2]}); + auto weight_dims = ctx->GetInputDim("Weight"); + int input_size = input_dims[2]; + int frame_size = weight_dims[1]; + int weight_height = weight_dims[1]; + int weight_width = weight_dims[2]; + PADDLE_ENFORCE_EQ(input_size, frame_size * 3, + "The input_size must be 3 times of frame_size in MpcGRUOp."); + PADDLE_ENFORCE_EQ( + weight_height, frame_size, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + PADDLE_ENFORCE_EQ( + weight_width, frame_size * 3, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + if (ctx->HasInput("H0")) + { + auto h0_dims = ctx->GetInputDim("H0"); + PADDLE_ENFORCE_EQ(h0_dims[2], frame_size, + "The width of H0 must be equal to frame_size."); + auto h0_grad_name = framework::GradVarName("H0"); + if (ctx->HasOutput(h0_grad_name)) + ctx->SetOutputDim(h0_grad_name, h0_dims); + } + if (ctx->HasInput("Bias")) + { + auto bias_dims = ctx->GetInputDim("Bias"); + int bias_height = bias_dims[1]; + int bias_width = bias_dims[2]; + PADDLE_ENFORCE_EQ(bias_height, 1, + "The shape of Bias must be [1, frame_size * 3]."); + PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, + "The shape of Bias must be [1, frame_size * 3]."); + auto bias_grad_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(bias_grad_name)) + ctx->SetOutputDim(bias_grad_name, bias_dims); + } + auto input_grad_name = framework::GradVarName("Input"); + if (ctx->HasOutput(input_grad_name)) + //transpose input's shape + ctx->SetOutputDim(input_grad_name, input_dims); + auto weight_grad_name = framework::GradVarName("Weight"); + if (ctx->HasOutput(weight_grad_name)) + ctx->SetOutputDim(weight_grad_name, weight_dims); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override + { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Hidden")), + ctx.device_context()); + } +}; + +template +class MpcGRUCPUKernel : public MpcOpKernel { + public: + void BatchCompute(const framework::ExecutionContext& context) const { + using DeviceContext = paddle::platform::CPUDeviceContext; + bool origin_mode = context.Attr("origin_mode"); + auto* input_trans = context.Input("Input"); + auto* h0 = context.Input("H0"); + auto* weight = context.Input("Weight"); + const T* weight_data = weight->data(); + auto* bias = context.Input("Bias"); + auto* batch_gate = context.Output("BatchGate"); + batch_gate->mutable_data(context.GetPlace()); + auto* batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + auto* batch_hidden = context.Output("BatchHidden"); + batch_hidden->mutable_data(context.GetPlace()); + auto* hidden = context.Output("Hidden"); + hidden->mutable_data(context.GetPlace()); + + auto hidden_dims = hidden->dims(); + const auto place = context.GetPlace(); + + bool is_reverse = context.Attr("is_reverse"); + + math::LoDTensor2BatchFunctor to_batch; + auto& dev_ctx = context.template device_context(); + // get input lod + auto input_lod = input_trans->lod(); + LoD gate_lod; + // transpose input to corrected mpc_input + // (T, 2, 3D) to (2, T, 3D) + math::Transpose transpose; + Tensor input; + auto input_dim = input_trans->dims(); + auto in_dim = framework::make_ddim({input_dim[1], input_dim[0], input_dim[2]}); + input.mutable_data( + in_dim, + context.GetPlace()); + transpose(dev_ctx, *input_trans, &input, {1, 0, 2}); + + for (int i = 0; i < 2; ++i) { + // mpc LoDTensor to Batch + Tensor input_s; + Tensor batch_gate_s; + SliceAndReshape(&input, input_s, i); + SliceAndReshape(batch_gate, batch_gate_s, i); + LoDTensor lod_input_s; + LoDTensor lod_batch_gate_s; + lod_input_s.ShareBufferWith(input_s); + lod_input_s.mutable_data(input_s.dims(), place); + lod_batch_gate_s.ShareBufferWith(batch_gate_s); + lod_batch_gate_s.mutable_data(batch_gate_s.dims(), place); + lod_input_s.set_lod(input_lod); + to_batch(dev_ctx, lod_input_s, &lod_batch_gate_s, true, is_reverse); + gate_lod = lod_batch_gate_s.lod(); + } + + if (bias) { + // add mpc bias + math::RowwiseAdd add_bias; + for (int i = 0; i < 2; ++i) { + Tensor batch_gate_s; + Tensor bias_s; + SliceAndReshape(batch_gate, batch_gate_s, i); + SliceAndReshape(bias, bias_s, i); + add_bias(dev_ctx, batch_gate_s, bias_s, &batch_gate_s); + } + } + // split mpc weight from shape (2, D, 3D) to 3 * (2, D, D) + std::vector mpc_splitted_weights_t; + //Split3Dim(context, &mpc_splitted_weights_t, *weight); + SplitWeight(context, mpc_splitted_weights_t, *weight); + + Tensor ordered_h0; + framework::Vector order((gate_lod)[2]); + Tensor mpc_hidden_prev_t; + bool has_hidden_prev = false; + + if (h0) { + // reordered h0 based on lod + ordered_h0.Resize(h0->dims()); + for (int i = 0; i < 2; ++i) { + Tensor h0_s; + Tensor ordered_h0_s; + SliceAndReshape(h0, h0_s, i); + SliceAndReshape(&ordered_h0, ordered_h0_s, i); + ReorderInitState( + context.template device_context(), h0_s, order, + &ordered_h0_s, true); + } + // copy ordered_h0 to mpc_hidden_prev_t + mpc_hidden_prev_t = ordered_h0; + has_hidden_prev = true; + } + auto batch_starts = (gate_lod)[0]; + size_t seq_len = batch_starts.size() - 1; + + std::vector mpc_gate_t_list; + std::vector mpc_reset_hidden_prev_t_list; + std::vector mpc_hidden_t_list; + // compute gru + for (size_t n = 0; n < seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + std::vector mpc_splitted_gate_t; + Tensor mpc_batch_gate_t; + Tensor mpc_reset_hidden_prev_t; + Tensor mpc_hidden_t; + + ToMpcBatchTensor(context, mpc_batch_gate_t, *batch_gate, bstart, bend); + Split3Dim(context, mpc_splitted_gate_t, mpc_batch_gate_t); + ToMpcBatchTensor(context, mpc_reset_hidden_prev_t, *batch_reset_hidden_prev, bstart, bend); + ToMpcBatchTensor(context, mpc_hidden_t, *batch_hidden, bstart, bend); + + ComputGRUUint(context, mpc_splitted_gate_t, mpc_splitted_weights_t, mpc_reset_hidden_prev_t, + mpc_hidden_t, mpc_hidden_prev_t, origin_mode, has_hidden_prev); + + Tensor mpc_gate_t; + Concat3Dim(context, &mpc_gate_t, mpc_splitted_gate_t); + //mpc_hidden_prev_t = mpc_hidden_t; + mpc_hidden_prev_t.mutable_data(mpc_hidden_t.dims(), place); + framework::TensorCopy(mpc_hidden_t, context.GetPlace(), &mpc_hidden_prev_t); + mpc_gate_t_list.emplace_back(mpc_gate_t); + mpc_reset_hidden_prev_t_list.emplace_back(mpc_reset_hidden_prev_t); + mpc_hidden_t_list.emplace_back(mpc_hidden_t); + } + // Concat output variables + ConcatBatchAll(context, batch_gate, mpc_gate_t_list); + ConcatBatchAll(context, batch_reset_hidden_prev, mpc_reset_hidden_prev_t_list); + ConcatBatchAll(context, batch_hidden, mpc_hidden_t_list); + // mpc batch tensor to mpc LoDTensor + for (int i = 0; i < 2; ++i) + { + Tensor batch_hidden_s; + SliceAndReshape(batch_hidden, batch_hidden_s, i); + Tensor hidden_s; + SliceAndReshape(hidden, hidden_s, i); + LoDTensor lod_batch_hidden_s; + LoDTensor lod_hidden_s; + + lod_batch_hidden_s.ShareBufferWith(batch_hidden_s); + lod_batch_hidden_s.mutable_data(batch_hidden_s.dims(), place); + lod_hidden_s.ShareBufferWith(hidden_s); + lod_hidden_s.mutable_data(hidden_s.dims(), place); + math::Batch2LoDTensorFunctor to_seq; + lod_batch_hidden_s.set_lod(gate_lod); + lod_hidden_s.set_lod(gate_lod); + to_seq(dev_ctx, lod_batch_hidden_s, &lod_hidden_s); + } + // set batch_gate_lod for grad op + batch_gate->set_lod(gate_lod); + } + + void ComputeImpl(const framework::ExecutionContext& context) const override { + BatchCompute(context); + } +}; + +template +class MpcGRUGradOpMaker : public framework::SingleGradOpMaker +{ +public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + +protected: + void Apply(GradOpPtr grad_op) const override + { + grad_op->SetType("mpc_gru_grad"); + grad_op->SetInput("Input", this->Input("Input")); + grad_op->SetInput("H0", this->Input("H0")); + grad_op->SetInput("Bias", this->Input("Bias")); + grad_op->SetInput("Weight", this->Input("Weight")); + + grad_op->SetInput("BatchGate", this->Output("BatchGate")); + grad_op->SetInput("BatchResetHiddenPrev", + this->Output("BatchResetHiddenPrev")); + grad_op->SetInput("BatchHidden", this->Output("BatchHidden")); + grad_op->SetInput("Hidden", this->Output("Hidden")); + + grad_op->SetInput(framework::GradVarName("Hidden"), + this->OutputGrad("Hidden")); + + grad_op->SetOutput(framework::GradVarName("H0"), this->InputGrad("H0")); + grad_op->SetOutput(framework::GradVarName("Input"), + this->InputGrad("Input")); + grad_op->SetOutput(framework::GradVarName("Weight"), + this->InputGrad("Weight")); + grad_op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); + + grad_op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(MpcGRUGradOpNoNeedBufferVarInference, "Input", + "Bias"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(mpc_gru, ops::MpcGRUOp, ops::MpcGRUOpMaker, + ops::MpcGRUGradOpMaker, + ops::MpcGRUGradOpMaker); +REGISTER_OPERATOR(mpc_gru_grad, ops::MpcGRUGradOp, + ops::MpcGRUGradOpNoNeedBufferVarInference); +REGISTER_OP_CPU_KERNEL(mpc_gru, ops::MpcGRUCPUKernel); +REGISTER_OP_CPU_KERNEL( + mpc_gru_grad, ops::MpcGRUGradKernel); diff --git a/core/paddlefl_mpc/operators/mpc_gru_op.h b/core/paddlefl_mpc/operators/mpc_gru_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d9f1d6f8f73252cad232fb363e38defc0ec039bf --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_gru_op.h @@ -0,0 +1,772 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "core/paddlefl_mpc/operators/math/sequence2batch.h" +#include "core/paddlefl_mpc/operators/math/concat_and_split.h" +#include "core/paddlefl_mpc/operators/math/math_function.h" +#include "mpc_op.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; +typedef std::function GateActivation; + +template +inline void ComputeSigmoidGrad(const framework::ExecutionContext& context, + Tensor& dy, Tensor& y, Tensor& dx); +template +inline void BackwardStateGrad(const framework::ExecutionContext& context, + std::vector& mpc_splitted_gate_t, + std::vector& mpc_splitted_gate_grad_t, + Tensor& mpc_hidden_prev_t, Tensor& mpc_hidden_prev_grad_t, + Tensor& mpc_hidden_grad_t, + bool origin_mode, bool has_hidden_prev, + bool has_hidden_prev_grad); + +template +inline void BackwarsResetGrad(const framework::ExecutionContext& context, + std::vector& mpc_splitted_gate_t, + std::vector& mpc_splitted_gate_grad_t, + Tensor& mpc_hidden_prev_t, Tensor& mpc_hidden_prev_grad_t, + Tensor& mpc_reset_hidden_prev_grad_t, + bool has_hidden_prev, bool has_hidden_prev_grad); + +template +inline void ReorderInitState(const DeviceContext& ctx, + const framework::Tensor& src, + framework::Vector index_lod, + framework::Tensor* dst, bool indexed_src) { + math::CopyMatrixRowsFunctor row_shuffle; + dst->mutable_data(src.dims(), ctx.GetPlace()); + row_shuffle(ctx, src, index_lod, dst, indexed_src); +} + +template +inline void ComputGRUUint(const framework::ExecutionContext& context, + std::vector& gate_t, + std::vector& weight_t, + Tensor &reset_hidden_prev_t, + Tensor &hidden_t, + Tensor &hidden_prev_t, + bool origin_mode, + bool& has_hidden_prev) { + // compute GRUUnit + Tensor u_h_t; + Tensor r_h_t; + // gate_t[x] shape (2, B, D) + // weight_t[x] shape (2, D, D) + // hidden_prev_t shape (2, B, D) + // hidden_t shape (2, B, D) + u_h_t.mutable_data(gate_t[0].dims(), context.GetPlace()); + r_h_t.mutable_data(gate_t[1].dims(), context.GetPlace()); + auto mpc_operator = mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators(); + if (has_hidden_prev) { + // compute update gate and reset gate: gate_t += hidden_prev_t matmul gate_weight + mpc_operator->matmul(&hidden_prev_t, &weight_t[0], &u_h_t); + mpc_operator->add(&u_h_t, &gate_t[0], &gate_t[0]); + + mpc_operator->matmul(&hidden_prev_t, &weight_t[1], &r_h_t); + mpc_operator->add(&r_h_t, &gate_t[1], &gate_t[1]); + } + + auto GateActProcess = [&gate_t](const GateActivation fun) { + fun(&gate_t[0], &gate_t[0]); + fun(&gate_t[1], &gate_t[1]); + }; + GateActivation activ_functor; + std::string active_gate = context.Attr("gate_activation"); + if (active_gate == "sigmoid_chebyshev") { + activ_functor = std::bind(&paddle::mpc::MpcOperators::sigmoid_chebyshev, + mpc_operator.get(), + std::placeholders::_1, + std::placeholders::_2); + } else if (active_gate == "sigmoid") { + activ_functor = std::bind(&paddle::mpc::MpcOperators::sigmoid, + mpc_operator.get(), + std::placeholders::_1, + std::placeholders::_2); + } else if (active_gate == "sigmoid_enhanced") { + activ_functor = std::bind(&paddle::mpc::MpcOperators::sigmoid_enhanced, + mpc_operator.get(), + std::placeholders::_1, + std::placeholders::_2); + } else { + PADDLE_THROW("gate activation of %s is not implemented yet.", active_gate); + } + GateActProcess(activ_functor); + + if (has_hidden_prev) { + // reset_hidden_prev_t = gate[1] * hidden_prev_t + // compute candidate gate: gate_t[2] += reset_hidden_prev_t matmul state_weight + Tensor r_h_tmp; + r_h_tmp.mutable_data(gate_t[2].dims(), context.GetPlace()); + mpc_operator->mul(&gate_t[1], &hidden_prev_t, &reset_hidden_prev_t); + mpc_operator->matmul(&reset_hidden_prev_t, &weight_t[2], &r_h_tmp); + mpc_operator->add(&r_h_tmp, &gate_t[2], &gate_t[2]); + } else { + //initialize reset_hidden_prev_t and hidden_prev_t as 0 + math::SetConstant zero; + auto& dev_ctx = context.template device_context(); + reset_hidden_prev_t.mutable_data(gate_t[0].dims(), context.GetPlace()); + hidden_prev_t.mutable_data(gate_t[0].dims(), context.GetPlace()); + zero(dev_ctx, &reset_hidden_prev_t, static_cast(0)); + zero(dev_ctx, &hidden_prev_t, static_cast(0)); + has_hidden_prev = true; + } + + mpc_operator->relu(&gate_t[2], &gate_t[2]); + + Tensor u_h_tmp; + Tensor ops_u_h_tmp; + u_h_tmp.mutable_data(hidden_t.dims(), context.GetPlace()); + ops_u_h_tmp.mutable_data(hidden_t.dims(), context.GetPlace()); + if (origin_mode) { + // compute output hidden_t = (gate[0] * hidden_prev_t + gate[2] - gate[0] * gate[2]) + mpc_operator->mul(&gate_t[0], &hidden_prev_t, &u_h_tmp); + mpc_operator->add(&gate_t[2], &u_h_tmp, &u_h_tmp); + mpc_operator->mul(&gate_t[0], &gate_t[2], &ops_u_h_tmp); + mpc_operator->sub(&u_h_tmp, &ops_u_h_tmp, &hidden_t); + } else { + // compute output hidden_t = (gate[0] * gate[2] + hidden_prev_t - gate[0] * hidden_prev_t) + mpc_operator->mul(&gate_t[0], &gate_t[2], &u_h_tmp); + mpc_operator->add(&hidden_prev_t, &u_h_tmp, &u_h_tmp); + mpc_operator->mul(&gate_t[0], &hidden_prev_t, &ops_u_h_tmp); + mpc_operator->sub(&u_h_tmp, &ops_u_h_tmp, &hidden_t); + } +} + +inline void SliceAndReshape(const Tensor* input, Tensor &output, int i) { + // Slice mpc tensor to share[i] + output = input->Slice(i, i + 1); + auto dims = output.dims(); + output.Resize(paddle::framework::slice_ddim(dims, 1, dims.size())); +} + +template +inline void ToMpcBatchTensor(const framework::ExecutionContext& context, + Tensor& output, const Tensor& input, + int start, int end) { + //input : (2 , T, x) -> output: (2, end - start, x) + auto dims = input.dims(); + auto& dev_ctx = context. template device_context(); + math::Transpose transpose; + Tensor tmp; + tmp.mutable_data(framework::make_ddim({dims[1], dims[0], dims[2]}), context.GetPlace()); + transpose(dev_ctx, input, &tmp, {1, 0, 2}); + Tensor tmp_slice = tmp.Slice(start, end); + output.mutable_data(framework::make_ddim({dims[0], end - start, dims[2]}), context.GetPlace()); + transpose(dev_ctx, tmp_slice, &output, {1, 0, 2}); +} + +template +inline void Split3Dim(const framework::ExecutionContext& context, + std::vector& output, + const Tensor& input) { + // input : (2, x, 3D) -> output : 3 * (2, x, D) + auto& dev_ctx = context. template device_context(); + Tensor tmp_trans; + auto dims = input.dims(); + int frame_size = dims[2] / 3; + tmp_trans.mutable_data(framework::make_ddim({dims[2], dims[0], dims[1]}), context.GetPlace()); + math::Transpose transpose; + transpose(dev_ctx, input, &tmp_trans, {2, 0, 1}); + for (int i = 0; i < 3; ++i) { + Tensor tmp_slice = tmp_trans.Slice(i * frame_size, (i + 1) * frame_size); + Tensor tmp_re_trans; + tmp_re_trans.mutable_data(framework::make_ddim({dims[0], dims[1], dims[2] / 3}), + context.GetPlace()); + transpose(dev_ctx, tmp_slice, &tmp_re_trans, {1, 2, 0}); + output.emplace_back(tmp_re_trans); + } +} + + +template +inline void Concat3Dim(const framework::ExecutionContext& context, + Tensor* output, + std::vector& input) { + // input 3 * (2, x, D) -> (2, x, 3D) + math::ConcatFunctor concat; + auto& input_dims = input[0].dims(); + std::vector output_dim{input_dims[0], input_dims[1], input_dims[2] * 3}; + output->mutable_data(framework::make_ddim(output_dim), context.GetPlace()); + auto& dev_ctx = context. template device_context(); + concat(dev_ctx, input, 3, output); +} + +template +inline void SplitWeight(const framework::ExecutionContext& context, + std::vector& splitted_weights, + const Tensor& weight) { + // split weight[0]、weight[1]、weight[2] with shape (2, D, D) from weight(2, D, 3D) + // note that weight[2]'s data start at offset 2 * D * D of weight's data + auto& dev_ctx = context. template device_context(); + auto dims = weight.dims(); + auto frame_size = dims[2] / 3; + splitted_weights.resize(3); + auto place = context.GetPlace(); + + // copy weight[0] weight[1] from weight + Tensor update_weight; + update_weight.mutable_data(framework::make_ddim({2, frame_size, 2 * frame_size}), + place); + //splitted_weights->at(2) = new Tensor(); + splitted_weights[2].mutable_data(framework::make_ddim({2, frame_size, frame_size}), + place); + for (int i = 0; i < 2; ++i) { + Tensor weight_s; + Tensor update_weight_s; + Tensor weight_3_s; + SliceAndReshape(&weight, weight_s, i); + SliceAndReshape(&update_weight, update_weight_s, i); + SliceAndReshape(&splitted_weights[2], weight_3_s, i); + T* update_s_data = update_weight_s.mutable_data(place); + T* weight_s_data = weight_s.data(); + memcpy(update_s_data, weight_s_data, update_weight_s.numel() * sizeof(T)); + // weight[3] + memcpy(weight_3_s.mutable_data(place), weight_s_data + 2 * frame_size * frame_size, + weight_3_s.numel() * sizeof(T)); + } + // split update_weight to weight[0] and weight[1] + math::Transpose transpose; + Tensor weight_trans; + weight_trans.mutable_data(framework::make_ddim({2 * frame_size, 2, frame_size}), place); + transpose(dev_ctx, update_weight, &weight_trans, {2, 0, 1}); + for (int i = 0; i < 2; ++i) { + //splitted_weights->at(i) = new Tensor(); + splitted_weights[i].mutable_data(framework::make_ddim({2, frame_size, frame_size}), place); + transpose(dev_ctx, weight_trans.Slice(frame_size * i, frame_size * (i + 1)), + &splitted_weights[i], {1, 2, 0}); + } +} + +template +inline void ConcatWeight(const framework::ExecutionContext& context, + Tensor* weight, + std::vector& splitted_weights) { + // concat weight[0]、weight[1]、weight[2] with shape (2, D, D) to weight(2, D, 3D) + // note that weight[2]'s data append after weight[0] and weight[1] + // weight[0] and weight[1] are concat as shape (2, D, 2D) in axis 2 + math::ConcatFunctor concat; + std::vector update_weight_list; + update_weight_list.resize(2); + auto place = context.GetPlace(); + auto& splitted_weights_dims = splitted_weights[0].dims(); + std::vector weight_dim{splitted_weights_dims[0], splitted_weights_dims[1], + splitted_weights_dims[2] * 3}; + weight->mutable_data(framework::make_ddim(weight_dim), context.GetPlace()); + for (int i = 0; i < 2; ++i) { + update_weight_list[i] = splitted_weights[i]; + } + auto& dev_ctx = context. template device_context(); + // Concat update weight and reset weight as update weights + Tensor update_weights; + update_weights.mutable_data( + framework::make_ddim({splitted_weights_dims[0], + splitted_weights_dims[1], + splitted_weights_dims[2] * 2}), + place); + concat(dev_ctx, update_weight_list, 3, &update_weights); + // Concat candidate weight + for (int i = 0; i < 2; ++i) { + Tensor weight_s = weight->Slice(i, i + 1); + Tensor update_weights_s = update_weights.Slice(i, i + 1); + Tensor reset_weight_s = splitted_weights[i].Slice(i, i + 1); + + T* weight_s_data = weight_s.mutable_data(place); + T* update_weights_s_data = update_weights_s.data(); + T* reset_weight_s_data = reset_weight_s.data(); + + size_t numel_update = update_weights_s.numel(); + memcpy(weight_s_data, update_weights_s_data, numel_update * sizeof(T)); + memcpy(weight_s_data + numel_update, reset_weight_s_data, reset_weight_s.numel()); + } +} + +template +inline void ConcatBatchOne(const framework::ExecutionContext& context, + Tensor* output, + Tensor& input, + int start, + int end) { + // replace output[2, start:end, x] with input (2, end - start, x) + + auto& dev_ctx = context. template device_context(); + Tensor tmp_trans; + auto dims = output->dims(); + tmp_trans.mutable_data(framework::make_ddim({dims[1], dims[0], dims[2]}), context.GetPlace()); + math::Transpose transpose; + transpose(dev_ctx, *output, &tmp_trans, {1, 0, 2}); + Tensor splitted_t0; + Tensor splitted_t2; + Tensor splitted_t0_rec; + Tensor splitted_t2_rec; + std::vector concat_in; + if (start > 0) { + splitted_t0 = tmp_trans.Slice(0, start); + auto t0_dims = splitted_t0.dims(); + splitted_t0_rec.mutable_data(framework::make_ddim({t0_dims[1], t0_dims[0], t0_dims[2]}), + context.GetPlace()); + transpose(dev_ctx, splitted_t0, &splitted_t0_rec, {1, 0, 2}); + concat_in.emplace_back(splitted_t0_rec); + } + concat_in.emplace_back(input); + if (end < dims[1]) { + splitted_t2 = tmp_trans.Slice(end, dims[1]); + auto t2_dims = splitted_t2.dims(); + splitted_t2_rec.mutable_data(framework::make_ddim({t2_dims[1], t2_dims[0], t2_dims[2]}), + context.GetPlace()); + transpose(dev_ctx, splitted_t2, &splitted_t2_rec, {1, 0, 2}); + concat_in.emplace_back(splitted_t2_rec); + } + + math::ConcatFunctor concat; + concat(dev_ctx, concat_in, 1, output); +} + +template +inline void ConcatBatchAll(const framework::ExecutionContext& context, + Tensor* output, + std::vector& input) { + // Concat all input tensors in dims[1] + math::ConcatFunctor concat; + auto& dev_ctx = context. template device_context(); + concat(dev_ctx, input, 1, output); +} + +template +inline void GRUUnitGradCompute(const framework::ExecutionContext& context, + std::vector& mpc_splitted_gate_t, + std::vector& mpc_splitted_gate_grad_t, + Tensor& mpc_hidden_prev_t, Tensor& mpc_hidden_prev_grad_t, + std::vector& mpc_splitted_weights_t, + std::vector& mpc_splitted_weights_grad_t, + Tensor& mpc_reset_hidden_prev_t, Tensor& mpc_reset_hidden_prev_grad_t, + Tensor& mpc_hidden_grad_t, bool origin_mode, + bool& has_hidden_prev, bool& has_hidden_prev_grad, + bool& has_weight_grad) { + // compute GRUUnitGrad + BackwardStateGrad(context, + mpc_splitted_gate_t, mpc_splitted_gate_grad_t, + mpc_hidden_prev_t, mpc_hidden_prev_grad_t, + mpc_hidden_grad_t, + origin_mode, has_hidden_prev, has_hidden_prev_grad); + PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, + "Protocol %s is not yet created in MPC Protocol."); + auto mpc_operator = mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators(); + math::Transpose transpose; + auto& dev_ctx = context. template device_context(); + std::vector trans_axis{0, 2, 1}; + if (has_hidden_prev && has_hidden_prev_grad) { + auto res_hidden_dims = mpc_reset_hidden_prev_grad_t.dims(); + // (B, D) * (D, D)^T + (B, D) : + //reset_hidden_prev_grad = batch_gate_grad[2] * state_weight[2] + reset_hidden_prev_grad + Tensor tmp; + tmp.mutable_data(res_hidden_dims, context.GetPlace()); + + transpose(dev_ctx, mpc_splitted_weights_t[2], &tmp, trans_axis); + mpc_operator->matmul(&mpc_splitted_gate_t[2], &tmp, &tmp); + mpc_operator->add(&mpc_reset_hidden_prev_grad_t, &tmp, &mpc_reset_hidden_prev_grad_t); + + if (has_weight_grad) { + // (B, D)^T * (B, D) + (D, D) + // state_weight_grad[2] = reset_hidden_prev * batch_gate_grad[2] + state_weight_grad[2] + Tensor tmp1, tmp2; + tmp1.mutable_data( + framework::make_ddim( + std::vector({res_hidden_dims[0], res_hidden_dims[2], res_hidden_dims[1]})), + context.GetPlace()); + tmp2.mutable_data(mpc_splitted_weights_t[2].dims(), context.GetPlace()); + transpose(dev_ctx, mpc_reset_hidden_prev_t, &tmp1, trans_axis); + mpc_operator->matmul(&tmp1, &mpc_splitted_gate_grad_t[2], &tmp2); + mpc_operator->add(&mpc_splitted_weights_grad_t[2], &tmp2, &mpc_splitted_weights_grad_t[2]); + } + } + BackwarsResetGrad(context, + mpc_splitted_gate_t, mpc_splitted_gate_grad_t, + mpc_hidden_prev_t, mpc_hidden_prev_grad_t, + mpc_reset_hidden_prev_grad_t, + has_hidden_prev, has_hidden_prev_grad); + if (has_hidden_prev && has_hidden_prev_grad) { + // (B, 2D) * (D, 2D)^T + (B, D) + // hidden_prev_grad = batch_gate_grad * gate_weight + hidden_prev_grad + // block matrix multiplication: A=[block_A1, block_A2], B^T=[block_B1, block_B2] + // A*B = block_A1*block_B1 + block_A2*block_B2 + Tensor tmp1, tmp2; + tmp1.mutable_data(mpc_splitted_weights_t[0].dims(), context.GetPlace()); + tmp2.mutable_data(mpc_hidden_prev_t.dims(), context.GetPlace()); + transpose(dev_ctx, mpc_splitted_weights_t[0], &tmp1, trans_axis); + mpc_operator->matmul(&mpc_splitted_gate_grad_t[0], &tmp1, &tmp2); + mpc_operator->add(&mpc_hidden_prev_grad_t, &tmp2, &mpc_hidden_prev_grad_t); + + transpose(dev_ctx, mpc_splitted_weights_t[1], &tmp1, trans_axis); + mpc_operator->matmul(&mpc_splitted_gate_grad_t[1], &tmp1, &tmp2); + mpc_operator->add(&mpc_hidden_prev_grad_t, &tmp2, &mpc_hidden_prev_grad_t); + + if (has_weight_grad) { + // (B, D)^T * (B, 2D) + (D, 2D) + // gate_weight_grad = hidden_prev * batch_gate_grad + gate_weight_grad + auto hid_dims = mpc_hidden_prev_t.dims(); + Tensor tmp3, tmp4; + tmp3.mutable_data( + framework::make_ddim({hid_dims[0], hid_dims[2], hid_dims[1]}), + context.GetPlace()); + tmp4.mutable_data(mpc_splitted_weights_t[0].dims(), context.GetPlace()); + transpose(dev_ctx, mpc_hidden_prev_t, &tmp3, trans_axis); + mpc_operator->matmul(&tmp3, &mpc_splitted_gate_grad_t[0], &tmp4); + mpc_operator->add(&mpc_splitted_weights_grad_t[0], &tmp4, &mpc_splitted_weights_grad_t[0]); + + mpc_operator->matmul(&tmp3, &mpc_splitted_gate_grad_t[1], &tmp4); + mpc_operator->add(&mpc_splitted_weights_grad_t[1], &tmp4, &mpc_splitted_weights_grad_t[1]); + } + } +} + +template +inline void BackwardStateGrad(const framework::ExecutionContext& context, + std::vector& mpc_splitted_gate_t, + std::vector& mpc_splitted_gate_grad_t, + Tensor& mpc_hidden_prev_t, Tensor& mpc_hidden_prev_grad_t, + Tensor& mpc_hidden_grad_t, + bool origin_mode, bool has_hidden_prev, + bool has_hidden_prev_grad) { + PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, + "Protocol %s is not yet created in MPC Protocol."); + auto mpc_operator = mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators(); + math::SetConstant zero; + auto& dev_ctx = context.template device_context(); + if (!has_hidden_prev) { + zero(dev_ctx, &mpc_hidden_prev_t, static_cast(0)); + } + if (!has_hidden_prev_grad) { + zero(dev_ctx, &mpc_hidden_prev_grad_t, static_cast(0)); + } + + if (origin_mode) { + // batch_gate_grad[0] = hidden_grad * (hidden_prev - batch_gate[2]) + mpc_operator->sub(&mpc_hidden_prev_t, &mpc_splitted_gate_t[2], &mpc_splitted_gate_grad_t[0]); + mpc_operator->mul(&mpc_hidden_grad_t, &mpc_splitted_gate_grad_t[0], &mpc_splitted_gate_grad_t[0]); + // hidden_prev_grad += hidden_grad * batch_gate[0] + Tensor tmp; + tmp.mutable_data(mpc_hidden_prev_grad_t.dims(), context.GetPlace()); + mpc_operator->mul(&mpc_hidden_grad_t, &mpc_splitted_gate_t[0], &tmp); + mpc_operator->add(&mpc_hidden_prev_grad_t, &tmp, &mpc_hidden_prev_grad_t); + + // batch_gate_grad[2] = activation(hidden_grad * (1-batch_gate[0]), batch_gate[2]) + // activation = grad_relu (return a * (b > 0.0 ? 1.0 : 0.0);) + Tensor tmp1; + tmp1.mutable_data(mpc_splitted_gate_grad_t[2].dims(), context.GetPlace()); + mpc_operator->mul(&mpc_hidden_grad_t, &mpc_splitted_gate_t[0], &tmp1); + mpc_operator->sub(&mpc_hidden_grad_t, &tmp1, &tmp1); + mpc_operator->relu_grad(&mpc_splitted_gate_t[2], &tmp1, &mpc_splitted_gate_grad_t[2], 0); + + } else { + // batch_gate_grad[0] = hidden_grad * (batch_gate[2] - hidden_prev) + mpc_operator->sub(&mpc_splitted_gate_t[2], &mpc_hidden_prev_t, &mpc_splitted_gate_grad_t[0]); + mpc_operator->mul(&mpc_hidden_grad_t, &mpc_splitted_gate_grad_t[0], &mpc_splitted_gate_grad_t[0]); + // hidden_prev_grad += hidden_grad * (1 - batch_gate[0]) + Tensor tmp; + tmp.mutable_data(mpc_hidden_prev_grad_t.dims(), context.GetPlace()); + mpc_operator->mul(&mpc_hidden_grad_t, &mpc_splitted_gate_t[0], &tmp); + mpc_operator->sub(&mpc_hidden_grad_t, &tmp, &tmp); + mpc_operator->add(&mpc_hidden_prev_grad_t, &tmp, &mpc_hidden_prev_grad_t); + + // batch_gate_grad[2] = activation(hidden_grad*batch_gate[0], batch_gate[2]) + // activation = grad_relu + Tensor tmp1; + tmp1.mutable_data(mpc_splitted_gate_grad_t[2].dims(), context.GetPlace()); + mpc_operator->mul(&mpc_hidden_grad_t, &mpc_splitted_gate_t[0], &tmp1); + mpc_operator->relu_grad(&mpc_splitted_gate_t[2], &tmp1, &mpc_splitted_gate_grad_t[2], 0); + } +} + +template +inline void BackwarsResetGrad(const framework::ExecutionContext& context, + std::vector& mpc_splitted_gate_t, + std::vector& mpc_splitted_gate_grad_t, + Tensor& mpc_hidden_prev_t, Tensor& mpc_hidden_prev_grad_t, + Tensor& mpc_reset_hidden_prev_grad_t, + bool has_hidden_prev, bool has_hidden_prev_grad) { + PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, + "Protocol %s is not yet created in MPC Protocol."); + auto mpc_operator = mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators(); + math::SetConstant zero; + auto& dev_ctx = context.template device_context(); + if (!has_hidden_prev) { + zero(dev_ctx, &mpc_hidden_prev_t, static_cast(0)); + } + if (!has_hidden_prev_grad) { + zero(dev_ctx, &mpc_hidden_prev_grad_t, static_cast(0)); + } + if (!has_hidden_prev || !has_hidden_prev_grad) { + zero(dev_ctx, &mpc_reset_hidden_prev_grad_t, static_cast(0)); + } + // batch_gate_grad[1] = reset_hidden_grad * hidden_prev + mpc_operator->mul(&mpc_reset_hidden_prev_grad_t, &mpc_hidden_prev_t, &mpc_splitted_gate_grad_t[1]); + // hidden_prev_grad += reset_hidden_grad * batch_gate_grad[1] + Tensor tmp; + tmp.mutable_data(mpc_hidden_prev_grad_t.dims(), context.GetPlace()); + mpc_operator->mul(&mpc_reset_hidden_prev_grad_t, &mpc_splitted_gate_grad_t[1], &tmp); + mpc_operator->add(&mpc_hidden_prev_grad_t, &tmp, &mpc_hidden_prev_grad_t); + // batch_gate_grad[0] = sigmoid_grad(batch_gate_grad[0], batch_gate[0]) + ComputeSigmoidGrad(context, mpc_splitted_gate_grad_t[0], + mpc_splitted_gate_t[0], mpc_splitted_gate_grad_t[0]); + // batch_gate_grad[1] = sigmoid_grad(batch_gate_grad[1], batch_gate[1]) + ComputeSigmoidGrad(context, mpc_splitted_gate_grad_t[1], + mpc_splitted_gate_t[1], mpc_splitted_gate_grad_t[1]); +} + +template +inline void ComputeSigmoidGrad(const framework::ExecutionContext& context, + Tensor& dy, Tensor& y, Tensor& dx) { + // dx = dy * (1.0 - y * y); + PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, + "Protocol %s is not yet created in MPC Protocol."); + auto mpc_operator = mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators(); + Tensor tmp; + tmp.mutable_data(dx.dims(), context.GetPlace()); + mpc_operator->mul(&y, &y, &tmp); + mpc_operator->mul(&dy, &tmp, &tmp); + mpc_operator->sub(&dy, &tmp, &dx); +} + +template +class MpcGRUGradKernel : public MpcOpKernel { +public: + void BatchCompute(const framework::ExecutionContext& context) const { + bool origin_mode = context.Attr("origin_mode"); + auto* h0 = context.Input("H0"); + auto* weight = context.Input("Weight"); + const T* weight_data = weight->data(); + auto* batch_gate = context.Input("BatchGate"); + auto* batch_reset_hidden_prev = + context.Input("BatchResetHiddenPrev"); + auto* batch_hidden = context.Input("BatchHidden"); + auto* hidden = context.Input("Hidden"); + auto* hidden_grad = + context.Input(framework::GradVarName("Hidden")); + auto* input_grad = + context.Output(framework::GradVarName("Input")); + auto* h0_grad = context.Output(framework::GradVarName("H0")); + auto* weight_grad = + context.Output(framework::GradVarName("Weight")); + auto* bias_grad = context.Output(framework::GradVarName("Bias")); + + auto gate_dims = batch_gate->dims(); + auto hidden_dims = hidden->dims(); + auto gate_lod = batch_gate->lod(); + const auto& place = context.GetPlace(); + bool has_hidden_prev = false; + bool has_hidden_prev_grad = false; + bool has_weight_grad = false; + + math::LoDTensor2BatchFunctor to_batch; + LoDTensor batch_hidden_grad, batch_gate_grad, batch_reset_hidden_prev_grad; + batch_hidden_grad.mutable_data(hidden_dims, context.GetPlace()); + batch_gate_grad.mutable_data(gate_dims, context.GetPlace()); + batch_reset_hidden_prev_grad.mutable_data(hidden_dims, + context.GetPlace()); + math::SetConstant zero; + auto& dev_ctx = context.template device_context(); + zero(dev_ctx, &batch_hidden_grad, static_cast(0)); + zero(dev_ctx, &batch_gate_grad, static_cast(0)); + zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast(0)); + + Tensor ordered_h0, ordered_h0_grad; + + framework::Vector order(gate_lod[2]); + + if (h0) { + // Reorder mpc h0 + ordered_h0.mutable_data(h0->dims(), place); + for (int i = 0; i < 2; ++i) { + Tensor h0_s; + SliceAndReshape(h0, h0_s, i); + Tensor ordered_h0_s; + SliceAndReshape(&ordered_h0, ordered_h0_s, i); + ReorderInitState(dev_ctx, h0_s, order, &ordered_h0_s, + true); + } + } + if (h0_grad) { + ordered_h0_grad.mutable_data(h0_grad->dims(), context.GetPlace()); + zero(context.template device_context(), &ordered_h0_grad, + static_cast(0)); + } + + bool is_reverse = context.Attr("is_reverse"); + for (int i = 0; i < 2; ++i) { + // mpc LoDTensor to mpc batch + Tensor batch_hidden_grad_s; + SliceAndReshape(&batch_hidden_grad, batch_hidden_grad_s, i); + Tensor hidden_grad_s; + SliceAndReshape(hidden_grad, hidden_grad_s, i); + LoDTensor lod_batch_hidden_grad_s; + LoDTensor lod_hidden_grad_s; + lod_batch_hidden_grad_s.ShareBufferWith(batch_hidden_grad_s); + lod_batch_hidden_grad_s.mutable_data(batch_hidden_grad_s.dims(), place); + lod_hidden_grad_s.ShareBufferWith(hidden_grad_s); + lod_hidden_grad_s.mutable_data(hidden_grad_s.dims(), place); + lod_hidden_grad_s.set_lod(gate_lod); + lod_batch_hidden_grad_s.set_lod(gate_lod); + to_batch(dev_ctx, lod_hidden_grad_s, &lod_batch_hidden_grad_s, false, is_reverse); + } + if (weight_grad) { + T* gate_weight_grad = + weight_grad->mutable_data(context.GetPlace()); + zero(dev_ctx, weight_grad, static_cast(0)); + has_weight_grad = true; + } + // split weights + std::vector mpc_splitted_weights_t; + SplitWeight(context, mpc_splitted_weights_t, *weight); + + auto batch_starts = gate_lod[0]; + size_t num_batch = batch_starts.size() - 1; + for (int n = static_cast(num_batch) - 1; n >= 0; n--) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + int bstart_pre = static_cast(batch_starts[n - 1]); + + // Split mpc tensors + Tensor mpc_hidden_grad_t; + Tensor mpc_hidden_prev_t; + Tensor mpc_hidden_prev_grad_t; + Tensor mpc_reset_hidden_prev_t; + Tensor mpc_reset_hidden_prev_grad_t; + std::vector splitted_batch_gate_t; + std::vector mpc_splitted_gate_t; + std::vector splitted_batch_gate_grad_t; + std::vector mpc_splitted_gate_grad_t; + std::vector mpc_splitted_weights_grad_t; + + if (weight_grad) { + SplitWeight(context, mpc_splitted_weights_grad_t, *weight_grad); + } + ToMpcBatchTensor(context, mpc_hidden_grad_t, batch_hidden_grad, bstart, bend); + ToMpcBatchTensor(context, mpc_reset_hidden_prev_t, *batch_reset_hidden_prev, bstart, bend); + ToMpcBatchTensor(context, mpc_reset_hidden_prev_grad_t, + batch_reset_hidden_prev_grad, bstart, bend); + + Split3Dim(context, splitted_batch_gate_grad_t, batch_gate_grad); + Split3Dim(context, splitted_batch_gate_t, *batch_gate); + for (int i = 0; i < 3; ++i) { + ToMpcBatchTensor(context, mpc_splitted_gate_grad_t[i], + splitted_batch_gate_grad_t[i], bstart, bend); + ToMpcBatchTensor(context, mpc_splitted_gate_t[i], + splitted_batch_gate_t[i], bstart, bend); + } + if (n == 0) { + if (h0) { + // hidden_prev_t = ordered_h0 + mpc_hidden_prev_t.mutable_data( + ordered_h0.dims(), place); + framework::TensorCopy(ordered_h0, place, &mpc_hidden_prev_t); + has_hidden_prev = true; + if (h0_grad) { + // hidden_prev_grad_t = ordered_h0_grad + mpc_hidden_prev_grad_t.mutable_data( + ordered_h0_grad.dims(), place); + framework::TensorCopy(ordered_h0_grad, place, &mpc_hidden_prev_grad_t); + has_hidden_prev_grad = true; + } + } + } else { + ToMpcBatchTensor(context, mpc_hidden_prev_t, *batch_hidden, bstart_pre, bstart); + ToMpcBatchTensor(context, mpc_hidden_prev_grad_t, batch_hidden_grad, bstart_pre, bstart); + + } + // compute GRUUnitGrad + GRUUnitGradCompute(context, + mpc_splitted_gate_t, mpc_splitted_gate_grad_t, + mpc_hidden_prev_t, mpc_hidden_prev_grad_t, + mpc_splitted_weights_t, mpc_splitted_weights_grad_t, + mpc_reset_hidden_prev_t, mpc_reset_hidden_prev_grad_t, + mpc_hidden_grad_t, origin_mode, has_hidden_prev, + has_hidden_prev_grad, has_weight_grad); + // cancat mpc tensor to gru_grad output variables + if (weight_grad) { + ConcatWeight(context, weight_grad, mpc_splitted_weights_grad_t); + } + Tensor mpc_batch_gate_grad_t; + Concat3Dim(context, &mpc_batch_gate_grad_t, mpc_splitted_gate_grad_t); + ConcatBatchOne(context, &batch_gate_grad, mpc_batch_gate_grad_t, bstart, bend); + ConcatBatchOne(context, &batch_hidden_grad, mpc_hidden_prev_grad_t, bstart_pre, bstart); + ConcatBatchOne(context, &batch_reset_hidden_prev_grad, mpc_reset_hidden_prev_grad_t, bstart, bend); + } + if (input_grad) { + // batch to lodTensor for mpc input_grad + input_grad->mutable_data(context.GetPlace()); + math::Batch2LoDTensorFunctor to_seq; + batch_gate_grad.set_lod(gate_lod); + for (int i = 0; i < 2; ++i) { + Tensor batch_gate_grad_s; + SliceAndReshape(&batch_gate_grad, batch_gate_grad_s, i); + Tensor input_grad_s; + SliceAndReshape(input_grad, input_grad_s, i); + + LoDTensor lod_batch_gate_grad_s; + LoDTensor lod_input_grad_s; + lod_batch_gate_grad_s.ShareBufferWith(batch_gate_grad_s); + lod_batch_gate_grad_s.mutable_data(batch_gate_grad_s.dims(), place); + lod_batch_gate_grad_s.set_lod(gate_lod); + lod_input_grad_s.ShareBufferWith(input_grad_s); + lod_input_grad_s.mutable_data(input_grad_s.dims(), place); + to_seq(dev_ctx, lod_batch_gate_grad_s, &lod_input_grad_s); + } + } + if (bias_grad) { + // col_sum mpc bias_grad + bias_grad->mutable_data(context.GetPlace()); + math::ColwiseSum col_sum; + for (int i = 0; i < 2; ++i) { + Tensor batch_gate_grad_s; + SliceAndReshape(&batch_gate_grad, batch_gate_grad_s, i); + Tensor bias_grad_s; + SliceAndReshape(bias_grad, bias_grad_s, i); + col_sum(dev_ctx, batch_gate_grad_s, &bias_grad_s); + } + } + if (h0 && h0_grad) { + // Reorder mpc h0_grad + for (int i = 0; i < 2; ++i) { + Tensor ordered_h0_grad_s; + SliceAndReshape(&ordered_h0_grad, ordered_h0_grad_s, i); + Tensor h0_grad_s; + SliceAndReshape(h0_grad, h0_grad_s, i); + ReorderInitState(dev_ctx, ordered_h0_grad_s, order, + &h0_grad_s, false); + } + } + } + + void ComputeImpl(const framework::ExecutionContext& context) const override { + BatchCompute(context); + } +}; + +} // namespace operators +} // namespace paddle + + diff --git a/core/paddlefl_mpc/operators/mpc_lookup_table_v2_op.cc b/core/paddlefl_mpc/operators/mpc_lookup_table_v2_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2bada784621b94da335aa41facafacd7edb20e98 --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_lookup_table_v2_op.cc @@ -0,0 +1,218 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/no_need_buffer_vars_inference.h" +#include "paddle/fluid/framework/var_type_inference.h" +#include "paddle/fluid/framework/op_registry.h" + +#include "mpc_lookup_table_v2_op.h" + +namespace paddle { +namespace operators { + +class MpcLookupTableV2Op : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, + "Input(W) of LookupTableV2Op should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Ids"), true, + "Input(Ids) of LookupTableV2Op should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of LookupTableV2Op should not be null."); + + auto table_dims = ctx->GetInputDim("W"); + auto ids_dims = ctx->GetInputDim("Ids"); + int ids_rank = ids_dims.size(); + VLOG(5) << "ids rank is " << ids_rank << std::endl; + PADDLE_ENFORCE_EQ( + table_dims.size(), 3, + "ShapeError: The dimensions of the 'mpc lookup table' must be 3. " + "But received lookup table's dimensions = %d, " + "lookup table's shape = [%s].", + table_dims.size(), table_dims); + PADDLE_ENFORCE_EQ( + ids_dims.size(), 3, + "ShapeError: The dimensions of the 'idexes' must be 3, " + "Other dimensions are not supported temporarily. " + "Received idexes' dimensions = %d, " + "idexes's shape = [%s].", + table_dims.size(), table_dims); + PADDLE_ENFORCE_EQ( + table_dims[0], 2, + "ShapeError: The first dimensions of the 'mpc lookup table' must be 2. " + "But received lookup table's first dimensions = %d.", + table_dims[0]); + PADDLE_ENFORCE_EQ( + ids_dims[0], 2, + "ShapeError: The first dimensions of the 'indexes' must be 2. " + "But received indexes' first dimensions = %d.", + ids_dims[0]); + + auto output_dims = framework::vectorize(ids_dims); + output_dims[output_dims.size() - 1] = table_dims[2]; + auto out_dims = framework::make_ddim(output_dims); + ctx->SetOutputDim("Out", out_dims); + + if (ctx->GetOutputsVarType("Out")[0] == + framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("Ids", /*->*/ "Out"); + } + } + +protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class MpcLookupTableV2OpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + AddInput("W", + "(Tensor) The input represents embedding tensors, " + "which is a learnable parameter."); + AddInput("Ids", + "An input with type int64 " + "contains the ids to be looked up in W."); + AddOutput("Out", "The lookup results, which have the same type as W."); + AddAttr("is_sparse", + "(boolean, default false) " + "Sparse update.") + .SetDefault(false); + AddAttr("is_distributed", + "(boolean, default false) distributed lookup table.") + .SetDefault(false); + AddAttr("padding_idx", + "(int64, default -1) " + "If the value is -1, it makes no effect to lookup. " + "Otherwise the given value indicates padding the output " + "with zeros whenever lookup encounters it in Ids.") + .SetDefault(kNoPadding); + + // for parameter prefetch + AddAttr("remote_prefetch", "").SetDefault(false); + AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); + AddAttr>("height_sections", + "Height for each output SelectedRows.") + .SetDefault(std::vector({})); + AddAttr>( + "epmap", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints in the order of input variables for mapping") + .SetDefault({}); + AddAttr>( + "table_names", + "(string vector, the splited table names that will be fetched from " + "parameter server)" + "in the order of input variables for mapping") + .SetDefault({}); + + AddComment(R"DOC( +Lookup Table V2 Operator. + +This operator is used to perform lookups on the parameter W, +then concatenated into a dense tensor. + +The input Ids can carry the LoD (Level of Details) information, +or not. And the output only shares the LoD information with input Ids. + +)DOC"); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(MpcLookupTableV2GradOpNoBuffer, "W"); + +template +class MpcLookupTableV2GradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("mpc_lookup_table_v2_grad"); + + op->SetInput("W", this->Input("W")); + op->SetInput("Ids", this->Input("Ids")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + + op->SetOutput(framework::GradVarName("W"), this->InputGrad("W")); + + op->SetAttrMap(this->Attrs()); + } +}; + +class MpcLookupTableV2OpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto table_dims = ctx->GetInputDim("W"); + ctx->SetOutputDim(framework::GradVarName("W"), table_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class MpcLookupTableV2OpGradVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext* ctx) const override { + auto out_var_name = framework::GradVarName("W"); + auto attr = ctx->GetAttr("is_sparse"); + bool is_sparse = boost::get(attr); + if (is_sparse) { + VLOG(3) << "mpc_lookup_table_v2_grad op " << framework::GradVarName("W") + << " is set to SelectedRows"; + ctx->SetOutputType(out_var_name, + framework::proto::VarType::SELECTED_ROWS); + } else { + VLOG(3) << "mpc_lookup_table_v2_grad op " << framework::GradVarName("W") + << " is set to LoDTensor"; + ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR); + } + ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(mpc_lookup_table_v2, ops::MpcLookupTableV2Op, + ops::MpcLookupTableV2OpMaker, + ops::MpcLookupTableV2GradOpMaker, + ops::MpcLookupTableV2GradOpMaker); + +REGISTER_OPERATOR(mpc_lookup_table_v2_grad, ops::MpcLookupTableV2OpGrad, + ops::MpcLookupTableV2GradOpNoBuffer, + ops::MpcLookupTableV2OpGradVarTypeInference); + +REGISTER_OP_CPU_KERNEL(mpc_lookup_table_v2, ops::MpcLookupTableV2Kernel); +REGISTER_OP_CPU_KERNEL(mpc_lookup_table_v2_grad, + ops::MpcLookupTableV2GradKernel); + diff --git a/core/paddlefl_mpc/operators/mpc_lookup_table_v2_op.h b/core/paddlefl_mpc/operators/mpc_lookup_table_v2_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0ef48c410fa33a3e84c8f356609f4daf79b62e40 --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_lookup_table_v2_op.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "mpc_op.h" +#include +#include + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/platform/device_context.h" +#include "core/paddlefl_mpc/operators/math/math_function_impl.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; +using DDim = framework::DDim; + +constexpr int64_t kNoPadding = -1; + +template +class MpcLookupTableV2Kernel : public MpcOpKernel { +public: + void ComputeImpl(const framework::ExecutionContext &context) const override { + auto *ids_t = context.Input("Ids"); // int tensor + auto *output_t = context.Output("Out"); // float tensor + auto *table_var = context.Input("W"); + auto *ids = ids_t->data(); + auto *table = table_var->data(); + auto *output = output_t->mutable_data(context.GetPlace()); + + PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, + "Protocol %s is not yet created in MPC Protocol."); + mpc::MpcInstance::mpc_instance()->mpc_protocol()-> + mpc_operators()->matmul(ids_t, table_var, output_t); + } +}; + +template +class MpcLookupTableV2GradKernel : public MpcOpKernel { +public: + void ComputeImpl(const framework::ExecutionContext &context) const override { + auto *ids_t = context.Input("Ids"); + auto id_dim = ids_t->dims(); + auto col_width = id_dim[1]; + auto row_width = id_dim[2]; + auto *d_output_t = context.Input(framework::GradVarName("Out")); + auto *d_table_t = context.Output(framework::GradVarName("W")); + + // transpose ids_t + auto *ids = ids_t->data(); + auto *table = d_table_t->mutable_data(context.GetPlace()); + auto *output = d_output_t->data(); + + Tensor ids_trans_t; + auto *ids_trans = ids_trans_t.mutable_data({2, row_width, col_width}, context.GetPlace()); + + math::Transpose transpose; + auto& dev_ctx = context. template device_context(); + transpose(dev_ctx, *ids_t, &ids_trans_t, {0, 2, 1}); + PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol."); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(&ids_trans_t, d_output_t, d_table_t); + } +}; + +} // namespace operators +} // namespace paddle + diff --git a/core/paddlefl_mpc/operators/mpc_mul_op.h b/core/paddlefl_mpc/operators/mpc_mul_op.h index 67a8b065ada96eb316d10f4114c125005b812e50..4018330742a43419cd705b4a14217cf920a37d28 100644 --- a/core/paddlefl_mpc/operators/mpc_mul_op.h +++ b/core/paddlefl_mpc/operators/mpc_mul_op.h @@ -150,6 +150,7 @@ public: if (dx) { dx->mutable_data(ctx.GetPlace()); + auto dx_dim = dx->dims(); if (dx->dims().size() > 3) { dx->Resize({2, x_mat_width, x_mat_height}); } @@ -160,7 +161,6 @@ public: // dx = dout * y'. dx: M x K, dout : M x N, y : K x N mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul( &dout_matrix, &y_matrix_trans, dx); - auto dx_dim = dx->dims(); if (dx_dim.size() > 3) { dx->Resize(dx_dim); } @@ -168,6 +168,7 @@ public: if (dy) { dy->mutable_data(ctx.GetPlace()); + auto dy_dim = dy->dims(); if (dy->dims().size() > 3) { dy->Resize({2, y_mat_width, y_mat_height}); } @@ -179,7 +180,6 @@ public: // dy = x' * dout. dy K x N, dout : M x N, x : M x K mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul( &x_matrix_trans, &dout_matrix, dy); - auto dy_dim = dy->dims(); if (dy_dim.size() > 3) { dy->Resize(dy_dim); } diff --git a/core/paddlefl_mpc/operators/mpc_pool_op.cc b/core/paddlefl_mpc/operators/mpc_pool_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8b25cda3db9d09fb2d426f5606a81e286af98357 --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_pool_op.cc @@ -0,0 +1,277 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include +#include "mpc_pool_op.h" + +namespace paddle { +namespace operators { + +int PoolOutputSize(int input_size, int filter_size, int padding_1, + int padding_2, int stride, bool ceil_mode) { + int output_size; + if (!ceil_mode) { + output_size = (input_size - filter_size + padding_1 + padding_2) / stride + 1; + } else { + output_size = (input_size - filter_size + padding_1 + padding_2 + stride - 1) / stride + 1; + } + PADDLE_ENFORCE_GT( + output_size, 0, + "ShapeError: the output size must be greater than 0. But received: " + "output_size = %d due to the settings of input_size(%d), padding(%d,%d), " + "k_size(%d) and stride(%d). Please check again!", + output_size, input_size, padding_1, padding_2, filter_size, stride); + return output_size; +} + + +class MpcPoolOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override{ + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "X(Input) of Pooling should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Out(Output) of Pooling should not be null."); + + std::string pooling_type = ctx->Attrs().Get("pooling_type"); + std::vector ksize = ctx->Attrs().Get>("ksize"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + bool ceil_mode = ctx->Attrs().Get("ceil_mode"); + // bool adaptive = ctx->Attrs().Get("adaptive"); + bool global_pooling = ctx->Attrs().Get("global_pooling"); + std::string data_format = ctx->Attrs().Get("data_format"); + std::string padding_algorithm = ctx->Attrs().Get("padding_algorithm"); + + auto in_x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(in_x_dims.size(), 5, + "ShapeError: the input of Op(pool) should be 5-D Tensor (ciphertext). " + "But received: %u-D Tensor and it's shape is [%s].", + in_x_dims.size(), in_x_dims); + + PADDLE_ENFORCE_EQ(in_x_dims.size() - ksize.size(), 3U, + "ShapeError: the dimension of input(ciphertext) minus the size of " + "Attr(ksize)(plaintext) must be euqal to 3 in Op(pool). " + "But received: the dimension of input minus the size " + "of Attr(ksize) is %d, the " + "input's dimension is %d, the shape of input " + "is [%s], the Attr(ksize)'s size is %d, the Attr(ksize) is [%s].", + in_x_dims.size() - ksize.size(), in_x_dims.size(), in_x_dims, + ksize.size(), framework::make_ddim(ksize)); + + PADDLE_ENFORCE_EQ(ksize.size(), strides.size(), + "ShapeError: the size of Attr(ksize) and Attr(strides) in " + "Op(pool) must be equal. " + "But received: Attr(ksize)'s size is %d, Attr(strides)'s " + "size is %d, Attr(ksize) is [%s], Attr(strides)is [%s].", + ksize.size(), strides.size(), framework::make_ddim(ksize), + framework::make_ddim(strides)); + + PADDLE_ENFORCE_EQ(data_format, "NCHW", + "data format can only be 'NCHW' ", + in_x_dims.size(), in_x_dims); + + // update paddings if "SAME" or global_pooling + framework::DDim data_dims; + data_dims = framework::slice_ddim(in_x_dims, 3, in_x_dims.size()); + UpdatePadding(&paddings, global_pooling, padding_algorithm, + data_dims, strides, ksize); + + if (global_pooling) { + UpdateKsize(&ksize, data_dims); + } + + std::vector output_shape; + std::vector one_hot_tensor_shape; + for (int i = 0; i < data_dims.size(); ++i) { + if ((!ctx->IsRuntime()) && (data_dims[i] < 0)) { + output_shape.push_back(data_dims[i]); + } else { + output_shape.push_back( + PoolOutputSize(data_dims[i], ksize[i], paddings[2 * i], + paddings[2 * i + 1], strides[i], ceil_mode)); + } + } + + output_shape.insert(output_shape.begin(), in_x_dims[0]); // share size + output_shape.insert(output_shape.begin() + 1, in_x_dims[1]); // output_N = input_N + output_shape.insert(output_shape.begin() + 2, in_x_dims[2]); // output_C = input_C + + one_hot_tensor_shape.push_back(in_x_dims[0]); // share size + one_hot_tensor_shape.push_back(in_x_dims[1]); // input_N + one_hot_tensor_shape.push_back(in_x_dims[2]); // input_C + one_hot_tensor_shape.push_back(ksize[0] * ksize[1]); + one_hot_tensor_shape.push_back(output_shape[3] * output_shape[4]); + + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->ShareLoD("X", "Out"); + + ctx->SetOutputDim("One_hot_tensor", framework::make_ddim(one_hot_tensor_shape)); + ctx->ShareLoD("X", "One_hot_tensor"); + } + +protected: + framework::OpKernelType GetExpectedKernelType(const framework::ExecutionContext& ctx) const { + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = "AnyLayout"; + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + + +class MpcPoolOpGrad : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override{ + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + +protected: + framework::OpKernelType GetExpectedKernelType(const framework::ExecutionContext& ctx) const { + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = "AnyLayout"; + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, library_); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + + +class MpcPool2dOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override{ + AddInput("X", + "(Tensor) The input tensor of pooling operator. " + "The format of input tensor is NCHW, where N is batch size, C is the " + "number of channels, H is the height of the feature, " + "and W is the width of the feature."); + AddOutput("Out", + "(Tensor) The output tensor of pooling operator. " + "The format of output tensor is also NCHW, " + "where N is batch size, C is the number of channels, " + "H is the height of the feature, " + "and W is the width of the feature."); + AddOutput("One_hot_tensor", + "one hot tensor"); + AddAttr("pooling_type", + "(string), pooling type, can be \"max\" for max-pooling " + "and \"avg\" for average-pooling.") + .InEnum({"max", "avg"}); + AddAttr>("ksize", + "(vector) The pooling window " + "size(height, width) of the pooling operator. " + "If global_pooling = true, ksize and paddings will " + "be ignored."); + AddAttr("global_pooling", + "(bool) Whether to use the global pooling. " + "If global_pooling = true, kernel size and paddings will be ignored. " + "Default False.") + .SetDefault(false); + AddAttr>("strides", + "(vector, default {1, 1}), strides(height, " + "width) of pooling operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", + "(vector, default {0,0}), paddings(height_top, height_bottom, " + "width_left, wifth_right) of pooling operator." + "If global_pooling = true, paddings and kernel size will be ignored.") + .SetDefault({0, 0}); + AddAttr("exclusive", + "(bool) When true, will exclude the zero-padding in the " + "averaging calculating, otherwise, include the zero-padding. Note, it " + "is only used when pooling_type is avg. The default is True. " + "Default True.") + .SetDefault(true); + AddAttr("ceil_mode", + "(bool) Whether to use the ceil function to calculate " + "output height and width. False is the default. If it is set to False, " + "the floor function will be used. Default False") + .SetDefault(false); + AddAttr("data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("NCHW"); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("padding_algorithm", + "(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\"," + "\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. " + "Set to \"SAME\" or \"VALID\" for algorithm of padding. ") + .SetDefault("EXPLICIT"); + AddComment(R"DOC( +This operation calculates the pooling output based on +the input, pooling_type and pool_size, pool_stride, pool_padding parameters. +Input(X) and Output(Out) are in NCHW or NHWC format, where N is batch size, C is the +number of channels, H is the height of the feature, and W is the width of the feature. +Parameters(pool_size, pool_stride, pool_padding) hold two integer elements. +These two elements represent height and width, respectively. +The input(X) size and output(Out) size may be different. +)DOC"); + } +}; + +class MpcPoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { +protected: + std::unordered_map& GetInputOutputWithSameType() const override { + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; + } +}; + + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR( + mpc_pool2d, ops::MpcPoolOp, ops::MpcPool2dOpMaker, ops::MpcPoolOpInferVarType, + paddle::framework::DefaultGradOpMaker, + paddle::framework::DefaultGradOpMaker); +REGISTER_OPERATOR(mpc_pool2d_grad, ops::MpcPoolOpGrad); + +REGISTER_OP_CPU_KERNEL( + mpc_pool2d, ops::MpcPoolKernel); +REGISTER_OP_CPU_KERNEL( + mpc_pool2d_grad, ops::MpcPoolGradKernel); diff --git a/core/paddlefl_mpc/operators/mpc_pool_op.h b/core/paddlefl_mpc/operators/mpc_pool_op.h new file mode 100644 index 0000000000000000000000000000000000000000..87dac6233e051aebc9d4fff9bbe6758a2579b0e3 --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_pool_op.h @@ -0,0 +1,381 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "mpc_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +template +inline void UpdatePadding(std::vector* paddings, const bool global_pooling, + const std::string& padding_algorithm, + const framework::DDim data_dims, + const std::vector& strides, + const std::vector& ksize) { + // set padding size == data_dims.size() * 2 + auto data_shape = framework::vectorize(data_dims); + if (static_cast(paddings->size()) == data_dims.size()) { + for (int i = 0; i < data_dims.size(); ++i) { + T copy_pad = *(paddings->begin() + 2 * i); + paddings->insert(paddings->begin() + 2 * i + 1, copy_pad); + } + } else { + PADDLE_ENFORCE_EQ(data_dims.size() * 2, paddings->size(), + "Paddings size should be the same or twice as the pooling size."); + } + + // when padding_algorithm is "VALID" or "SAME" + if (padding_algorithm == "SAME") { + for (int i = 0; i < data_dims.size(); ++i) { + T out_size = (data_dims[i] + strides[i] - 1) / strides[i]; + T pad_sum = std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], static_cast(0)); + T pad_0 = pad_sum / 2; + T pad_1 = pad_sum - pad_0; + *(paddings->begin() + i * 2) = pad_0; + *(paddings->begin() + i * 2 + 1) = pad_1; + } + } else if (padding_algorithm == "VALID") { + for (auto it = paddings->begin(); it != paddings->end(); it++) { + *it = 0; + } + } + + // if global_pooling == true, padding will be ignore + if (global_pooling) { + for (auto it = paddings->begin(); it != paddings->end(); it++) { + *it = 0; + } + } +} + +template +inline void UpdateKsize(std::vector* ksize, + const framework::DDim data_dims) { + ksize->resize(static_cast(data_dims.size())); + for (size_t i = 0; i < ksize->size(); ++i) { + *(ksize->begin() + i) = static_cast(data_dims[i]); + } +} + +template +void VisitDataStrideWise(DDim in_dims, DDim out_dims, + std::vector& ksize, std::vector& strides, std::vector& paddings, + const T* src, T* target, int src_stride, int target_stride, Func visitor) { + + const int share_size = in_dims[0]; + const int batch_size = in_dims[1]; + const int channel_size = in_dims[2]; + const int input_height = in_dims[3]; + const int input_width = in_dims[4]; + const int out_height = out_dims[3]; + const int out_width = out_dims[4]; + const int out_mat_numel = out_height * out_width; + + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int filter_numel = ksize_height * ksize_width; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + int hstart, hend; + int wstart, wend; + + int idx = 0; + while (idx++ < batch_size * channel_size) { + for (size_t ph = 0; ph < out_height; ++ph) { + hstart = ph * stride_height - padding_height; + hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + + for (size_t pw = 0; pw < out_width; ++pw) { + wstart = pw * stride_width - padding_width; + wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + visitor(ph, pw, input_height, input_width, out_height, out_width, hstart, hend, + wstart, wend, src, target); + } + } + src += src_stride; + target += target_stride; + } +} + + +template +class MpcPoolKernel : public MpcOpKernel { +public: + void ComputeImpl(const framework::ExecutionContext &context) const override { + + const Tensor* in_x = context.Input("X"); + Tensor* out = context.Output("Out"); + Tensor* out_one_hot_tensor = context.Output("One_hot_tensor"); + + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + std::string data_format = context.Attr("data_format"); // NCHW + bool global_pooling = context.Attr("global_pooling"); + std::string padding_algorithm = + context.Attr("padding_algorithm"); + + const T* in_x_data = in_x->data(); + T* output_data = out->mutable_data(context.GetPlace()); + T* one_hot_tensor_data = out_one_hot_tensor->mutable_data(context.GetPlace()); + + // update paddings + auto in_x_dims = in_x->dims(); + auto out_dims = out->dims(); + + const int input_stride = in_x_dims[3] * in_x_dims[4]; + const int output_stride = out_dims[3] * out_dims[4]; + const int one_hot_tensor_stride = ksize[0] * ksize[1] * out_dims[3] * out_dims[4]; + + // create temp tensor + auto& dev_ctx = context.template device_context(); + Tensor input2col = context.AllocateTmpTensor(out_one_hot_tensor->dims(), dev_ctx); + T* input2col_data = input2col.data(); + std::fill(input2col_data, input2col_data + input2col.numel(), static_cast(0)); + + framework::DDim data_dims; + data_dims = framework::slice_ddim(in_x_dims, 3, in_x_dims.size()); + + // update padding => h, w + UpdatePadding(&paddings, global_pooling, padding_algorithm, + data_dims, strides, ksize); + if (data_dims.size() * 2 == static_cast(paddings.size())) { + for (int i = 0; i < data_dims.size(); ++i) { + paddings.erase(paddings.begin() + i + 1); + } + } + + if (global_pooling) { + UpdateKsize(&ksize, data_dims); + } + + // share0, share1 + const int input_plaintext_size = in_x->numel() / 2; + const int input2col_plaintext_size = out_one_hot_tensor->numel() / 2; + + // im2col + auto get_im2col = [=] (int ph, int pw, int input_height, int input_width, int out_height, int out_width, + int hstart, int hend, int wstart, int wend, const T* src, T* target) { + + size_t out_index = ph * out_width + pw; + size_t offset = out_height * out_width; + size_t index = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + target[out_index + index * offset] = src[h * input_width + w]; // share0 + target[out_index + index * offset + input2col_plaintext_size] = + src[h * input_width + w + input_plaintext_size]; // share1 + ++index; + } + } + }; + + // input2col + // convert in_x_data (S, B, C, H, W) into (S, B, C, filter_size * filter_size, H_output * W_output) + VisitDataStrideWise(in_x_dims, out_dims, ksize, strides, paddings, in_x_data, input2col_data, input_stride, one_hot_tensor_stride, get_im2col); + + const T* input2col_data2 = input2col.data(); + + // maxpooling(input2col_trans), return(max2col, out_one_hot_tensor_trans) + // input2col_trans: (S, filter_size * filter_size, B, C, H_output * W_output) + // max2col: (S, , B, C, H_output * W_output) + // out_one_hot_tensor_trans: (S, filter_size * filter_size, B, C, H_output * W_output) + Tensor input2col_trans; + DDim in2col_dims = input2col.dims(); + T* input2col_trans_data = input2col_trans.mutable_data(in2col_dims, context.GetPlace()); + input2col_trans.Resize({in2col_dims[0], in2col_dims[3], in2col_dims[1], in2col_dims[2], in2col_dims[4]}); + + Tensor max2col; + max2col.ShareDataWith(*out); + max2col.Resize({in2col_dims[0], 1, in2col_dims[1], in2col_dims[2], in2col_dims[4]}); + + Tensor out_one_hot_tensor_trans; + out_one_hot_tensor_trans.mutable_data(out_one_hot_tensor->dims(), context.GetPlace()); + out_one_hot_tensor_trans.Resize({in2col_dims[0], in2col_dims[3], in2col_dims[1], in2col_dims[2], in2col_dims[4]}); + + // convert input2col (S, B, C, filter_size * filter_size, H_output * W_output) + // into input2col_trans (S, filter_size * filter_size, B, C, H_output * W_output) + const int Rank = 5; + Eigen::array permute; + permute = {0, 3, 1, 2, 4}; + + auto eigen_in = framework::EigenTensor::From(input2col); + auto eigen_out = framework::EigenTensor::From(input2col_trans); + auto* dev = dev_ctx.eigen_device(); + eigen_out.device(*dev) = eigen_in.shuffle(permute); + + // maxpooling + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->max_pooling( + &input2col_trans, &max2col, &out_one_hot_tensor_trans); + + permute = {0, 2, 3, 1, 4}; + + // convert out_one_hot_tensor_trans: (S, filter_size * filter_size, B, C, H_output * W_output) + // into out_one_hot_tensor (S, B, C, filter_size * filter_size, H_output * W_output) + auto eigen_in2 = framework::EigenTensor::From(out_one_hot_tensor_trans); + auto eigen_out2 = framework::EigenTensor::From(*out_one_hot_tensor); + eigen_out2.device(*dev) = eigen_in2.shuffle(permute); + + // convert max2col: (S, 1, B, C, H_output * W_output) + // into out_one_hot_tensor (S, B, C, 1, H_output * W_output) + auto eigen_in3 = framework::EigenTensor::From(max2col); + + // flatten height & width + auto flatten_out_dims = out_dims; + flatten_out_dims[3] = 1; + flatten_out_dims[4] = out_dims[3] * out_dims[4]; + out->Resize(flatten_out_dims); + + auto eigen_out3 = framework::EigenTensor::From(*out); + eigen_out3.device(*dev) = eigen_in3.shuffle(permute); + + // reshape out (S, 1, B, C, H_output * W_output) + // into (S, B, C, H_output * W_output) + out->Resize(out_dims); + } +}; + + +template +class MpcPoolGradKernel : public MpcOpKernel { +public: + void ComputeImpl(const framework::ExecutionContext &context) const override { + + const Tensor* one_hot_tensor = context.Input("One_hot_tensor"); + const Tensor* out_grad = context.Input(framework::GradVarName("Out")); + Tensor* in_x_grad = context.Output(framework::GradVarName("X")); + + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + std::string data_format = context.Attr("data_format"); // NCHW + bool global_pooling = context.Attr("global_pooling"); + std::string padding_algorithm = + context.Attr("padding_algorithm"); + + if (in_x_grad) { + // update padding => h, w + auto in_x_dims = in_x_grad->dims(); + auto out_dims = out_grad->dims(); + framework::DDim data_dims; + data_dims = framework::slice_ddim(in_x_dims, 3, in_x_dims.size()); + + UpdatePadding(&paddings, global_pooling, padding_algorithm, + data_dims, strides, ksize); + if (data_dims.size() * 2 == static_cast(paddings.size())) { + for (int i = 0; i < data_dims.size(); ++i) { + paddings.erase(paddings.begin() + i + 1); + } + } + + if (global_pooling) { + UpdateKsize(&ksize, data_dims); + } + + // create temp tensor + auto& dev_ctx = context.template device_context(); + Tensor expanded_out_grad_tensor = + context.AllocateTmpTensor(one_hot_tensor->dims(), dev_ctx); + Tensor mul_result_tensor = + context.AllocateTmpTensor(one_hot_tensor->dims(), dev_ctx); + + // create data var of input and output variable + T* in_x_grad_data = in_x_grad->mutable_data(context.GetPlace()); + std::fill(in_x_grad_data, in_x_grad_data + in_x_grad->numel(), static_cast(0)); + const T* one_hot_tensor_data = one_hot_tensor->data(); + const T* out_grad_data = out_grad->data(); + T* expanded_out_grad_data = expanded_out_grad_tensor.data(); + T* mul_result_data = mul_result_tensor.data(); + + const int filter_numel = ksize[0] * ksize[1]; + + // stride = h * w + const int input_stride = in_x_dims[3] * in_x_dims[4]; + const int output_stride = out_dims[3] * out_dims[4]; + const int one_hot_tensor_stride = ksize[0] * ksize[1] * out_dims[3] * out_dims[4]; + + // stride: share0, share1 + const int input_plaintext_size = in_x_grad->numel() / 2; + const int output_plaintext_size = out_grad->numel() / 2; + const int one_hot_tensor_plaintext_size = one_hot_tensor->numel() / 2; + + // expand out grad + auto get_expand_out_grad = [=] (int ph, int pw, int input_height, int input_width, + int out_height, int out_width, int hstart, int hend, + int wstart, int wend, const T* src, T* target) { + + size_t out_grad_index = ph * out_width + pw; + size_t offset = out_height * out_width; + + for (size_t index = 0; index < filter_numel; ++index) { + target[out_grad_index + index * offset] = src[out_grad_index]; //share0 + target[out_grad_index + index * offset + one_hot_tensor_plaintext_size] = + src[out_grad_index + output_plaintext_size]; // share1 + } + }; + + // expand [S, B, C, H_poolout, W_poolout] into [S, B, C, ksize * ksize, H_poolout*W_poolout] + VisitDataStrideWise(in_x_dims, out_dims, ksize, strides, paddings, out_grad_data, + expanded_out_grad_data, output_stride, one_hot_tensor_stride, get_expand_out_grad); + + // compute mul result = out_grad.expand * one_hot_tensor + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->arith_bool_mul( + &expanded_out_grad_tensor, one_hot_tensor, &mul_result_tensor); + + + // updata input X's grad + auto update_in_grad = [=] (int ph, int pw, + int input_height, int input_width, + int out_height, int out_width, + int hstart, int hend, int wstart, int wend, + const T* src, T* target) { + + size_t index = 0; + size_t in_pos = 0; + size_t out_grad_index = ph * out_width + pw; + size_t res_offset = out_height * out_width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + in_pos = h * input_width + w; + target[in_pos] += src[out_grad_index + index * res_offset]; // share0 + target[in_pos + input_plaintext_size] += + src[out_grad_index + index * res_offset + one_hot_tensor_plaintext_size]; // share1 + ++index; + } + } + }; + // convert [S, B, C, filter_size * filter_size, ] into [S, B, C, H, W] + VisitDataStrideWise(in_x_dims, out_dims, ksize, strides, paddings, mul_result_data, + in_x_grad_data, one_hot_tensor_stride, input_stride, update_in_grad); + + } //if (in_x_grad) + } // void ComputeImpl +}; // class MpcPooliGradKernel + +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/mpc_relu_op.cc b/core/paddlefl_mpc/operators/mpc_relu_op.cc index 420abeb635088174465c8e764ca522f8ec6f91b2..b52724ecbf9e1f9a3ec94dda7b4075cde2531904 100644 --- a/core/paddlefl_mpc/operators/mpc_relu_op.cc +++ b/core/paddlefl_mpc/operators/mpc_relu_op.cc @@ -25,16 +25,18 @@ class MpcReluOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { auto in_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim("Y", in_dims); + ctx->SetOutputDim("Out", in_dims); + ctx->SetOutputDim("Derivative", in_dims); } }; -//forward input & output defination +//forward input & output defination class MpcReluOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "The input tensor."); - AddOutput("Y", "Output of relu_op"); + AddOutput("Out", "Output of relu_op"); + AddOutput("Derivative", "Derivative of relu_op"); AddComment(R"DOC( Mpc Relu Operator. )DOC"); @@ -47,7 +49,7 @@ class MpcReluGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - auto in_dims = ctx->GetInputDim(framework::GradVarName("Y")); + auto in_dims = ctx->GetInputDim(framework::GradVarName("Out")); ctx->SetOutputDim(framework::GradVarName("X"), in_dims); } }; @@ -61,8 +63,9 @@ public: protected: void Apply(GradOpPtr grad) const override { grad->SetType("mpc_relu_grad"); - grad->SetInput("Y", this->Output("Y")); - grad->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + grad->SetInput("Out", this->Output("Out")); + grad->SetInput("Derivative", this->Output("Derivative")); + grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetAttrMap(this->Attrs()); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); } diff --git a/core/paddlefl_mpc/operators/mpc_relu_op.h b/core/paddlefl_mpc/operators/mpc_relu_op.h index f0a39264513bbfbfe4a9ba162d5bace2235b559b..a2a22d30d1dc93088917d88e0145eca68d1e4b68 100644 --- a/core/paddlefl_mpc/operators/mpc_relu_op.h +++ b/core/paddlefl_mpc/operators/mpc_relu_op.h @@ -25,11 +25,14 @@ class MpcReluKernel : public MpcOpKernel { public: void ComputeImpl(const framework::ExecutionContext& ctx) const override { const Tensor* in_t = ctx.Input("X"); - Tensor* out_t = ctx.Output("Y"); + Tensor* out_t = ctx.Output("Out"); + Tensor* der_t = ctx.Output("Derivative"); auto x = in_t->data(); auto y = out_t->mutable_data(ctx.GetPlace()); + auto der = der_t->mutable_data(ctx.GetPlace()); PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol."); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu(in_t,out_t); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators() + ->relu_with_derivative(in_t,out_t, der_t); } }; @@ -38,11 +41,12 @@ template class MpcReluGradKernel : public MpcOpKernel { public: void ComputeImpl(const framework::ExecutionContext& ctx) const override { - auto* dy_t = ctx.Input(framework::GradVarName("Y")); - auto* y_t = ctx.Input("Y"); + auto* dy_t = ctx.Input(framework::GradVarName("Out")); + auto* y_t = ctx.Input("Out"); + auto* der_t = ctx.Input("Derivative"); auto* dx_t = ctx.Output(framework::GradVarName("X")); auto dx = dx_t->mutable_data(ctx.GetPlace()); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu_grad(y_t, dy_t, dx_t, 0.0); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->arith_bool_mul(dy_t, der_t, dx_t); } }; diff --git a/core/paddlefl_mpc/operators/mpc_sgd_op.cc b/core/paddlefl_mpc/operators/mpc_sgd_op.cc index c2f7cc9a5ef593d79ec5078b0eb48953d82e9773..8d1b13e01a26e7dff0b7080efa04ec55a6a56a56 100644 --- a/core/paddlefl_mpc/operators/mpc_sgd_op.cc +++ b/core/paddlefl_mpc/operators/mpc_sgd_op.cc @@ -72,12 +72,6 @@ public: " but the received var(%s)'s type is %s", ctx->InputVarName("Param"), in_var_type); ctx->SetOutputType("ParamOut", in_var_type); - - //for (auto &out_var_n : framework::StaticGraphVarTypeInference::Output(ctx, "ParamOut")) { - // if (ctx->GetVarType(out_var_n) != in_var_type) { - // ctx->SetType(out_var_n, in_var_type); - //} - //} } }; @@ -111,4 +105,4 @@ REGISTER_OPERATOR( ops::MpcSGDOpInferVarType); REGISTER_OP_CPU_KERNEL( mpc_sgd, - ops::MpcSGDOpKernel); + ops::MpcSGDOpKernel); diff --git a/core/paddlefl_mpc/operators/mpc_sgd_op.h b/core/paddlefl_mpc/operators/mpc_sgd_op.h index 805b74d04d38f4c8fdb040a958f77fea12e6f8ed..3d2e0ac979040ab71e80c5a226be6034b23dc813 100644 --- a/core/paddlefl_mpc/operators/mpc_sgd_op.h +++ b/core/paddlefl_mpc/operators/mpc_sgd_op.h @@ -19,7 +19,7 @@ namespace paddle { namespace operators { -template +template class MpcSGDOpKernel : public MpcOpKernel { public: void ComputeImpl(const framework::ExecutionContext &ctx) const override{ @@ -47,14 +47,14 @@ class MpcSGDOpKernel : public MpcOpKernel { PADDLE_ENFORCE_EQ(param->numel(), sz); PADDLE_ENFORCE_EQ(grad->numel(), sz); - const double *lr = learning_rate->data(); + double lr = *learning_rate->data(); param_out->mutable_data(ctx.GetPlace()); PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol."); // update parameters framework::Tensor temp; temp.mutable_data(param->dims(), ctx.GetPlace()); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(grad, lr[0], &temp); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(grad, lr, &temp); mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(param, &temp, param_out); } }; diff --git a/core/paddlefl_mpc/operators/mpc_softmax_with_cross_entropy_op.cc b/core/paddlefl_mpc/operators/mpc_softmax_with_cross_entropy_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c454ca5c488edbeb0a8fae82a0d284c1106dd427 --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_softmax_with_cross_entropy_op.cc @@ -0,0 +1,241 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "mpc_softmax_with_cross_entropy_op.h" + +namespace paddle { +namespace operators { + +class MpcSoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("Logits"), true, + platform::errors::InvalidArgument("Input(Logits) should be not null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Label"), true, + platform::errors::InvalidArgument("Input(Label) should be not null.")); + + PADDLE_ENFORCE_EQ(ctx->HasOutput("Softmax"), true, + platform::errors::InvalidArgument( + "Output(Softmax) should be not null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Loss"), true, + platform::errors::InvalidArgument("Output(Loss) should be not null.")); + + auto axis = ctx->Attrs().Get("axis"); + auto logits_dims = ctx->GetInputDim("Logits"); + auto labels_dims = ctx->GetInputDim("Label"); + auto logits_rank = logits_dims.size(); + + axis = CanonicalAxis(axis, logits_rank); + PADDLE_ENFORCE_GE(axis, logits_rank - 1, + platform::errors::InvalidArgument( + "Attr(axis) value should be -1 or R-1, " + "R is the rank of Input(Logits).")); + for (int i = 0; i < logits_rank; i++) { + if (i != axis) { + if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) { + PADDLE_ENFORCE_EQ(logits_dims[i], labels_dims[i], + platform::errors::InvalidArgument( + "Input(Logits) and Input(Label) should in " + "same shape in dimensions except axis.")); + } + } + } + + bool soft_label = ctx->Attrs().Get("soft_label"); + PADDLE_ENFORCE_EQ(soft_label, true, + platform::errors::InvalidArgument( + "soft_label can only be true! ")); + if (soft_label) { + if (ctx->IsRuntime() || + (logits_dims[axis] > 0 && labels_dims[axis] > 0)) { + PADDLE_ENFORCE_EQ(logits_dims[axis], labels_dims[axis], + platform::errors::InvalidArgument( + "If Attr(soft_label) == true, " + "the axis dimension of " + "Input(X) and Input(Label) should be equal.")); + } + } + ctx->SetOutputDim("Softmax", logits_dims); + + logits_dims[axis] = 1; + ctx->SetOutputDim("Loss", logits_dims); + + ctx->ShareLoD("Logits", /*->*/ "Softmax"); + ctx->ShareLoD("Logits", /*->*/ "Loss"); + } +}; + + +class MpcSoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Loss")), true, + platform::errors::InvalidArgument( + "Input(Loss@Grad) should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"), true, + platform::errors::InvalidArgument( + "Input(Softmax) should be not null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Label"), true, + platform::errors::InvalidArgument("Input(Label) should be not null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Logits")), true, + platform::errors::InvalidArgument( + "Output(Logits@Grad) should be not null.")); + + auto axis = ctx->Attrs().Get("axis"); + auto softmax_dims = ctx->GetInputDim("Softmax"); + auto labels_dims = ctx->GetInputDim("Label"); + auto softmax_rank = softmax_dims.size(); + + axis = CanonicalAxis(axis, softmax_rank); + PADDLE_ENFORCE_GE(axis, softmax_rank - 1, + platform::errors::InvalidArgument( + "Attr(axis) value should be -1 or R-1, " + "R is the rank of Input(Logits).")); + for (int i = 0; i < softmax_rank; i++) { + if (i != axis) { + if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) { + PADDLE_ENFORCE_EQ( + softmax_dims[i], labels_dims[i], + platform::errors::InvalidArgument( + "Input(Logits) and Input(Label) should in same shape in " + "dimensions except axis.")); + } + } + } + + bool soft_label = ctx->Attrs().Get("soft_label"); + PADDLE_ENFORCE_EQ(soft_label, true, + platform::errors::InvalidArgument( + "soft_label can only be true! ")); + if (soft_label) { + if (ctx->IsRuntime() || (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) { + PADDLE_ENFORCE_EQ(softmax_dims[axis], labels_dims[axis], + platform::errors::InvalidArgument( + "If Attr(soft_label) == true, " + "the axis dimension of " + "Input(X) and Input(Label) should be equal.")); + } + } + + ctx->SetOutputDim(framework::GradVarName("Logits"), + ctx->GetInputDim("Softmax")); + } +}; + + +class MpcSoftmaxWithCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + AddInput("Logits", + "(Tensor, default: Tensor), The input tensor of unscaled " + "log probabilities, whose dimension :attr:`axis` should be scaled " + "by softmax."); + AddInput( + "Label", + "(Tensor) The input tensor of groud truth label. If :attr:`soft_label` " + "is set to false, Label is a Tensor in same shape with " + "Input(Logits) except the shape in dimension :attr:`axis` as 1. If " + "soft_label is set to true, Label is a Tensor in same " + "shape with Input(Logits)."); + AddOutput( + "Softmax", + "(Tensor, default: Tensor), A tensor in same shape with " + "Input(Logits). " + "The outputs value of softmax activation by given the input batch, " + "which will be used in backward calculation.") + .AsIntermediate(); + AddOutput("Loss", + "(Tensor, default: Tensor), A tensor in same shape with " + "Input(Logits) " + "except the shape in dimension :attr:`axis` as 1. The cross " + "entropy loss."); + AddAttr( + "soft_label", + "(bool, default: false), A flag to indicate whether to interpretant " + "the given labels as soft labels.") + .SetDefault(false); + AddAttr("axis", + "The dimension index of Input(Logits) to perform softmax," + "default -1 for last dimension") + .SetDefault(-1); + AddAttr("use_relu", "").SetDefault(false); + AddAttr("use_long_div", "").SetDefault(true); + AddComment(R"DOC( +Softmax With Cross Entropy Operator. +Cross entropy loss with softmax is used as the output layer extensively. This +operator computes the softmax normalized values for each row of the input +tensor. +Conputing cross-entropy loss is not supported now. +Now, we only support soft_label=true, axis=-1 or (rank-1). +Forward: out = softmax(x). todo: add cross_entropy +backward: dx = dout.expand * (softmax(x) - label) +)DOC"); + } +}; + + +template +class MpcSoftmaxGradMaker : public framework::SingleGradOpMaker { +public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + +protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("mpc_softmax_with_cross_entropy_grad"); + grad_op->SetInput("Label", this->Input("Label")); + grad_op->SetInput("Softmax", this->Output("Softmax")); + grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss")); + grad_op->SetOutput(framework::GradVarName("Logits"), + this->InputGrad("Logits")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + + +DECLARE_INPLACE_OP_INFERER(MpcSoftmaxWithCrossEntropyInplaceInference, + {"Logits", "Softmax"}); + +DECLARE_INPLACE_OP_INFERER(MpcSoftmaxWithCrossEntropyGradInplaceInference, + {"Softmax", framework::GradVarName("Logits")}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(mpc_softmax_with_cross_entropy, ops::MpcSoftmaxWithCrossEntropyOp, + ops::MpcSoftmaxWithCrossEntropyOpMaker, + ops::MpcSoftmaxGradMaker, + ops::MpcSoftmaxGradMaker, + ops::MpcSoftmaxWithCrossEntropyInplaceInference); +REGISTER_OPERATOR(mpc_softmax_with_cross_entropy_grad, + ops::MpcSoftmaxWithCrossEntropyOpGrad, + ops::MpcSoftmaxWithCrossEntropyGradInplaceInference); +REGISTER_OP_CPU_KERNEL(mpc_softmax_with_cross_entropy, + ops::MpcSoftmaxWithCrossEntropyKernel); +REGISTER_OP_CPU_KERNEL(mpc_softmax_with_cross_entropy_grad, + ops::MpcSoftmaxWithCrossEntropyGradKernel); + diff --git a/core/paddlefl_mpc/operators/mpc_softmax_with_cross_entropy_op.h b/core/paddlefl_mpc/operators/mpc_softmax_with_cross_entropy_op.h new file mode 100644 index 0000000000000000000000000000000000000000..04dbd077a699e5771874d2305916d9da16ba4e6a --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_softmax_with_cross_entropy_op.h @@ -0,0 +1,107 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "mpc_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +static inline int CanonicalAxis(const int axis, const int rank) { + if (axis < 0) { + return axis + rank; + } + return axis; +} + +static inline int SizeToAxis(const int axis, DDim dims) { + int size = 1; + for (int i = 0; i < axis; i++) { + size *= dims[i]; + } + return size; +} + +static inline int SizeFromAxis(const int axis, DDim dims) { + int size = 1; + for (int i = axis; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + + +// Out = softmax(Logits) = relu(Logits_i) / sum(relu(Logits_i)): prediction of input. +// todo: loss=? +template +class MpcSoftmaxWithCrossEntropyKernel : public MpcOpKernel { +public: + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto *in_x_t = ctx.Input("Logits"); + auto *out_softmax_t = ctx.Output("Softmax"); + auto *out_loss_t = ctx.Output("Loss"); + out_softmax_t->mutable_data(ctx.GetPlace()); + out_loss_t->mutable_data(ctx.GetPlace()); + bool use_relu = ctx.Attr("use_relu"); + bool use_long_div = ctx.Attr("use_long_div"); + + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->softmax( + in_x_t, out_softmax_t, use_relu, use_long_div); + } +}; + +// dx = dout.expand * (softmax(x) - labels) +template +class MpcSoftmaxWithCrossEntropyGradKernel : public MpcOpKernel { +public: + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto *dout = ctx.Input(framework::GradVarName("Loss")); + auto *in_label_t = ctx.Input("Label"); + auto *in_softmax_t = ctx.Input("Softmax"); + auto *dx = ctx.Output(framework::GradVarName("Logits")); + const bool soft_label = ctx.Attr("soft_label"); + PADDLE_ENFORCE_EQ(soft_label, true, "soft_label can only be true."); + + const int rank = dx->dims().size(); + const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + int axis_dim = dx->dims()[axis]; + const int n = SizeToAxis(axis, dx->dims()); + const int d = SizeFromAxis(axis, dx->dims()); + + T* dx_data = dx->mutable_data(ctx.GetPlace()); + const T* dout_data = dout->data(); + + // expand dout + Tensor dout_expand; + T* dout_expand_data = dout_expand.mutable_data(dx->dims(), ctx.GetPlace()); + + for (size_t i = 0; i < n; ++i) { + for (size_t j = 0; j < d; ++j) { + dout_expand_data[i * d + j] = dout_data[i]; + } + } + + // dx = dout.expand * (softmax - label) + Tensor softmax_minus_label; + T* softmax_minus_label_data = softmax_minus_label.mutable_data(dx->dims(), ctx.GetPlace()); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(in_softmax_t, in_label_t, &softmax_minus_label); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(&dout_expand, &softmax_minus_label, dx); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/core/privc3/CMakeLists.txt b/core/privc3/CMakeLists.txt index cbc1d1b9b4864b2f944fb459da49de5bb8674095..c560ef28be4002acec5283c5b1e5dd480e1c40b3 100644 --- a/core/privc3/CMakeLists.txt +++ b/core/privc3/CMakeLists.txt @@ -1,5 +1,3 @@ -add_compile_options(-msse4.2 -maes) - set(PRIVC3_SRCS "./aes.cc" "./paddle_tensor.cc" @@ -8,16 +6,23 @@ set(PRIVC3_SRCS "./tensor_adapter_factory.cc" ) +if (USE_AES_NI) + add_compile_definitions(USE_AES_NI) +endif (USE_AES_NI) + add_library(privc3_o OBJECT ${PRIVC3_SRCS}) add_dependencies(privc3_o fluid_framework) add_library(privc3 STATIC $) -target_link_libraries(privc3 fluid_framework) + +if (USE_OPENMP) + target_link_libraries(privc3 fluid_framework OpenMP::OpenMP_CXX OpenMP::OpenMP_C crypto) +else() + target_link_libraries(privc3 fluid_framework crypto) +endif (USE_OPENMP) cc_test(fixedpoint_util_test SRCS fixedpoint_util_test.cc DEPS privc3) cc_test(paddle_tensor_test SRCS paddle_tensor_test.cc DEPS privc3) cc_test(boolean_tensor_test SRCS boolean_tensor_test.cc DEPS privc3) cc_test(fixedpoint_tensor_test SRCS fixedpoint_tensor_test.cc DEPS privc3) - -#set(CMAKE_BUILD_TYPE "Debug") diff --git a/core/privc3/boolean_tensor.h b/core/privc3/boolean_tensor.h index 8a3de5595223d1fd043364cbfebaeb07d6d94523..36418a34dd1d56968aff9fcc7c4e55e03f9fb301 100644 --- a/core/privc3/boolean_tensor.h +++ b/core/privc3/boolean_tensor.h @@ -23,117 +23,124 @@ namespace aby3 { -template class FixedPointTensor; +template +class FixedPointTensor; -template class BooleanTensor { +template +class BooleanTensor { public: - BooleanTensor(TensorAdapter *share_tensor[2]); + BooleanTensor(TensorAdapter* share_tensor[2]); - BooleanTensor(TensorAdapter *tensor0, TensorAdapter *tensor1); + BooleanTensor(TensorAdapter* tensor0, TensorAdapter* tensor1); - BooleanTensor(); + BooleanTensor(); - // ABY3 a2b - template - BooleanTensor &operator=(const FixedPointTensor *other); + // ABY3 a2b + template + BooleanTensor& operator=(const FixedPointTensor* other); - ~BooleanTensor() {} + ~BooleanTensor() {} - // get share - TensorAdapter *share(size_t idx); + //get share + TensorAdapter* share(size_t idx); - const TensorAdapter *share(size_t idx) const; + const TensorAdapter* share(size_t idx) const; - // reveal boolean tensor to one party - void reveal_to_one(size_t party_num, TensorAdapter *ret) const; + // reveal boolean tensor to one party + void reveal_to_one(size_t party_num, TensorAdapter* ret) const; - // reveal boolean tensor to all parties - void reveal(TensorAdapter *ret) const; + // reveal boolean tensor to all parties + void reveal(TensorAdapter* ret) const; - const std::vector shape() const; + const std::vector shape() const; - size_t numel() const; + size_t numel() const; - // //convert TensorAdapter to shares - // static void share(const TensorAdapter* input, - // TensorAdapter* output_shares[3], - // const std::string& rnd_seed = ""); + // //convert TensorAdapter to shares + // static void share(const TensorAdapter* input, + // TensorAdapter* output_shares[3], + // const std::string& rnd_seed = ""); - // element-wise xor with BooleanTensor - void bitwise_xor(const BooleanTensor *rhs, BooleanTensor *ret) const; + // element-wise xor with BooleanTensor + void bitwise_xor(const BooleanTensor* rhs, BooleanTensor* ret) const; - // element-wise xor with TensorAdapter - void bitwise_xor(const TensorAdapter *rhs, BooleanTensor *ret) const; + // element-wise xor with TensorAdapter + void bitwise_xor(const TensorAdapter* rhs, BooleanTensor* ret) const; - // element-wise and with BooleanTensor - void bitwise_and(const BooleanTensor *rhs, BooleanTensor *ret) const; + // element-wise and with BooleanTensor + void bitwise_and(const BooleanTensor* rhs, BooleanTensor* ret) const; - // element-wise and with TensorAdapter - void bitwise_and(const TensorAdapter *rhs, BooleanTensor *ret) const; + // element-wise and with TensorAdapter + void bitwise_and(const TensorAdapter* rhs, BooleanTensor* ret) const; - // element-wise or with BooleanTensor - void bitwise_or(const BooleanTensor *rhs, BooleanTensor *ret) const; + // element-wise or + // for both tensor adapter and boolean tensor + template class CTensor> + void bitwise_or(const CTensor* rhs, BooleanTensor* ret) const; - // element-wise or with TensorAdapter - void bitwise_or(const TensorAdapter *rhs, BooleanTensor *ret) const; + // element-wise not + void bitwise_not(BooleanTensor* ret) const; - // element-wise not - void bitwise_not(BooleanTensor *ret) const; + // element-wise lshift + void lshift(size_t rhs, BooleanTensor* ret) const; - // element-wise lshift - void lshift(size_t rhs, BooleanTensor *ret) const; + // element-wise rshift + void rshift(size_t rhs, BooleanTensor* ret) const; - // element-wise rshift - void rshift(size_t rhs, BooleanTensor *ret) const; + // element-wise logical_rshift + void logical_rshift(size_t rhs, BooleanTensor* ret) const; - // element-wise logical_rshift - void logical_rshift(size_t rhs, BooleanTensor *ret) const; + // element-wise ppa with BooleanTensor + void ppa(const BooleanTensor* rhs, BooleanTensor*ret , size_t nbits) const; - // element-wise ppa with BooleanTensor - void ppa(const BooleanTensor *rhs, BooleanTensor *ret, size_t nbits) const; + // ABY3 b2a + template + void b2a(FixedPointTensor* ret) const; - // ABY3 b2a - template void b2a(FixedPointTensor *ret) const; + // ABY3 ab mul + // this is an one-bit boolean share + template + void mul(const TensorAdapter* rhs, FixedPointTensor* ret, size_t rhs_party) const; - // ABY3 ab mul - // this is an one-bit boolean share - template - void mul(const TensorAdapter *rhs, FixedPointTensor *ret, - size_t rhs_party) const; + // ABY3 ab mul + // this is an one-bit boolean share + template + void mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const; - // ABY3 ab mul - // this is an one-bit boolean share - template - void mul(const FixedPointTensor *rhs, - FixedPointTensor *ret) const; + // extract to this + template + void bit_extract(size_t i, const FixedPointTensor* in); - // extract to this - template - void bit_extract(size_t i, const FixedPointTensor *in); + // extract from this to ret + void bit_extract(size_t i, BooleanTensor* ret) const; - // extract from this to ret - void bit_extract(size_t i, BooleanTensor *ret) const; + // turn all 1s to 0s except the last 1 in a col + // given cmp result from max pooling, generate one hot tensor + // indicating which element is max + // inplace transform + void onehot_from_cmp(); private: - static inline std::shared_ptr aby3_ctx() { - return paddle::mpc::ContextHolder::mpc_ctx(); - } + static inline std::shared_ptr aby3_ctx() { + return paddle::mpc::ContextHolder::mpc_ctx(); + } - static inline std::shared_ptr tensor_factory() { - return paddle::mpc::ContextHolder::tensor_factory(); - } + static inline std::shared_ptr tensor_factory() { + return paddle::mpc::ContextHolder::tensor_factory(); + } - size_t pre_party() const; + size_t pre_party() const; - size_t next_party() const; + size_t next_party() const; - size_t party() const; + size_t party() const; private: - TensorAdapter *_share[2]; + TensorAdapter* _share[2]; + }; -} // namespace aby3 +} //namespace aby3 #include "boolean_tensor_impl.h" diff --git a/core/privc3/boolean_tensor_impl.h b/core/privc3/boolean_tensor_impl.h index 99c6c1f61de284f078c729f17ef87e7b84619f68..012a158f4d8677787f7695035f06cab451e2dea3 100644 --- a/core/privc3/boolean_tensor_impl.h +++ b/core/privc3/boolean_tensor_impl.h @@ -18,465 +18,514 @@ namespace aby3 { -template size_t BooleanTensor::pre_party() const { - return aby3_ctx()->pre_party(); +template +size_t BooleanTensor::pre_party() const { + return aby3_ctx()->pre_party(); } -template size_t BooleanTensor::next_party() const { - return aby3_ctx()->next_party(); +template +size_t BooleanTensor::next_party() const { + return aby3_ctx()->next_party(); } -template size_t BooleanTensor::party() const { - return aby3_ctx()->party(); +template +size_t BooleanTensor::party() const { + return aby3_ctx()->party(); } -template -BooleanTensor::BooleanTensor(TensorAdapter *tensor[2]) { - // TODO: check if tensor shape equal - _share[0] = tensor[0]; - _share[1] = tensor[1]; +template +BooleanTensor::BooleanTensor(TensorAdapter* tensor[2]) { + // TODO: check if tensor shape equal + _share[0] = tensor[0]; + _share[1] = tensor[1]; } -template -BooleanTensor::BooleanTensor(TensorAdapter *tensor0, - TensorAdapter *tensor1) { - // TODO: check if tensor shape equal - _share[0] = tensor0; - _share[1] = tensor1; +template +BooleanTensor::BooleanTensor(TensorAdapter* tensor0, + TensorAdapter* tensor1) { + // TODO: check if tensor shape equal + _share[0] = tensor0; + _share[1] = tensor1; } -template BooleanTensor::BooleanTensor() {} +template +BooleanTensor::BooleanTensor() { +} -template TensorAdapter *BooleanTensor::share(size_t idx) { - // TODO: check if idx < 2 - return _share[idx]; +template +TensorAdapter* BooleanTensor::share(size_t idx) { + // TODO: check if idx < 2 + return _share[idx]; } -template -const TensorAdapter *BooleanTensor::share(size_t idx) const { - // TODO: check if idx < 2 - return _share[idx]; +template +const TensorAdapter* BooleanTensor::share(size_t idx) const { + // TODO: check if idx < 2 + return _share[idx]; } -template -void BooleanTensor::reveal_to_one(size_t party_num, - TensorAdapter *ret) const { +template +void BooleanTensor::reveal_to_one(size_t party_num, TensorAdapter* ret) const { - if (party_num == party()) { - // TODO: check if tensor shape equal + if (party_num == party()) { + // TODO: check if tensor shape equal - // incase of this and ret shares tensor ptr - auto buffer = tensor_factory()->template create(ret->shape()); - aby3_ctx()->network()->template recv(pre_party(), *buffer); + // incase of this and ret shares tensor ptr + auto buffer = tensor_factory()->template create(ret->shape()); + aby3_ctx()->network()->template recv(pre_party(), *buffer); - share(0)->bitwise_xor(buffer.get(), ret); - share(1)->bitwise_xor(ret, ret); + share(0)->bitwise_xor(buffer.get(), ret); + share(1)->bitwise_xor(ret, ret); - } else if (party_num == next_party()) { + } else if (party_num == next_party()) { - aby3_ctx()->network()->template send(party_num, *share(0)); - } + aby3_ctx()->network()->template send(party_num, *share(0)); + + } } -template -void BooleanTensor::reveal(TensorAdapter *ret) const { - for (size_t idx = 0; idx < 3; ++idx) { - reveal_to_one(idx, ret); - } +template +void BooleanTensor::reveal(TensorAdapter* ret) const { + for (size_t idx = 0; idx < 3; ++idx) { + reveal_to_one(idx, ret); + } } -template +template const std::vector BooleanTensor::shape() const { - if (share(0)) { - return share(0)->shape(); - } else { - return std::vector(); - } + if (share(0)) { + return share(0)->shape(); + } + else { + return std::vector(); + } } -template size_t BooleanTensor::numel() const { - if (share(0)) { - return share(0)->numel(); - } else { - 0; - } +template +size_t BooleanTensor::numel() const { + if (share(0)) { + return share(0)->numel(); + } + else { + 0; + } } -template -void BooleanTensor::bitwise_xor(const BooleanTensor *rhs, - BooleanTensor *ret) const { - share(0)->bitwise_xor(rhs->share(0), ret->share(0)); - share(1)->bitwise_xor(rhs->share(1), ret->share(1)); +template +void BooleanTensor::bitwise_xor(const BooleanTensor* rhs, + BooleanTensor* ret) const { + share(0)->bitwise_xor(rhs->share(0), ret->share(0)); + share(1)->bitwise_xor(rhs->share(1), ret->share(1)); } -template -void BooleanTensor::bitwise_xor(const TensorAdapter *rhs, - BooleanTensor *ret) const { - share(0)->bitwise_xor(rhs, ret->share(0)); - share(1)->bitwise_xor(rhs, ret->share(1)); +template +void BooleanTensor::bitwise_xor(const TensorAdapter* rhs, + BooleanTensor* ret) const { + share(0)->bitwise_xor(rhs, ret->share(0)); + share(1)->bitwise_xor(rhs, ret->share(1)); } -template -void BooleanTensor::bitwise_and(const BooleanTensor *rhs, - BooleanTensor *ret) const { - - auto tmp_zero = tensor_factory()->template create(ret->shape()); - auto tmp0 = tensor_factory()->template create(ret->shape()); - auto tmp1 = tensor_factory()->template create(ret->shape()); - auto tmp2 = tensor_factory()->template create(ret->shape()); - - aby3_ctx()->template gen_zero_sharing_boolean(*tmp_zero.get()); - - share(0)->bitwise_and(rhs->share(0), tmp0.get()); - share(0)->bitwise_and(rhs->share(1), tmp1.get()); - share(1)->bitwise_and(rhs->share(0), tmp2.get()); - - tmp0->bitwise_xor(tmp1.get(), tmp0.get()); - tmp0->bitwise_xor(tmp2.get(), tmp0.get()); - tmp0->bitwise_xor(tmp_zero.get(), ret->share(0)); - - // 3-party msg send recv sequence - // p0 p1 p2 - // t0: 0->2 2<-0 - // t1: 1<-2 2->1 - // t2: 0<-1 1->2 - if (party() > 0) { - aby3_ctx()->network()->template recv(next_party(), *(ret->share(1))); - aby3_ctx()->network()->template send(pre_party(), *(ret->share(0))); - } else { - aby3_ctx()->network()->template send(pre_party(), *(ret->share(0))); - aby3_ctx()->network()->template recv(next_party(), *(ret->share(1))); - } +template +void BooleanTensor::bitwise_and(const BooleanTensor* rhs, + BooleanTensor* ret) const { + + auto tmp_zero = tensor_factory()->template create(ret->shape()); + auto tmp0 = tensor_factory()->template create(ret->shape()); + auto tmp1 = tensor_factory()->template create(ret->shape()); + auto tmp2 = tensor_factory()->template create(ret->shape()); + + aby3_ctx()->template gen_zero_sharing_boolean(*tmp_zero.get()); + + share(0)->bitwise_and(rhs->share(0), tmp0.get()); + share(0)->bitwise_and(rhs->share(1), tmp1.get()); + share(1)->bitwise_and(rhs->share(0), tmp2.get()); + + tmp0->bitwise_xor(tmp1.get(), tmp0.get()); + tmp0->bitwise_xor(tmp2.get(), tmp0.get()); + tmp0->bitwise_xor(tmp_zero.get(), ret->share(0)); + + // 3-party msg send recv sequence + // p0 p1 p2 + // t0: 0->2 2<-0 + // t1: 1<-2 2->1 + // t2: 0<-1 1->2 + if (party() > 0) { + aby3_ctx()->network()->template recv(next_party(), *(ret->share(1))); + aby3_ctx()->network()->template send(pre_party(), *(ret->share(0))); + } else { + aby3_ctx()->network()->template send(pre_party(), *(ret->share(0))); + aby3_ctx()->network()->template recv(next_party(), *(ret->share(1))); + } } -template -void BooleanTensor::bitwise_and(const TensorAdapter *rhs, - BooleanTensor *ret) const { - share(0)->bitwise_and(rhs, ret->share(0)); - share(1)->bitwise_and(rhs, ret->share(1)); +template +void BooleanTensor::bitwise_and(const TensorAdapter* rhs, + BooleanTensor* ret) const { + share(0)->bitwise_and(rhs, ret->share(0)); + share(1)->bitwise_and(rhs, ret->share(1)); } -template -void BooleanTensor::bitwise_or(const BooleanTensor *rhs, - BooleanTensor *ret) const { - // ret = x & y - bitwise_and(rhs, ret); - // ret = x & y ^ x - bitwise_xor(ret, ret); - // ret = x & y ^ x ^ y - rhs->bitwise_xor(ret, ret); +template +template class CTensor> +void BooleanTensor::bitwise_or(const CTensor* rhs, + BooleanTensor* ret) const { + + std::vector>> tmp; + + for (int i = 0; i < 2; ++i) { + tmp.emplace_back( + tensor_factory()->template create(shape())); + } + + BooleanTensor buffer(tmp[0].get(), tmp[1].get()); + // ret = x & y + bitwise_and(rhs, &buffer); + // ret = x & y ^ x + bitwise_xor(&buffer, &buffer); + // ret = x & y ^ x ^ y + buffer.bitwise_xor(rhs, ret); } -template -void BooleanTensor::bitwise_or(const TensorAdapter *rhs, - BooleanTensor *ret) const { - // ret = x & y - bitwise_and(rhs, ret); - // ret = x & y ^ x - bitwise_xor(ret, ret); - // ret = x & y ^ x ^ y - ret->bitwise_xor(rhs, ret); -} - -template -void BooleanTensor::bitwise_not(BooleanTensor *ret) const { - if (party() == 0) { - share(0)->bitwise_not(ret->share(0)); - share(1)->copy(ret->share(1)); - } else if (party() == 1) { - share(0)->copy(ret->share(0)); - share(1)->copy(ret->share(1)); - } else { - share(0)->copy(ret->share(0)); - share(1)->bitwise_not(ret->share(1)); - } +template +void BooleanTensor::bitwise_not(BooleanTensor* ret) const { + if (party() == 0) { + share(0)->bitwise_not(ret->share(0)); + share(1)->copy(ret->share(1)); + } else if (party() == 1) { + share(0)->copy(ret->share(0)); + share(1)->copy(ret->share(1)); + } else { + share(0)->copy(ret->share(0)); + share(1)->bitwise_not(ret->share(1)); + } } -template -void BooleanTensor::lshift(size_t rhs, BooleanTensor *ret) const { - share(0)->lshift(rhs, ret->share(0)); - share(1)->lshift(rhs, ret->share(1)); +template +void BooleanTensor::lshift(size_t rhs, BooleanTensor* ret) const { + share(0)->lshift(rhs, ret->share(0)); + share(1)->lshift(rhs, ret->share(1)); } -template -void BooleanTensor::rshift(size_t rhs, BooleanTensor *ret) const { - share(0)->rshift(rhs, ret->share(0)); - share(1)->rshift(rhs, ret->share(1)); +template +void BooleanTensor::rshift(size_t rhs, BooleanTensor* ret) const { + share(0)->rshift(rhs, ret->share(0)); + share(1)->rshift(rhs, ret->share(1)); } -template -void BooleanTensor::logical_rshift(size_t rhs, BooleanTensor *ret) const { - share(0)->logical_rshift(rhs, ret->share(0)); - share(1)->logical_rshift(rhs, ret->share(1)); +template +void BooleanTensor::logical_rshift(size_t rhs, BooleanTensor* ret) const { + share(0)->logical_rshift(rhs, ret->share(0)); + share(1)->logical_rshift(rhs, ret->share(1)); } -template -void BooleanTensor::ppa(const BooleanTensor *rhs, BooleanTensor *ret, +template +void BooleanTensor::ppa(const BooleanTensor* rhs, + BooleanTensor* ret, size_t n_bits) const { - // kogge stone adder from tfe - // https://github.com/tf-encrypted - // TODO: check T is int64_t other native type not support yet - const size_t k = std::ceil(std::log2(n_bits)); - std::vector keep_masks(k); - for (size_t i = 0; i < k; ++i) { - keep_masks[i] = (T(1) << (T)std::exp2(i)) - 1; - } - - std::shared_ptr> tmp[11]; - for (auto &ti : tmp) { - ti = tensor_factory()->template create(ret->shape()); - } - BooleanTensor g(tmp[0].get(), tmp[1].get()); - BooleanTensor p(tmp[2].get(), tmp[3].get()); - BooleanTensor g1(tmp[4].get(), tmp[5].get()); - BooleanTensor p1(tmp[6].get(), tmp[7].get()); - BooleanTensor c(tmp[8].get(), tmp[9].get()); - auto k_mask = tmp[10].get(); - - bitwise_and(rhs, &g); - bitwise_xor(rhs, &p); - - for (size_t i = 0; i < k; ++i) { - - std::transform(k_mask->data(), k_mask->data() + k_mask->numel(), - k_mask->data(), - [&keep_masks, i](T) -> T { return keep_masks[i]; }); - - g.lshift(std::exp2(i), &g1); - p.lshift(std::exp2(i), &p1); - - p1.bitwise_xor(k_mask, &p1); - g1.bitwise_and(&p, &c); - - g.bitwise_xor(&c, &g); - p.bitwise_and(&p1, &p); - } - g.lshift(1, &c); - bitwise_xor(rhs, &p); - - c.bitwise_xor(&p, ret); + // kogge stone adder from tfe + // https://github.com/tf-encrypted + // TODO: check T is int64_t other native type not support yet + const size_t k = std::ceil(std::log2(n_bits)); + std::vector keep_masks(k); + for (size_t i = 0; i < k; ++i) { + keep_masks[i] = (T(1) << (T) std::exp2(i)) - 1; + } + + std::shared_ptr> tmp[11]; + for (auto& ti: tmp) { + ti = tensor_factory()->template create(ret->shape()); + } + BooleanTensor g(tmp[0].get(), tmp[1].get()); + BooleanTensor p(tmp[2].get(), tmp[3].get()); + BooleanTensor g1(tmp[4].get(), tmp[5].get()); + BooleanTensor p1(tmp[6].get(), tmp[7].get()); + BooleanTensor c(tmp[8].get(), tmp[9].get()); + auto k_mask = tmp[10].get(); + + bitwise_and(rhs, &g); + bitwise_xor(rhs, &p); + + for (size_t i = 0; i < k; ++i) { + + std::transform(k_mask->data(), k_mask->data() + k_mask->numel(), + k_mask->data(), + [&keep_masks, i](T) -> T { return keep_masks[i]; }); + + g.lshift(std::exp2(i), &g1); + p.lshift(std::exp2(i), &p1); + + + p1.bitwise_xor(k_mask, &p1); + g1.bitwise_and(&p, &c); + + g.bitwise_xor(&c, &g); + p.bitwise_and(&p1, &p); + } + g.lshift(1, &c); + bitwise_xor(rhs, &p); + + c.bitwise_xor(&p, ret); } -template -void a2b(CircuitContext *aby3_ctx, TensorAdapterFactory *tensor_factory, - const FixedPointTensor *a, BooleanTensor *b, size_t n_bits) { +template +void a2b(CircuitContext* aby3_ctx, + TensorAdapterFactory* tensor_factory, + const FixedPointTensor* a, + BooleanTensor* b, + size_t n_bits) { - std::shared_ptr> tmp[4]; - for (auto &ti : tmp) { - ti = tensor_factory->template create(a->shape()); - // set 0 - std::transform(ti->data(), ti->data() + ti->numel(), ti->data(), - [](T) -> T { return 0; }); - } + std::shared_ptr> tmp[4]; + for (auto& ti: tmp) { + ti = tensor_factory->template create(a->shape()); + // set 0 + std::transform(ti->data(), ti->data() + ti->numel(), ti->data(), + [](T) -> T { return 0; }); + } - std::shared_ptr> lhs = - std::make_shared>(tmp[0].get(), tmp[1].get()); - std::shared_ptr> rhs = - std::make_shared>(tmp[2].get(), tmp[3].get()); + std::shared_ptr> lhs = + std::make_shared>(tmp[0].get(), tmp[1].get()); + std::shared_ptr> rhs = + std::make_shared>(tmp[2].get(), tmp[3].get()); - if (aby3_ctx->party() == 0) { - a->share(0)->add(a->share(1), lhs->share(0)); + if (aby3_ctx->party() == 0) { + a->share(0)->add(a->share(1), lhs->share(0)); - // reshare x0 + x1 - aby3_ctx->template gen_zero_sharing_boolean(*lhs->share(1)); - lhs->share(0)->bitwise_xor(lhs->share(1), lhs->share(0)); + // reshare x0 + x1 + aby3_ctx->template gen_zero_sharing_boolean(*lhs->share(1)); + lhs->share(0)->bitwise_xor(lhs->share(1), lhs->share(0)); - aby3_ctx->network()->template send(2, *(lhs->share(0))); - aby3_ctx->network()->template recv(1, *(lhs->share(1))); + aby3_ctx->network()->template send(2, *(lhs->share(0))); + aby3_ctx->network()->template recv(1, *(lhs->share(1))); - } else if (aby3_ctx->party() == 1) { + } else if (aby3_ctx->party() == 1) { - aby3_ctx->template gen_zero_sharing_boolean(*lhs->share(0)); - aby3_ctx->network()->template send(0, *(lhs->share(0))); - aby3_ctx->network()->template recv(2, *(lhs->share(1))); + aby3_ctx->template gen_zero_sharing_boolean(*lhs->share(0)); + aby3_ctx->network()->template send(0, *(lhs->share(0))); + aby3_ctx->network()->template recv(2, *(lhs->share(1))); - a->share(1)->copy(rhs->share(1)); + a->share(1)->copy(rhs->share(1)); - } else { // party == 2 + } else { // party == 2 - aby3_ctx->template gen_zero_sharing_boolean(*lhs->share(0)); + aby3_ctx->template gen_zero_sharing_boolean(*lhs->share(0)); - aby3_ctx->network()->template recv(0, *(lhs->share(1))); - aby3_ctx->network()->template send(1, *(lhs->share(0))); + aby3_ctx->network()->template recv(0, *(lhs->share(1))); + aby3_ctx->network()->template send(1, *(lhs->share(0))); - a->share(0)->copy(rhs->share(0)); - } + a->share(0)->copy(rhs->share(0)); + } - lhs->ppa(rhs.get(), b, n_bits); + lhs->ppa(rhs.get(), b, n_bits); } -template -template -BooleanTensor &BooleanTensor:: -operator=(const FixedPointTensor *other) { - a2b(aby3_ctx().get(), tensor_factory().get(), other, this, sizeof(T) * 8); - return *this; +template +template +BooleanTensor& BooleanTensor::operator=(const FixedPointTensor* other) { + a2b(aby3_ctx().get(), tensor_factory().get(), other, this, sizeof(T) * 8); + return *this; } template -void tensor_rshift_transform(const TensorAdapter *lhs, size_t rhs, - TensorAdapter *ret) { - const T *begin = lhs->data(); - std::transform(begin, begin + lhs->numel(), ret->data(), - [rhs](T in) { return (in >> rhs) & 1; }); +void tensor_rshift_transform(const TensorAdapter* lhs, + size_t rhs, TensorAdapter* ret) { + const T* begin = lhs->data(); + std::transform(begin, begin + lhs->numel(), ret->data(), + [rhs](T in) { return (in >> rhs) & 1; }); }; -template -template -void BooleanTensor::bit_extract(size_t i, const FixedPointTensor *in) { - a2b(aby3_ctx().get(), tensor_factory().get(), in, this, i + 1); +template +template +void BooleanTensor::bit_extract(size_t i, const FixedPointTensor* in) { + a2b(aby3_ctx().get(), tensor_factory().get(), in, this, i + 1); - tensor_rshift_transform(share(0), i, share(0)); - tensor_rshift_transform(share(1), i, share(1)); + tensor_rshift_transform(share(0), i, share(0)); + tensor_rshift_transform(share(1), i, share(1)); } -template -void BooleanTensor::bit_extract(size_t i, BooleanTensor *ret) const { - tensor_rshift_transform(share(0), i, ret->share(0)); - tensor_rshift_transform(share(1), i, ret->share(1)); +template +void BooleanTensor::bit_extract(size_t i, BooleanTensor* ret) const { + tensor_rshift_transform(share(0), i, ret->share(0)); + tensor_rshift_transform(share(1), i, ret->share(1)); } -template -template -void BooleanTensor::b2a(FixedPointTensor *ret) const { - std::shared_ptr> tmp[2]; - for (auto &ti : tmp) { - ti = tensor_factory()->template create(shape()); - // set 0 - std::transform(ti->data(), ti->data() + ti->numel(), ti->data(), - [](T) -> T { return 0; }); - } - BooleanTensor bt(tmp[0].get(), tmp[1].get()); - - if (party() == 1) { - aby3_ctx()->template gen_random(*ret->mutable_share(0), 0); - aby3_ctx()->template gen_random(*ret->mutable_share(1), 1); - ret->share(0)->add(ret->share(1), tmp[0].get()); - tmp[0]->negative(tmp[0].get()); - aby3_ctx()->network()->template send(0, *(tmp[0].get())); - } else if (party() == 0) { - aby3_ctx()->network()->template recv(1, *(tmp[1].get())); - // dummy gen random, for prng sync - aby3_ctx()->template gen_random(*ret->mutable_share(1), 1); - } else { // party == 2 - aby3_ctx()->template gen_random(*ret->mutable_share(0), 0); - } - - bt.ppa(this, &bt, sizeof(T) * 8); - - TensorAdapter *dest = nullptr; - if (party() == 0) { - dest = ret->mutable_share(0); - } - - bt.reveal_to_one(0, dest); - - if (party() == 0) { - aby3_ctx()->network()->template recv(1, *(ret->mutable_share(1))); - aby3_ctx()->network()->template send(2, *(ret->mutable_share(0))); - } else if (party() == 1) { - aby3_ctx()->network()->template send(0, *(ret->mutable_share(0))); - } else { // party == 2 - aby3_ctx()->network()->template recv(0, *(ret->mutable_share(1))); - } +template +template +void BooleanTensor::b2a(FixedPointTensor* ret) const { + std::shared_ptr> tmp[2]; + for (auto& ti: tmp) { + ti = tensor_factory()->template create(shape()); + // set 0 + std::transform(ti->data(), ti->data() + ti->numel(), ti->data(), + [](T) -> T { return 0; }); + } + BooleanTensor bt(tmp[0].get(), tmp[1].get()); + + if (party() == 1) { + aby3_ctx()->template gen_random(*ret->mutable_share(0), 0); + aby3_ctx()->template gen_random(*ret->mutable_share(1), 1); + ret->share(0)->add(ret->share(1), tmp[0].get()); + tmp[0]->negative(tmp[0].get()); + aby3_ctx()->network()->template send(0, *(tmp[0].get())); + } else if (party() == 0) { + aby3_ctx()->network()->template recv(1, *(tmp[1].get())); + // dummy gen random, for prng sync + aby3_ctx()->template gen_random(*ret->mutable_share(1), 1); + } else { // party == 2 + aby3_ctx()->template gen_random(*ret->mutable_share(0), 0); + } + + bt.ppa(this, &bt, sizeof(T) * 8); + + TensorAdapter* dest = nullptr; + if (party() == 0) { + dest = ret->mutable_share(0); + } + + bt.reveal_to_one(0, dest); + + if (party() == 0) { + aby3_ctx()->network()->template recv(1, *(ret->mutable_share(1))); + aby3_ctx()->network()->template send(2, *(ret->mutable_share(0))); + } else if (party() == 1) { + aby3_ctx()->network()->template send(0, *(ret->mutable_share(0))); + } else { // party == 2 + aby3_ctx()->network()->template recv(0, *(ret->mutable_share(1))); + } } -template -template -void BooleanTensor::mul(const TensorAdapter *rhs, - FixedPointTensor *ret, +template +template +void BooleanTensor::mul(const TensorAdapter* rhs, + FixedPointTensor* ret, size_t rhs_party) const { - // ot sender - size_t idx0 = rhs_party; - - size_t idx1 = (rhs_party + 1) % 3; - - size_t idx2 = (rhs_party + 2) % 3; - - auto tmp0 = tensor_factory()->template create(ret->shape()); - auto tmp1 = tensor_factory()->template create(ret->shape()); - - TensorAdapter *tmp[2] = {tmp0.get(), tmp1.get()}; - - TensorAdapter *null_arg[2] = {nullptr, nullptr}; - - if (party() == idx0) { - // use ret as buffer - TensorAdapter *m[2] = {ret->mutable_share(0), ret->mutable_share(1)}; - - aby3_ctx()->template gen_zero_sharing_arithmetic(*tmp[0]); - - // m0 = a * (b0 ^ b1) + s0 - // m1 = a * (1 ^ b0 ^ b1) + s0 - share(0)->bitwise_xor(share(1), m[0]); - std::transform(m[0]->data(), m[0]->data() + m[0]->numel(), m[1]->data(), - [](T in) { return 1 ^ in; }); - - m[0]->mul(rhs, m[0]); - m[1]->mul(rhs, m[1]); - - m[0]->add(tmp[0], m[0]); - m[1]->add(tmp[0], m[1]); - - aby3_ctx()->template ot(idx0, idx1, idx2, null_arg[0], - const_cast **>(m), tmp, - null_arg[0]); - - // ret0 = s2 - // ret1 = s1 - aby3_ctx()->network()->template recv(idx2, *(ret->mutable_share(0))); - aby3_ctx()->network()->template recv(idx1, *(ret->mutable_share(1))); - - } else if (party() == idx1) { - // ret0 = s1 - aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(0))); - // ret1 = a * b + s0 - aby3_ctx()->template ot( - idx0, idx1, idx2, share(1), - const_cast **>(null_arg), tmp, - ret->mutable_share(1)); - aby3_ctx()->network()->template send(idx0, *(ret->share(0))); - aby3_ctx()->network()->template send(idx2, *(ret->share(1))); - } else if (party() == idx2) { - // ret0 = a * b + s0 - aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(1))); - // ret1 = s2 - aby3_ctx()->template ot( - idx0, idx1, idx2, share(0), - const_cast **>(null_arg), tmp, - null_arg[0]); - - aby3_ctx()->network()->template send(idx0, *(ret->share(1))); - - aby3_ctx()->network()->template recv(idx1, *(ret->mutable_share(0))); - } + // ot sender + size_t idx0 = rhs_party; + + size_t idx1 = (rhs_party + 1) % 3; + + size_t idx2 = (rhs_party + 2) % 3; + + auto tmp0 = tensor_factory()->template create(ret->shape()); + auto tmp1 = tensor_factory()->template create(ret->shape()); + + TensorAdapter* tmp[2] = {tmp0.get(), tmp1.get()}; + + TensorAdapter* null_arg[2] = {nullptr, nullptr}; + + if (party() == idx0) { + // use ret as buffer + TensorAdapter* m[2] = {ret->mutable_share(0), ret->mutable_share(1)}; + + aby3_ctx()->template gen_zero_sharing_arithmetic(*tmp[0]); + + // m0 = a * (b0 ^ b1) + s0 + // m1 = a * (1 ^ b0 ^ b1) + s0 + share(0)->bitwise_xor(share(1), m[0]); + std::transform(m[0]->data(), m[0]->data() + m[0]->numel(), m[0]->data(), + [](T in) { return 1 & in; }); + std::transform(m[0]->data(), m[0]->data() + m[0]->numel(), m[1]->data(), + [](T in) { return 1 ^ in; }); + + m[0]->mul(rhs, m[0]); + m[1]->mul(rhs, m[1]); + + m[0]->add(tmp[0], m[0]); + m[1]->add(tmp[0], m[1]); + + aby3_ctx()->template ot(idx0, idx1, idx2, null_arg[0], + const_cast**>(m), + tmp, null_arg[0]); + + // ret0 = s2 + // ret1 = s1 + aby3_ctx()->network()->template recv(idx2, *(ret->mutable_share(0))); + aby3_ctx()->network()->template recv(idx1, *(ret->mutable_share(1))); + + } else if (party() == idx1) { + // ret0 = s1 + aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(0))); + // ret1 = a * b + s0 + aby3_ctx()->template ot(idx0, idx1, idx2, share(1), + const_cast**>(null_arg), + tmp, ret->mutable_share(1)); + aby3_ctx()->network()->template send(idx0, *(ret->share(0))); + aby3_ctx()->network()->template send(idx2, *(ret->share(1))); + } else if (party() == idx2) { + // ret0 = a * b + s0 + aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(1))); + // ret1 = s2 + aby3_ctx()->template ot(idx0, idx1, idx2, share(0), + const_cast**>(null_arg), + tmp, null_arg[0]); + + aby3_ctx()->network()->template send(idx0, *(ret->share(1))); + + aby3_ctx()->network()->template recv(idx1, *(ret->mutable_share(0))); + } } -template -template -void BooleanTensor::mul(const FixedPointTensor *rhs, - FixedPointTensor *ret) const { - auto tmp0 = tensor_factory()->template create(ret->shape()); - auto tmp1 = tensor_factory()->template create(ret->shape()); - auto tmp2 = tensor_factory()->template create(ret->shape()); - - FixedPointTensor tmp(tmp0.get(), tmp1.get()); - - if (party() == 0) { - mul(nullptr, ret, 1); - mul(rhs->share(0), &tmp, 0); - ret->add(&tmp, ret); - - } else if (party() == 1) { - rhs->share(0)->add(rhs->share(1), tmp2.get()); - mul(tmp2.get(), ret, 1); - mul(nullptr, &tmp, 0); - ret->add(&tmp, ret); - - } else { // party() == 2 - mul(nullptr, ret, 1); - mul(nullptr, &tmp, 0); - ret->add(&tmp, ret); - } +template +template +void BooleanTensor::mul(const FixedPointTensor* rhs, + FixedPointTensor* ret) const { + std::vector>> tmp; + + for (int i = 0; i < 4; ++i) { + tmp.emplace_back( + tensor_factory()->template create(ret->shape())); + } + + FixedPointTensor tmp0(tmp[0].get(), tmp[1].get()); + FixedPointTensor tmp1(tmp[2].get(), tmp[3].get()); + + if (party() == 0) { + mul(nullptr, &tmp0, 1); + mul(rhs->share(0), &tmp1, 0); + } else if (party() == 1) { + rhs->share(0)->add(rhs->share(1), tmp[2].get()); + mul(tmp[2].get(), &tmp0, 1); + mul(nullptr, &tmp1, 0); + } else { // party() == 2 + mul(nullptr, &tmp0, 1); + mul(nullptr, &tmp1, 0); + } + tmp0.add(&tmp1, ret); +} +template +void BooleanTensor::onehot_from_cmp() { + // cmp is done slice by slice + // suppose that shape = [k, m, n, ...] + // shape of all slices and tmp tensors = [1, m, n] + auto shape_ = shape(); + size_t len = shape_[0]; + shape_[0] = 1; + std::vector>> tmp; + + for (int i = 0; i < 4; ++i) { + tmp.emplace_back( + tensor_factory()->template create(shape_)); + } + + tmp.emplace_back(tensor_factory()->template create()); + tmp.emplace_back(tensor_factory()->template create()); + + BooleanTensor found(tmp[0].get(), tmp[1].get()); + + assign_to_tensor(tmp[0].get(), T(0)); + assign_to_tensor(tmp[1].get(), T(0)); + + BooleanTensor not_found(tmp[2].get(), tmp[3].get()); + + // res[i] = !found & input[i] + // found = found 1 res[i] + // to find last 1, we search backward + for (size_t i = len; i > 0; --i) { + share(0)->slice(i - 1, i, tmp[4].get()); + share(1)->slice(i - 1, i, tmp[5].get()); + BooleanTensor cmp_i(tmp[4].get(), tmp[5].get()); + found.bitwise_not(¬_found); + not_found.bitwise_and(&cmp_i, &cmp_i); + cmp_i.bitwise_or(&found, &found); + } } } // namespace aby3 diff --git a/core/privc3/boolean_tensor_test.cc b/core/privc3/boolean_tensor_test.cc index c44d6cd0e2d92a44df4e407427aba9f63b09a304..984fbb5a81e45d6dd1eb8113fecbbd9425319f9b 100644 --- a/core/privc3/boolean_tensor_test.cc +++ b/core/privc3/boolean_tensor_test.cc @@ -1214,9 +1214,9 @@ TEST_F(BooleanTensorTest, abmul_test) { gen1(), gen1(), gen1()}; // lhs = 1 - sl[0]->data()[0] = 1; - sl[1]->data()[0] = 0; - sl[2]->data()[0] = 0; + sl[0]->data()[0] = -1; + sl[1]->data()[0] = -3; + sl[2]->data()[0] = 3; BTensor b0(sl[0].get(), sl[1].get()); BTensor b1(sl[1].get(), sl[2].get()); @@ -1273,9 +1273,9 @@ TEST_F(BooleanTensorTest, abmul2_test) { gen1(), gen1(), gen1()}; // lhs = 1 - sl[0]->data()[0] = 1; - sl[1]->data()[0] = 0; - sl[2]->data()[0] = 0; + sl[0]->data()[0] = -3; + sl[1]->data()[0] = -1; + sl[2]->data()[0] = 3; // rhs = 12 = 3 + 4 + 5 sr[0]->data()[0] = 3; @@ -1330,4 +1330,197 @@ TEST_F(BooleanTensorTest, abmul2_test) { } EXPECT_EQ(1 * 12, p->data()[0]); } + +TEST_F(BooleanTensorTest, abmul3_test) { + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sr[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1()}; + + + // lhs = 0 + sl[0]->data()[0] = 373964488827046757; + sl[1]->data()[0] = -2697357730885869060; + sl[2]->data()[0] = -2332413979122373991; + + // rhs = -1 + sr[0]->data()[0] = 8388121746490115866; + sr[1]->data()[0] = 5851959018403668595; + sr[2]->data()[0] = 4206663308815767154; + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + FTensor fr0(sr[0].get(), sr[1].get()); + FTensor fr1(sr[1].get(), sr[2].get()); + FTensor fr2(sr[2].get(), sr[0].get()); + + FTensor fout0(sout[0].get(), sout[1].get()); + FTensor fout1(sout[2].get(), sout[3].get()); + FTensor fout2(sout[4].get(), sout[5].get()); + + auto p = gen1(); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.mul(&fr0, &fout0); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.mul(&fr1, &fout1); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.mul(&fr2, &fout2); + fout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(0, p->data()[0]); +} + +TEST_F(BooleanTensorTest, abmul4_test) { + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sr[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1()}; + + + // lhs = 1 + sl[0]->data()[0] = 373964488827046757; + sl[1]->data()[0] = -2697357730885869060; + sl[2]->data()[0] = -2332413979122373992; + + // rhs = -1 + sr[0]->data()[0] = 8388121746490115866; + sr[1]->data()[0] = 5851959018403668595; + sr[2]->data()[0] = 4206663308815767154; + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + FTensor fr0(sr[0].get(), sr[1].get()); + FTensor fr1(sr[1].get(), sr[2].get()); + FTensor fr2(sr[2].get(), sr[0].get()); + + FTensor fout0(sout[0].get(), sout[1].get()); + FTensor fout1(sout[2].get(), sout[3].get()); + FTensor fout2(sout[4].get(), sout[5].get()); + + auto p = gen1(); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.mul(&fr0, &fout0); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.mul(&fr1, &fout1); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.mul(&fr2, &fout2); + fout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(-1, p->data()[0]); +} + +TEST_F(BooleanTensorTest, onehot_from_cmp_test) { + std::vector shape = {4, 1}; + std::shared_ptr> sout[6] = + { gen(shape), gen(shape), gen(shape), gen(shape), gen(shape), gen(shape)}; + + for (auto& ptr: sout) { + assign_to_tensor(ptr.get(), 0l); + } + + sout[0].get()->data()[0] = 1; + sout[0].get()->data()[2] = 1; + + sout[5].get()->data()[0] = 1; + sout[5].get()->data()[2] = 1; + + // input plaintext [1010] + + BTensor bout0(sout[0].get(), sout[1].get()); + BTensor bout1(sout[2].get(), sout[3].get()); + BTensor bout2(sout[4].get(), sout[5].get()); + + auto p = gen(shape); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bout0.onehot_from_cmp(); + bout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bout1.onehot_from_cmp(); + bout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bout2.onehot_from_cmp(); + bout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(0, p->data()[0]); + EXPECT_EQ(0, p->data()[1]); + EXPECT_EQ(1, p->data()[2]); + EXPECT_EQ(0, p->data()[3]); +} } // namespace aby3 diff --git a/core/privc3/circuit_context.h b/core/privc3/circuit_context.h index 26b7a7a8fd4ad54e625110ce63b265a962dd34ac..ed75e31d8f1770537d7ca7ddde401e31de12246e 100644 --- a/core/privc3/circuit_context.h +++ b/core/privc3/circuit_context.h @@ -13,7 +13,6 @@ // limitations under the License. #pragma once -#include #include #include @@ -26,165 +25,185 @@ using AbstractNetwork = paddle::mpc::AbstractNetwork; class CircuitContext { public: - CircuitContext(size_t party, std::shared_ptr network, - const block &seed = g_zero_block, - const block &seed2 = g_zero_block) { - init(party, network, seed, seed2); - } + CircuitContext(size_t party, + std::shared_ptr network, + const block& seed = g_zero_block, + const block& seed2 = g_zero_block) { + init(party, network, seed, seed2); + } + + CircuitContext(const CircuitContext& other) = delete; + + CircuitContext& operator=(const CircuitContext& other) = delete; + + void init(size_t party, + std::shared_ptr network, + block seed, + block seed2) { + set_party(party); + set_network(network); + + if (equals(seed, g_zero_block)) { + seed = block_from_dev_urandom(); + } + + if (equals(seed2, g_zero_block)) { + seed2 = block_from_dev_urandom(); + } + set_random_seed(seed, 0); + // seed2 is private + set_random_seed(seed2, 2); + + // 3 for 3-party computation + size_t party_pre = (this->party() - 1 + 3) % 3; + size_t party_next = (this->party() + 1) % 3; + + if (party == 1) { + block recv_seed = this->network()->template recv(party_next); + this->network()->template send(party_pre, seed); + seed = recv_seed; + } else { + this->network()->template send(party_pre, seed); + seed = this->network()->template recv(party_next); + } + + set_random_seed(seed, 1); + } + + void set_party(size_t party) { + if (party >= 3) { + // exception handling + } + _party = party; + } + + void set_network(std::shared_ptr network) { + _network = network; + } - CircuitContext(const CircuitContext &other) = delete; + AbstractNetwork* network() { + return _network.get(); + } + + void set_random_seed(const block& seed, size_t idx) { + if (idx >= 3) { + // exception handling + } + _prng[idx].set_seed(seed); + } + + size_t party() const { + return _party; + } - CircuitContext &operator=(const CircuitContext &other) = delete; + size_t pre_party() const { + return (_party + 3 - 1) % 3; + } - void init(size_t party, std::shared_ptr network, block seed, - block seed2) { - set_party(party); - set_network(network); + size_t next_party() const { + return (_party + 1) % 3; + } - if (equals(seed, g_zero_block)) { - seed = block_from_dev_urandom(); + template + T gen_random(bool next) { + return _prng[next].get(); } - if (equals(seed2, g_zero_block)) { - seed2 = block_from_dev_urandom(); - } - set_random_seed(seed, 0); - // seed2 is private - set_random_seed(seed2, 2); - - // 3 for 3-party computation - size_t party_pre = (this->party() - 1 + 3) % 3; - size_t party_next = (this->party() + 1) % 3; - - if (party == 1) { - block recv_seed = this->network()->template recv(party_next); - this->network()->template send(party_pre, seed); - seed = recv_seed; - } else { - this->network()->template send(party_pre, seed); - seed = this->network()->template recv(party_next); - } - - set_random_seed(seed, 1); - } - - void set_party(size_t party) { - if (party >= 3) { - // exception handling - } - _party = party; - } - - void set_network(std::shared_ptr network) { - _network = network; - } - - AbstractNetwork *network() { return _network.get(); } - - void set_random_seed(const block &seed, size_t idx) { - if (idx >= 3) { - // exception handling - } - _prng[idx].set_seed(seed); - } - - size_t party() const { return _party; } - - size_t pre_party() const { return (_party + 3 - 1) % 3; } - - size_t next_party() const { return (_party + 1) % 3; } - - template T gen_random(bool next) { return _prng[next].get(); } - - template class Tensor> - void gen_random(Tensor &tensor, bool next) { - std::for_each( - tensor.data(), tensor.data() + tensor.numel(), - [this, next](T &val) { val = this->template gen_random(next); }); - } - - template T gen_random_private() { return _prng[2].get(); } - - template class Tensor> - void gen_random_private(Tensor &tensor) { - std::for_each( - tensor.data(), tensor.data() + tensor.numel(), - [this](T &val) { val = this->template gen_random_private(); }); - } - - template T gen_zero_sharing_arithmetic() { - return _prng[0].get() - _prng[1].get(); - } - - template class Tensor> - void gen_zero_sharing_arithmetic(Tensor &tensor) { - std::for_each(tensor.data(), tensor.data() + tensor.numel(), - [this](T &val) { - val = this->template gen_zero_sharing_arithmetic(); - }); - } - - template T gen_zero_sharing_boolean() { - return _prng[0].get() ^ _prng[1].get(); - } - - template class Tensor> - void gen_zero_sharing_boolean(Tensor &tensor) { - std::for_each( - tensor.data(), tensor.data() + tensor.numel(), - [this](T &val) { val = this->template gen_zero_sharing_boolean(); }); - } - - template class Tensor> - void ot(size_t sender, size_t receiver, size_t helper, - const Tensor *choice, const Tensor *m[2], Tensor *buffer[2], - Tensor *ret) { - // TODO: check tensor shape equals - const size_t numel = buffer[0]->numel(); - if (party() == sender) { - bool common = helper == next_party(); - this->template gen_random(*buffer[0], common); - this->template gen_random(*buffer[1], common); - for (size_t i = 0; i < numel; ++i) { - buffer[0]->data()[i] ^= m[0]->data()[i]; - buffer[1]->data()[i] ^= m[1]->data()[i]; - } - network()->template send(receiver, *buffer[0]); - network()->template send(receiver, *buffer[1]); - - } else if (party() == helper) { - bool common = sender == next_party(); - - this->template gen_random(*buffer[0], common); - this->template gen_random(*buffer[1], common); - - for (size_t i = 0; i < numel; ++i) { - // TODO: check if choice is one bit - buffer[0]->data()[i] = - choice->data()[i] ? buffer[1]->data()[i] : buffer[0]->data()[i]; - } - network()->template send(receiver, *buffer[0]); - } else if (party() == receiver) { - network()->template recv(sender, *buffer[0]); - network()->template recv(sender, *buffer[1]); - network()->template recv(helper, *ret); - size_t i = 0; - std::for_each(ret->data(), ret->data() + numel, - [&buffer, &i, choice, ret](T &in) { - // TODO: check if choice is one bit - bool c = choice->data()[i]; - in ^= buffer[c]->data()[i]; - ++i; - }); - } - } + template class Tensor> + void gen_random(Tensor& tensor, bool next) { + std::for_each(tensor.data(), tensor.data() + tensor.numel(), + [this, next](T& val) { + val = this->template gen_random(next); + }); + } + + template + T gen_random_private() { + return _prng[2].get(); + } + + template class Tensor> + void gen_random_private(Tensor& tensor) { + std::for_each(tensor.data(), tensor.data() + tensor.numel(), + [this](T& val) { + val = this->template gen_random_private(); + }); + } + + template + T gen_zero_sharing_arithmetic() { + return _prng[0].get() - _prng[1].get(); + } + + template class Tensor> + void gen_zero_sharing_arithmetic(Tensor& tensor) { + std::for_each(tensor.data(), tensor.data() + tensor.numel(), + [this](T& val) { + val = this->template gen_zero_sharing_arithmetic(); + }); + } + + template + T gen_zero_sharing_boolean() { + return _prng[0].get() ^ _prng[1].get(); + } + + template class Tensor> + void gen_zero_sharing_boolean(Tensor& tensor) { + std::for_each(tensor.data(), tensor.data() + tensor.numel(), + [this](T& val) { + val = this->template gen_zero_sharing_boolean(); + }); + } + + template class Tensor> + void ot(size_t sender, size_t receiver, size_t helper, + const Tensor* choice, const Tensor* m[2], + Tensor* buffer[2], Tensor* ret) { + // TODO: check tensor shape equals + const size_t numel = buffer[0]->numel(); + if (party() == sender) { + bool common = helper == next_party(); + this->template gen_random(*buffer[0], common); + this->template gen_random(*buffer[1], common); + for (size_t i = 0; i < numel; ++i) { + buffer[0]->data()[i] ^= m[0]->data()[i]; + buffer[1]->data()[i] ^= m[1]->data()[i]; + } + network()->template send(receiver, *buffer[0]); + network()->template send(receiver, *buffer[1]); + + } else if (party() == helper) { + bool common = sender == next_party(); + + this->template gen_random(*buffer[0], common); + this->template gen_random(*buffer[1], common); + + for (size_t i = 0; i < numel; ++i) { + buffer[0]->data()[i] = choice->data()[i] & 1 ? + buffer[1]->data()[i] : buffer[0]->data()[i]; + } + network()->template send(receiver, *buffer[0]); + } else if (party() == receiver) { + network()->template recv(sender, *buffer[0]); + network()->template recv(sender, *buffer[1]); + network()->template recv(helper, *ret); + size_t i = 0; + std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) { + bool c = choice->data()[i] & 1; + in ^= buffer[c]->data()[i]; + ++i;} + ); + } + } private: - size_t _party; + size_t _party; + + std::shared_ptr _network; - std::shared_ptr _network; + PseudorandomNumberGenerator _prng[3]; - PseudorandomNumberGenerator _prng[3]; }; } // namespace aby3 diff --git a/core/privc3/fixedpoint_tensor.h b/core/privc3/fixedpoint_tensor.h index b3ca55f93a6115ebbe3266e343af8297c74b9a7b..35e21b6e3e48550fe97eeb64bc6f0d9ed9f6a747 100644 --- a/core/privc3/fixedpoint_tensor.h +++ b/core/privc3/fixedpoint_tensor.h @@ -16,164 +16,243 @@ #include -#include "boolean_tensor.h" #include "circuit_context.h" -#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "paddle_tensor.h" +#include "boolean_tensor.h" +#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" namespace aby3 { -template class FixedPointTensor { +template +class FixedPointTensor { public: - explicit FixedPointTensor(TensorAdapter *share_tensor[2]); + explicit FixedPointTensor(TensorAdapter* share_tensor[2]); + + explicit FixedPointTensor(TensorAdapter* share_tensor_0, + TensorAdapter* share_tensor_1); - explicit FixedPointTensor(TensorAdapter *share_tensor_0, - TensorAdapter *share_tensor_1); + ~FixedPointTensor() {}; - ~FixedPointTensor(){}; + //get mutable shape of tensor + TensorAdapter* mutable_share(size_t idx); - // get mutable shape of tensor - TensorAdapter *mutable_share(size_t idx); + const TensorAdapter* share(size_t idx) const; - const TensorAdapter *share(size_t idx) const; + size_t numel() const { + return _share[0]->numel(); + } + + // reveal fixedpointtensor to one party + void reveal_to_one(size_t party, TensorAdapter* ret) const; + + // reveal fixedpointtensor to all parties + void reveal(TensorAdapter* ret) const; - size_t numel() const { return _share[0]->numel(); } + const std::vector shape() const; - // reveal fixedpointtensor to one party - void reveal_to_one(size_t party, TensorAdapter *ret) const; + //convert TensorAdapter to shares + static void share(const TensorAdapter* input, + TensorAdapter* output_shares[3], + block seed = g_zero_block); - // reveal fixedpointtensor to all parties - void reveal(TensorAdapter *ret) const; + // element-wise add with FixedPointTensor + void add(const FixedPointTensor* rhs, FixedPointTensor* ret) const; - const std::vector shape() const; + // element-wise add with TensorAdapter - // convert TensorAdapter to shares - static void share(const TensorAdapter *input, - TensorAdapter *output_shares[3], - block seed = g_zero_block); + void add(const TensorAdapter* rhs, FixedPointTensor* ret) const; - // element-wise add with FixedPointTensor - void add(const FixedPointTensor *rhs, FixedPointTensor *ret) const; + // element-wise sub with FixedPointTensor + void sub(const FixedPointTensor* rhs, FixedPointTensor* ret) const; - // element-wise add with TensorAdapter + // element-wise sub with TensorAdapter + void sub(const TensorAdapter* rhs, FixedPointTensor* ret) const; - void add(const TensorAdapter *rhs, FixedPointTensor *ret) const; + // negative + void negative(FixedPointTensor* ret) const; - // element-wise sub with FixedPointTensor - void sub(const FixedPointTensor *rhs, FixedPointTensor *ret) const; + // element-wise mul with FixedPointTensor using truncate1 + void mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const; - // element-wise sub with TensorAdapter - void sub(const TensorAdapter *rhs, FixedPointTensor *ret) const; + // element-wise mul with TensorAdapter + void mul(const TensorAdapter* rhs, FixedPointTensor* ret) const; - // negative - void negative(FixedPointTensor *ret) const; + // div by TensorAdapter + void div(const TensorAdapter* rhs, FixedPointTensor* ret) const; - // element-wise mul with FixedPointTensor using truncate1 - void mul(const FixedPointTensor *rhs, FixedPointTensor *ret) const; + // div by FixedPointedTensor + // TODO@yqy : not surport operator rhs <= 0 now + void div(const FixedPointTensor* rhs, FixedPointTensor* ret, + size_t iter = 16, double x0 = pow(2, -15)) const; - // element-wise mul with TensorAdapter - void mul(const TensorAdapter *rhs, FixedPointTensor *ret) const; + // long div by boolean circuit + // res_int_len: estimated bit len of the integer part of result + void long_div(const FixedPointTensor* rhs, + FixedPointTensor* ret, size_t res_int_len = 20) const; - // div by TensorAdapter - void div(const TensorAdapter *rhs, FixedPointTensor *ret) const; + void inverse_square_root(FixedPointTensor* ret, + size_t iter = 16, double x0 = 0x1p-10) const; - // element-wise mul, use trunc2 - void mul2(const FixedPointTensor *rhs, FixedPointTensor *ret) const; + // dot_mul + template class CTensor, + size_t... N1> + void dot_mul(const CTensor* rhs, FixedPointTensor* ret) const; - // dot_mul - template