未验证 提交 84c1587a 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #61 from qjing666/new_version

Add fl_mpc code and restructure the framework
repos:
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
sha: v1.0.1
hooks:
- id: remove-crlf
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: 5bf6c09bfa1297d3692cadd621ef95f1284e33c0
hooks:
- id: check-added-large-files
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: end-of-file-fixer
- repo: local
hooks:
- id: clang-format-with-version-check
name: clang-format
description: Format files with ClangFormat.
entry: bash ./tools/codestyle/clang_format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$
- repo: local
hooks:
- id: cpplint-cpp-source
name: cpplint
description: Check C++ code style using cpplint.py.
entry: bash ./tools/codestyle/cpplint_pre_commit.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$
- repo: local
hooks:
- id: pylint-doc-string
name: pylint
description: Check python docstring style using docstring_checker.
entry: bash ./tools/codestyle/pylint_pre_commit.hook
language: system
files: \.(py)$
- repo: local
hooks:
- id: copyright_checker
name: copyright_checker
entry: python ./tools/codestyle/copyright.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$
cmake_minimum_required(VERSION 3.13)
project(PaddleEncrypted)
add_compile_options(-msse4.2 -maes -fPIC -DPADDLE_WITH_MKLDNN)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 11)
if(UNIX AND NOT APPLE)
set(LINUX TRUE)
endif()
if (WIN32)
set(CMAKE_FIND_LIBRARY_SUFFIX dll)
elseif (APPLE)
set(CMAKE_FIND_LIBRARY_SUFFIX dylib)
set(CMAKE_FIND_LIBRARY_PREFIXES lib)
elseif (LINUX)
set(CMAKE_FIND_LIBRARY_SUFFIX so)
set(CMAKE_FIND_LIBRARY_PREFIXES lib)
endif()
if (NOT PYTHON_EXECUTABLE)
set(PYTHON_EXECUTABLE python3)
endif()
find_program(PYTHON ${PYTHON_EXECUTABLE})
if (NOT PYTHON)
message(FATAL_ERROR "${PYTHON_EXECUTABLE} not found")
endif()
execute_process(COMMAND ${PYTHON} -c "import paddle;print(paddle.version.full_version)"
RESULT_VARIABLE ret OUTPUT_VARIABLE paddle_version OUTPUT_STRIP_TRAILING_WHITESPACE)
if (NOT ret)
if (NOT ${paddle_version} STREQUAL "1.6.3")
message(FATAL_ERROR "Paddle installation of 1.6.3 is required but ${paddle_version} is found")
endif()
else()
message(FATAL_ERROR "Could not get paddle version.")
endif()
execute_process(COMMAND ${PYTHON} -c "import paddle; print(paddle.sysconfig.get_include())"
OUTPUT_VARIABLE PADDLE_INCLUDE OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND ${PYTHON} -c "import paddle; print(paddle.sysconfig.get_lib())"
OUTPUT_VARIABLE PADDLE_LIB OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND ${PYTHON} -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())"
OUTPUT_VARIABLE PYTHON_SITE_PACKAGES OUTPUT_STRIP_TRAILING_WHITESPACE)
find_library(FLUID_LIB NAMES paddle_framework PATHS ${PADDLE_LIB})
if (NOT FLUID_LIB)
message(FATAL_ERROR "paddle_framework library is not found in ${PADDLE_LIB}")
endif()
option(WITH_TESTING "Compile with unit testing" ON)
option(WITH_PSI "Compile with psi lib" ON)
########################### the project build part ###############################
message(STATUS "Using paddlepaddle installation of ${paddle_version}")
message(STATUS "paddlepaddle include directory: ${PADDLE_INCLUDE}")
message(STATUS "paddlepaddle libraries directory: ${PADDLE_LIB}")
message(STATUS "python libraries directory: ${PYTHON_SITE_PACKAGES}")
include(third_party)
include(generic)
include_directories(.)
include_directories(${PADDLE_INCLUDE})
include_directories(${PADDLE_INCLUDE}/third_party)
add_subdirectory(core/privc3)
add_subdirectory(core/paddlefl_mpc/mpc_protocol)
add_subdirectory(core/paddlefl_mpc/operators)
add_subdirectory(core/paddlefl_mpc/data_utils)
if (WITH_TESTING)
add_subdirectory(core/testing)
endif()
if (WITH_PSI)
add_subdirectory(core/psi)
endif()
add_library(fluid_framework SHARED IMPORTED GLOBAL)
set_property(TARGET fluid_framework PROPERTY IMPORTED_LOCATION ${FLUID_LIB})
# generate dynamic .so lib
add_library(paddle_enc SHARED
$<TARGET_OBJECTS:privc3_o>
$<TARGET_OBJECTS:mpc_protocol_o>
$<TARGET_OBJECTS:mpc_ops_o>)
target_link_libraries(paddle_enc fluid_framework)
target_link_libraries(paddle_enc gloo)
target_link_libraries(paddle_enc hiredis)
set(CMAKE_SKIP_INSTALL_RPATH TRUE)
set(PADDLE_ENCRYPTED_LIB_PATH "${CMAKE_SOURCE_DIR}/python/paddle_fl/mpc/libs")
install(DIRECTORY "${THIRD_PARTY_PATH}/install/gloo/lib/"
DESTINATION ${PADDLE_ENCRYPTED_LIB_PATH}/third_party)
install(DIRECTORY "${THIRD_PARTY_PATH}/install/hiredis/lib/"
DESTINATION ${PADDLE_ENCRYPTED_LIB_PATH}/third_party)
install(DIRECTORY "${THIRD_PARTY_PATH}/install/openssl/lib/"
DESTINATION ${PADDLE_ENCRYPTED_LIB_PATH}/third_party/openssl)
install(TARGETS paddle_enc mpc_data_utils
LIBRARY DESTINATION ${PADDLE_ENCRYPTED_LIB_PATH}
LIBRARY DESTINATION ${PADDLE_ENCRYPTED_LIB_PATH})
if (WITH_PSI)
install(TARGETS psi LIBRARY DESTINATION ${PADDLE_ENCRYPTED_LIB_PATH})
endif()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(GLOO_PROJECT "extern_gloo")
IF((NOT DEFINED GLOO_VER) OR (NOT DEFINED GLOO_URL))
MESSAGE(STATUS "use pre defined download url")
SET(GLOO_VER "master" CACHE STRING "" FORCE)
SET(GLOO_NAME "gloo" CACHE STRING "" FORCE)
SET(GLOO_URL "https://paddlefl.bj.bcebos.com/gloo.tar.gz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "GLOO_NAME: ${GLOO_NAME}, GLOO_URL: ${GLOO_URL}")
SET(GLOO_SOURCE_DIR "${THIRD_PARTY_PATH}/gloo")
SET(GLOO_DOWNLOAD_DIR "${GLOO_SOURCE_DIR}/src/${GLOO_PROJECT}")
SET(GLOO_DST_DIR "gloo")
SET(GLOO_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(GLOO_INSTALL_DIR ${GLOO_INSTALL_ROOT}/${GLOO_DST_DIR})
SET(GLOO_ROOT ${GLOO_INSTALL_DIR})
SET(GLOO_INC_DIR ${GLOO_ROOT}/include)
SET(GLOO_LIB_DIR ${GLOO_ROOT}/lib)
SET(GLOO_LIB ${GLOO_LIB_DIR}/libgloo.a)
#SET(GLOO_IOMP_LIB ${GLOO_LIB_DIR}/libiomp5.so) #todo what is this
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${GLOO_ROOT}/lib")
INCLUDE_DIRECTORIES(${GLOO_INC_DIR})
FILE(WRITE ${GLOO_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(GLOO)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${GLOO_NAME}/include ${GLOO_NAME}/lib \n"
" DESTINATION ${GLOO_DST_DIR})\n")
ExternalProject_Add(
${GLOO_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${GLOO_SOURCE_DIR}
DOWNLOAD_DIR ${GLOO_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${GLOO_URL} -c -q -O ${GLOO_NAME}.tar.gz
&& tar zxvf ${GLOO_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${GLOO_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GLOO_INSTALL_ROOT}
)
ADD_LIBRARY(gloo SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET gloo PROPERTY IMPORTED_LOCATION ${GLOO_LIB})
ADD_DEPENDENCIES(gloo ${GLOO_PROJECT})
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(WITH_TESTING)
ENABLE_TESTING()
ENDIF()
INCLUDE(GNUInstallDirs)
INCLUDE(ExternalProject)
SET(GTEST_PREFIX_DIR ${THIRD_PARTY_PATH}/gtest)
SET(GTEST_SOURCE_DIR ${THIRD_PARTY_PATH}/gtest/src/extern_gtest)
SET(GTEST_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gtest)
SET(GTEST_INCLUDE_DIR "${GTEST_INSTALL_DIR}/include" CACHE PATH "gtest include directory." FORCE)
set(GTEST_REPOSITORY https://github.com/google/googletest.git)
set(GTEST_TAG release-1.8.1)
INCLUDE_DIRECTORIES(${GTEST_INCLUDE_DIR})
IF(WIN32)
set(GTEST_LIBRARIES
"${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/gtest.lib" CACHE FILEPATH "gtest libraries." FORCE)
set(GTEST_MAIN_LIBRARIES
"${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/gtest_main.lib" CACHE FILEPATH "gtest main libraries." FORCE)
ELSE(WIN32)
set(GTEST_LIBRARIES
"${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/libgtest.a" CACHE FILEPATH "gtest libraries." FORCE)
set(GTEST_MAIN_LIBRARIES
"${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/libgtest_main.a" CACHE FILEPATH "gtest main libraries." FORCE)
ENDIF(WIN32)
IF(WITH_MKLML)
# wait for mklml downloading completed
SET(GTEST_DEPENDS ${MKLML_PROJECT})
ENDIF()
cache_third_party(extern_gtest
REPOSITORY ${GTEST_REPOSITORY}
TAG ${GTEST_TAG}
DIR GTEST_SOURCE_DIR)
ExternalProject_Add(
extern_gtest
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
"${GTEST_DOWNLOAD_CMD}"
DEPENDS ${GTEST_DEPENDS}
PREFIX ${GTEST_PREFIX_DIR}
SOURCE_DIR ${GTEST_SOURCE_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_GMOCK=ON
-Dgtest_disable_pthreads=ON
-Dgtest_force_shared_crt=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GTEST_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
ADD_LIBRARY(gtest STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET gtest PROPERTY IMPORTED_LOCATION ${GTEST_LIBRARIES})
ADD_DEPENDENCIES(gtest extern_gtest)
ADD_LIBRARY(gtest_main STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET gtest_main PROPERTY IMPORTED_LOCATION ${GTEST_MAIN_LIBRARIES})
ADD_DEPENDENCIES(gtest_main extern_gtest)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(HIREDIS_PREFIX_DIR ${THIRD_PARTY_PATH}/hiredis)
SET(HIREDIS_SOURCE_DIR ${THIRD_PARTY_PATH}/hiredis/src/extern_hiredis)
SET(HIREDIS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/hiredis)
SET(HIREDIS_INCLUDE_DIR ${HIREDIS_INSTALL_DIR}/include)
SET(HIREDIS_LIBRARY ${HIREDIS_INSTALL_DIR}/lib/libhiredis.a)
SET(HIREDIS_REPOSITORY https://github.com/redis/hiredis.git)
SET(HIREDIS_TAG v0.13.3)
cache_third_party(extern_hiredis
REPOSITORY ${HIREDIS_REPOSITORY}
TAG ${HIREDIS_TAG}
DIR HIREDIS_SOURCE_DIR)
INCLUDE_DIRECTORIES(${HIREDIS_INCLUDE_DIR})
INCLUDE_DIRECTORIES(${HIREDIS_INCLUDE_DIR}/hiredis)
include(ProcessorCount)
ExternalProject_Add(
extern_hiredis
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
"${HIREDIS_DOWNLOAD_CMD}"
DEPENDS ${HIREDIS_DEPENDS}
CONFIGURE_COMMAND ""#${HIREDIS_CONFIGURE_COMMAND}
PREFIX ${HIREDIS_PREFIX_DIR}
SOURCE_DIR ${HIREDIS_SOURCE_DIR}
#UPDATE_COMMAND ""
BUILD_COMMAND CC=${CMAKE_C_COMPILER} CXX=${CMAKE_CXX_COMPILER}
CFLAGS=${CMAKE_C_FLAGS} DEBUG=${CMAKE_C_FLAGS_DEBUG}
make -j ${NUM_OF_PROCESSOR}
INSTALL_COMMAND PREFIX=${HIREDIS_INSTALL_DIR} INCLUDE_PATH=include/hiredis
LIBRARY_PATH=lib make install
BUILD_IN_SOURCE 1
)
ADD_LIBRARY(hiredis SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET hiredis PROPERTY IMPORTED_LOCATION ${HIREDIS_LIBRARY})
ADD_DEPENDENCIES(hiredis extern_hiredis)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
set(OPENSSL_SOURCES_DIR ${THIRD_PARTY_PATH}/openssl)
set(OPENSSL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openssl)
set(OPENSSL_INCLUDE_DIR "${OPENSSL_INSTALL_DIR}/include")
set(OPENSSL_NAME "openssl")
include(ProcessorCount)
#ProcessorCount(NUM_OF_PROCESSOR)
if((NOT DEFINED OPENSSL_URL) OR (NOT DEFINED OPENSSL_VER))
message(STATUS "use pre defined download url")
set(OPENSSL_URL "https://paddlefl.bj.bcebos.com/openssl-1.0.2u.tar.gz" CACHE STRING "" FORCE)
set(OPENSSL_VER "openssl-1.0.2u" CACHE STRING "" FORCE)
endif()
ExternalProject_Add(
extern_openssl
PREFIX ${OPENSSL_SOURCES_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${OPENSSL_URL} -c -q -O ${OPENSSL_NAME}.tar.gz
&& tar -xvf ${OPENSSL_NAME}.tar.gz
SOURCE_DIR ${OPENSSL_SOURCES_DIR}/src/${OPENSSL_VER}
CONFIGURE_COMMAND ./config shared --openssldir=${OPENSSL_INSTALL_DIR} -lrt -Wl,--no-as-needed
BUILD_COMMAND make depend -j ${NUM_OF_PROCESSOR} &&
make build_libcrypto -j ${NUM_OF_PROCESSOR} &&
make build_apps -j ${NUM_OF_PROCESSOR}
INSTALL_COMMAND make install_sw
BUILD_IN_SOURCE 1
)
set(OPENSSL_CRYPTO_LIBRARY "${OPENSSL_INSTALL_DIR}/lib/libcrypto.so")
add_library(crypto SHARED IMPORTED GLOBAL)
set_property(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY})
add_dependencies(crypto extern_openssl)
include_directories(${OPENSSL_INCLUDE_DIR})
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
set(PYBIND_PREFIX_DIR ${THIRD_PARTY_PATH}/pybind)
set(PYBIND_SOURCE_DIR ${THIRD_PARTY_PATH}/pybind/src/extern_pybind)
SET(PYBIND_REPOSITORY https://github.com/pybind/pybind11.git)
SET(PYBIND_TAG v2.2.4)
cache_third_party(extern_pybind
REPOSITORY ${PYBIND_REPOSITORY}
TAG ${PYBIND_TAG}
DIR PYBIND_SOURCE_DIR)
set(PYBIND_INCLUDE_DIR ${PYBIND_SOURCE_DIR}/include)
include_directories(${PYBIND_INCLUDE_DIR})
ExternalProject_Add(
extern_pybind
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
"${PYBIND_DOWNLOAD_CMD}"
PREFIX ${PYBIND_PREFIX_DIR}
SOURCE_DIR ${PYBIND_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
if(${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/pybind_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_pybind = \"${dummyfile}\";")
add_library(pybind STATIC ${dummyfile})
else()
add_library(pybind INTERFACE)
endif()
add_dependencies(pybind extern_pybind)
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Creat a target named "third_party", which can compile external dependencies on all platform(windows/linux/mac)
include(CMakeParseArguments)
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
"A path setting third party libraries download & build directories.")
set(THIRD_PARTY_CACHE_PATH "${CMAKE_SOURCE_DIR}" CACHE STRING
"A path cache third party source code to avoid repeated download.")
set(THIRD_PARTY_BUILD_TYPE Release)
# cache funciton to avoid repeat download code of third_party.
# This function has 4 parameters, URL / REPOSITOR / TAG / DIR:
# 1. URL: specify download url of 3rd party
# 2. REPOSITORY: specify git REPOSITORY of 3rd party
# 3. TAG: specify git tag/branch/commitID of 3rd party
# 4. DIR: overwrite the original SOURCE_DIR when cache directory
#
# The function Return 1 PARENT_SCOPE variables:
# - ${TARGET}_DOWNLOAD_CMD: Simply place "${TARGET}_DOWNLOAD_CMD" in ExternalProject_Add,
# and you no longer need to set any donwnload steps in ExternalProject_Add.
# For example:
# Cache_third_party(${TARGET}
# REPOSITORY ${TARGET_REPOSITORY}
# TAG ${TARGET_TAG}
# DIR ${TARGET_SOURCE_DIR})
FUNCTION(cache_third_party TARGET)
SET(options "")
SET(oneValueArgs URL REPOSITORY TAG DIR)
SET(multiValueArgs "")
cmake_parse_arguments(cache_third_party "${optionps}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
STRING(REPLACE "extern_" "" TARGET_NAME ${TARGET})
STRING(REGEX REPLACE "[0-9]+" "" TARGET_NAME ${TARGET_NAME})
STRING(TOUPPER ${TARGET_NAME} TARGET_NAME)
IF(cache_third_party_REPOSITORY)
SET(${TARGET_NAME}_DOWNLOAD_CMD
GIT_REPOSITORY ${cache_third_party_REPOSITORY})
IF(cache_third_party_TAG)
LIST(APPEND ${TARGET_NAME}_DOWNLOAD_CMD
GIT_TAG ${cache_third_party_TAG})
ENDIF()
ELSEIF(cache_third_party_URL)
SET(${TARGET_NAME}_DOWNLOAD_CMD
URL ${cache_third_party_URL})
ELSE()
MESSAGE(FATAL_ERROR "Download link (Git repo or URL) must be specified for cache!")
ENDIF()
IF(WITH_TP_CACHE)
IF(NOT cache_third_party_DIR)
MESSAGE(FATAL_ERROR "Please input the ${TARGET_NAME}_SOURCE_DIR for overwriting when -DWITH_TP_CACHE=ON")
ENDIF()
# Generate and verify cache dir for third_party source code
SET(cache_third_party_REPOSITORY ${cache_third_party_REPOSITORY} ${cache_third_party_URL})
IF(cache_third_party_REPOSITORY AND cache_third_party_TAG)
STRING(MD5 HASH_REPO ${cache_third_party_REPOSITORY})
STRING(MD5 HASH_GIT ${cache_third_party_TAG})
STRING(SUBSTRING ${HASH_REPO} 0 8 HASH_REPO)
STRING(SUBSTRING ${HASH_GIT} 0 8 HASH_GIT)
STRING(CONCAT HASH ${HASH_REPO} ${HASH_GIT})
# overwrite the original SOURCE_DIR when cache directory
SET(${cache_third_party_DIR} ${THIRD_PARTY_CACHE_PATH}/third_party/${TARGET}_${HASH})
ELSEIF(cache_third_party_REPOSITORY)
STRING(MD5 HASH_REPO ${cache_third_party_REPOSITORY})
STRING(SUBSTRING ${HASH_REPO} 0 16 HASH)
# overwrite the original SOURCE_DIR when cache directory
SET(${cache_third_party_DIR} ${THIRD_PARTY_CACHE_PATH}/third_party/${TARGET}_${HASH})
ENDIF()
IF(EXISTS ${${cache_third_party_DIR}})
# judge whether the cache dir is empty
FILE(GLOB files ${${cache_third_party_DIR}}/*)
LIST(LENGTH files files_len)
IF(files_len GREATER 0)
list(APPEND ${TARGET_NAME}_DOWNLOAD_CMD DOWNLOAD_COMMAND "")
ENDIF()
SET(${cache_third_party_DIR} ${${cache_third_party_DIR}} PARENT_SCOPE)
ENDIF()
ENDIF()
# Pass ${TARGET_NAME}_DOWNLOAD_CMD to parent scope, the double quotation marks can't be removed
SET(${TARGET_NAME}_DOWNLOAD_CMD "${${TARGET_NAME}_DOWNLOAD_CMD}" PARENT_SCOPE)
ENDFUNCTION()
MACRO(UNSET_VAR VAR_NAME)
UNSET(${VAR_NAME} CACHE)
UNSET(${VAR_NAME})
ENDMACRO()
# Correction of flags on different Platform(WIN/MAC) and Print Warning Message
if (APPLE)
if(WITH_MKL)
MESSAGE(WARNING
"Mac is not supported with MKL in Paddle yet. Force WITH_MKL=OFF.")
set(WITH_MKL OFF CACHE STRING "Disable MKL for building on mac" FORCE)
endif()
endif()
if(WIN32 OR APPLE)
MESSAGE(STATUS "Disable XBYAK in Windows and MacOS")
SET(WITH_XBYAK OFF CACHE STRING "Disable XBYAK in Windows and MacOS" FORCE)
if(WITH_LIBXSMM)
MESSAGE(WARNING
"Windows, Mac are not supported with libxsmm in Paddle yet."
"Force WITH_LIBXSMM=OFF")
SET(WITH_LIBXSMM OFF CACHE STRING "Disable LIBXSMM in Windows and MacOS" FORCE)
endif()
if(WITH_NGRAPH)
MESSAGE(WARNING
"Windows or Mac is not supported with nGraph in Paddle yet."
"Force WITH_NGRAPH=OFF")
SET(WITH_NGRAPH OFF CACHE STRING "Disable nGraph in Windows and MacOS" FORCE)
endif()
if(WITH_BOX_PS)
MESSAGE(WARNING
"Windows or Mac is not supported with BOX_PS in Paddle yet."
"Force WITH_BOX_PS=OFF")
SET(WITH_BOX_PS OFF CACHE STRING "Disable BOX_PS package in Windows and MacOS" FORCE)
endif()
if(WITH_PSLIB)
MESSAGE(WARNING
"Windows or Mac is not supported with PSLIB in Paddle yet."
"Force WITH_PSLIB=OFF")
SET(WITH_PSLIB OFF CACHE STRING "Disable PSLIB package in Windows and MacOS" FORCE)
endif()
if(WITH_LIBMCT)
MESSAGE(WARNING
"Windows or Mac is not supported with LIBMCT in Paddle yet."
"Force WITH_LIBMCT=OFF")
SET(WITH_LIBMCT OFF CACHE STRING "Disable LIBMCT package in Windows and MacOS" FORCE)
endif()
if(WITH_PSLIB_BRPC)
MESSAGE(WARNING
"Windows or Mac is not supported with PSLIB_BRPC in Paddle yet."
"Force WITH_PSLIB_BRPC=OFF")
SET(WITH_PSLIB_BRPC OFF CACHE STRING "Disable PSLIB_BRPC package in Windows and MacOS" FORCE)
endif()
endif()
set(WITH_MKLML ${WITH_MKL})
if(NOT DEFINED WITH_MKLDNN)
if(WITH_MKL AND AVX2_FOUND)
set(WITH_MKLDNN ON)
else()
message(STATUS "Do not have AVX2 intrinsics and disabled MKL-DNN")
set(WITH_MKLDNN OFF)
endif()
endif()
if(WIN32 OR APPLE OR NOT WITH_GPU OR ON_INFER)
set(WITH_DGC OFF)
endif()
if(${CMAKE_VERSION} VERSION_GREATER "3.5.2")
set(SHALLOW_CLONE "GIT_SHALLOW TRUE") # adds --depth=1 arg to git clone of External_Projects
endif()
########################### include third_party according to flags ###############################
#include(external/zlib) # download, build, install zlib
#include(external/gflags) # download, build, install gflags
#include(external/boost) # download boost
#include(external/eigen) # download eigen3
#include(external/threadpool)# download threadpool
#include(external/dlpack) # download dlpack
#include(external/xxhash) # download, build, install xxhash
#include(external/warpctc) # download, build, install warpctc
set(third_party_deps)
#list(APPEND third_party_deps extern_eigen3 extern_gflags extern_boost)
#list(APPEND third_party_deps extern_zlib extern_dlpack extern_warpctc extern_threadpool)
# if(WITH_AMD_GPU)
# include(external/rocprim) # download, build, install rocprim
# list(APPEND third_party_deps extern_rocprim)
# endif()
#include(cblas) # find first, then download, build, install openblas
#if(${CBLAS_PROVIDER} STREQUAL MKLML)
# list(APPEND third_party_deps extern_mklml)
#elseif(${CBLAS_PROVIDER} STREQUAL EXTERN_OPENBLAS)
# list(APPEND third_party_deps extern_openblas)
#endif()
if(NOT WIN32 AND NOT APPLE)
include(external/gloo)
list(APPEND third_party_deps extern_gloo)
endif()
if(NOT WIN32 AND NOT APPLE)
include(external/hiredis)
list(APPEND third_party_deps extern_hiredis)
endif()
# if(WITH_MKLDNN)
# include(external/mkldnn) # download, build, install mkldnn
# list(APPEND third_party_deps extern_mkldnn)
# endif()
#include(external/protobuf) # find first, then download, build, install protobuf
# if(NOT PROTOBUF_FOUND OR WIN32)
# list(APPEND third_party_deps extern_protobuf)
# endif()
if(NOT WIN32 AND NOT APPLE)
include(external/pybind11)
list(APPEND third_party_deps extern_pybind11)
endif()
# if(WITH_PYTHON)
# include(external/python) # find python and python_module
# include(external/pybind11) # download pybind11
# list(APPEND third_party_deps extern_pybind)
# endif()
IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
include(external/gtest) # download, build, install gtest
list(APPEND third_party_deps extern_gtest)
ENDIF()
IF(WITH_PSI)
include(external/openssl) # download, build, install gtest
list(APPEND third_party_deps extern_openssl)
ENDIF()
# if(WITH_GPU)
# include(external/cub) # download cub
# list(APPEND third_party_deps extern_cub)
# endif(WITH_GPU)
# if(WITH_PSLIB)
# include(external/pslib) # download, build, install pslib
# list(APPEND third_party_deps extern_pslib)
# if(WITH_LIBMCT)
# include(external/libmct) # download, build, install libmct
# list(APPEND third_party_deps extern_libxsmm)
# endif()
# if(WITH_PSLIB_BRPC)
# include(external/pslib_brpc) # download, build, install pslib_brpc
# list(APPEND third_party_deps extern_pslib_brpc)
# endif()
# endif(WITH_PSLIB)
# if(WITH_BOX_PS)
# include(external/box_ps)
# list(APPEND third_party_deps extern_box_ps)
# endif(WITH_BOX_PS)
# if(WITH_DISTRIBUTE)
# if(WITH_GRPC)
# list(APPEND third_party_deps extern_grpc)
# else()
# list(APPEND third_party_deps extern_leveldb)
# list(APPEND third_party_deps extern_brpc)
# endif()
# endif()
# if(WITH_NGRAPH)
# if(WITH_MKLDNN)
# include(external/ngraph) # download, build, install nGraph
# list(APPEND third_party_deps extern_ngraph)
# else()
# MESSAGE(WARNING
# "nGraph needs mkl-dnn to be enabled."
# "Force WITH_NGRAPH=OFF")
# SET(WITH_NGRAPH OFF CACHE STRING "Disable nGraph if mkl-dnn is disabled" FORCE)
# endif()
# endif()
# if(WITH_XBYAK)
# include(external/xbyak) # download, build, install xbyak
# list(APPEND third_party_deps extern_xbyak)
# endif()
# if(WITH_LIBXSMM)
# include(external/libxsmm) # download, build, install libxsmm
# list(APPEND third_party_deps extern_libxsmm)
# endif()
# if(WITH_DGC)
# message(STATUS "add dgc lib.")
# include(external/dgc) # download, build, install dgc
# add_definitions(-DPADDLE_WITH_DGC)
# list(APPEND third_party_deps extern_dgc)
# endif()
# if (WITH_LITE)
# include(external/lite)
# endif (WITH_LITE)
add_custom_target(third_party DEPENDS ${third_party_deps})
## PE - Paddle Encrypted
Paddle Encrypted is a framework for privacy-preserving deep learning based on PaddlePaddle. It follows the same running mechanism and programming paradigm with PaddlePaddle, while using secure multi-party computation (MPC) to enable secure training and prediction.
With Paddle Encrypted, it is easy to train models or conduct prediction as on PaddlePaddle over encrypted data, without the need for cryptography expertise. Furthermore, the rich industry-oriented models and algorithms built on PaddlePaddle can be smoothly migrated to secure versions on Paddle Encrypted with little effort.
As a key product of PaddleFL, Paddle Encrypted intrinsically supports federated learning well, including horizontal, vertical and transfer learning scenarios. It provides both provable security (semantic security) and competitive performance.
Below please see the installation, examples, or visit the documentation to learn more about the technical details.
## Design Overview
![img](http://icode.baidu.com/path/to/iamge)
Paddle Encrypted implements secure training and inference tasks based on the underlying MPC protocol of ABY3[], in which participants can be classified into roles of Input Party (IP), Computing Party (CP) and Result Party (RP).
Input Parties (e.g., the training data/model owners) encrypt and distribute data or models to Computing Parties. Computing Parties (e.g., the VM on the cloud) conduct training or inference tasks based on specific MPC protocols, being restricted to see only the encrypted data or models, and thus guarantee the data privacy. When the computation is completed, one or more Result Parties (e.g., data owners or specified third-party) receive the encrypted results from Computing Parties, and reconstruct the plaintext results. Roles can be overlapped, e.g., a data owner can also act as a computing party.
A full training or inference process in Paddle Encrypted consists of mainly three phases: data preparation, training/inference, and result reconstruction.
#### Data preparation
#####Private data alignment
Paddle Encrypted enables data owners (IPs) to find out records with identical keys (like UUID) without revealing private data to each other. This is especially useful in the vertical learning cases where segmented features with same keys need to be identified and aligned from all owners in a private manner before training. Using the OT-based PSI (Private Set Intersection) algorithm[], PE can perform private alignment at a speed of up to 60k records per second.
#####Encryption and distribution
In Paddle Encrypted, data and models from IPs will be encrypted using Secret-Sharing[], and then be sent to CPs, via directly transmission or distributed storage like HDFS. Each CP can only obtain one share of each piece of data, and thus is unable to recover the original value in the Semi-honest model[].
#### Training/inference
![img](http://icode.baidu.com/path/to/iamge)
As in PaddlePaddle, a training or inference job can be separated into the compile-time phase and the run-time phase:
##### Compile time
* **MPC environment specification**: a user needs to choose a MPC protocol, and configure the network settings. In current version, PE provides only the "ABY3" protocol. More protocol implementation will be provided in future.
* **User-defined job program**: a user can define the machine learning model structure and the training strategies (or inference task) in a PE program, using the secure operators.
##### Run time
A PE program is exactly a PaddlePaddle program, and will be executed as normal PaddlePaddle programs. For example, in run-time a PE program will be transpiled into ProgramDesc, and then be passed to and run by the Executor. The main concepts in the run-time phase are as follows:
* **Computing nodes**: a computing node is an entity corresponding to a Computing Party. In real deployment, it can be a bare-metal machine, a cloud VM, a docker or even a process. PE requires exactly three computing nodes in each run, which is determined by the underlying ABY3 protocol. A PE program will be deployed and run in parallel on all three computing nodes.
* **Operators using MPC**: PE provides typical machine learning operators in `paddle.fluid_encrypted` over encrypted data. Such operators are implemented upon PaddlePaddle framework, based on MPC protocols like ABY3. Like other PaddlePaddle operators, in run time, instances of PE operators are created and run in order by Executor (see [] for details).
####Result reconstruction
Upon completion of the secure training (or inference) job, the models (or prediction results) will be output by CPs in encrypted form. Result Parties can collect the encrypted results, decrypt them using the tools in PE, and deliver the plaintext results to users.
## Compilation and Installation
#### Environment preparation
* CentOS 6 or CentOS 7 (64 bit)
* Python 2.7.15+/3.5.1+/3.6/3.7 ( 64 bit) or above
* pip or pip3 9.0.1+ (64 bit)
* PaddlePaddle release 1.6.3
* Redis 5.0.8 (64 bit)
#### Clone the source code, compile and install
Fetch the source code and checkout stable release
```sh
git clone https://repo/site
cd /path/to/paddle_mpc
# Checkout stable release
git checkout [stable-release]
mkdir build && cd build
```
Execute compile commands, where `PYTHON_EXECUTABLE` is path to the python binary where the PaddlePaddle is installed, and `PYTHON_INCLUDE_DIRS` is the corresponding python include directory. You can get the `PYTHON_INCLUDE_DIRS` via the following command:
```sh
${PYTHON_EXECUTABLE} -c "from distutils.sysconfig import get_python_inc;print(get_python_inc())"
```
Then you can put the directory in the following command and make:
```sh
cmake ../ -DPYTHON_EXECUTABLE=${python} -DPYTHON_INCLUDE_DIRS=${python_include_dir}
make -j$(nproc)
```
Install the package:
```sh
make install
cd /path/to/paddle_mpc/python && ${PYTHON_EXECUTABLE} setup.py sdist bdist_wheel && pip or pip3 install dist/***.whl -U
```
Validate the installation by running the `python` or `python3`, then runs `import paddle_encrypted as pe` and `pe.version()`. The installation succeeds if you see `Paddle Encrypted Version: 1.0.0`.
## Example
#### Build your model
In Paddle Encrypted, you can build models as it is in PaddlePaddle, but using the variables and operators over encrypted data. First, prepare a training script as the example below. It is worth to note that the operators and variables are created using the `paddle.fluid_encrypted` package.
```python
# An example to build an LR model, named train.py (USE THE HOUSE PRICE CASE)
import sys
import paddle.paddle_encrypted as paddle_enc
import paddle.fluid as fluid
import numpy
# read role from command line
role, addr, port = sys.argv[1], sys.argv[2], sys.argv[3]
# init the MPC environment
paddle_enc.init("aby3", (int)role, net_server_addr=addr, net_server_port=(int)port)
# define encrypted variables
image = paddle_enc.data(name='image', shape=[None, 784], dtype='int64')
label = paddle_enc.data(name='label', shape=[None, 1], dtype='int64')
# define a secure training network
hidden = paddle_enc.layers.fc(input=image, size=100, act='relu')
prediction = paddle_enc.layers.fc(input=hidden, size=10, act='softmax')
cost = paddle_enc.layers.square_error_cost(input=prediction, label=label)
loss = paddle_enc.layers.mean(cost)
sgd = paddle_enc.optimizer.SGD(learning_rate=0.001)
sgd.minimize(loss)
# Place the training on CPU
exe = fluid.Executor(place=fluid.CPUPlace())
# use random numbers to simulate encrypted data, and start training
x = numpy.random.random(size=(128, 2, 784)).astype('int64')
y = numpy.random.random(size=(128, 2, 1)).astype('int64')
loss_data, = exe.run(feed={'image':x, 'lable':y},
fetch_list=[loss.name])
```
#### Execution and results
To make the MPC training run, we need to deploy the training processes on multiple machines (i.e., three machines in current version), and use a discovery service to let them find each other. We use Redis as the discovery service here.
1. Start a Redis service, and keep the service address:
```sh
redis-server --port ${port}
```
2. Deploy the above `train.py` on three machines, and run with different role settings (from 0 to 2):
```sh
# run python code
# on machine1:
python train.py 0 ${redis_addr} ${port}
# on machine2:
python train.py 1 ${redis_addr} ${port}
# on machine3
python train.py 2 ${redis_addr} ${port}
```
Then the training process will start and the underlying MPC-based operators will be executed to complete the secure training.
## Benchmark Task
put result here as a table? | DataSet/Task | training methods | Result | | --- | --- | --- |
## On Going and Future Work
- more features
## Reference
[1].
add_compile_options(-msse4.2 -maes)
set(PYBIND_SRCS
"./data_utils.cc"
)
if (NOT PYTHON_INCLUDE_DIRS)
find_package(PythonLibs REQUIRED)
endif()
include_directories(${PYTHON_INCLUDE_DIRS})
add_library(mpc_data_utils MODULE ${PYBIND_SRCS})
target_link_libraries(mpc_data_utils PRIVATE pybind)
target_link_libraries(mpc_data_utils PRIVATE privc3)
target_link_libraries(mpc_data_utils PRIVATE psi)
set_target_properties(mpc_data_utils PROPERTIES PREFIX "${PYTHON_MODULE_PREFIX}")
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <atomic>
#include <set>
#include <string>
#include <vector>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "core/paddlefl_mpc/mpc_protocol/aby3_operators.h"
#include "core/privc3/fixedpoint_util.h"
#include "core/psi/psi_api.h"
namespace py = pybind11;
namespace aby3 {
// split plaintext into three shares.
template <typename T, size_t N> py::array_t<T> share(double input) {
size_t share_num = 3;
auto shares = py::array_t<T>(share_num);
py::buffer_info shares_buf = shares.request();
T *shares_buf_ptr = (T *)shares_buf.ptr;
T *ret_ptr[share_num];
for (size_t i = 0; i < share_num; ++i) {
ret_ptr[i] = &shares_buf_ptr[i];
}
FixedPointUtil<T, N>::share(input, ret_ptr);
return shares;
}
// combine three shares to reveal plaintext.
template <typename T, size_t N> double reveal(py::array_t<T> shares) {
size_t share_num = 3;
py::buffer_info shares_buf = shares.request();
T *shares_buf_ptr = (T *)shares_buf.ptr;
T *ret[share_num];
for (size_t idx = 0; idx < share_num; ++idx) {
ret[idx] = &shares_buf_ptr[idx];
}
double result = FixedPointUtil<T, N>::reveal(ret);
return result;
}
// call psi_send
int send_psi(int port, const std::set<std::string> &input) {
std::atomic<int> prog(0);
return psi::psi_send(port, input, &prog);
}
// call psi_recv
std::vector<std::string> recv_psi(const std::string &remote_ip, int port,
const std::set<std::string> &input) {
std::vector<std::string> output;
std::atomic<int> prog(0);
int ret = psi::psi_recv(remote_ip, port, input, &output, &prog);
if (ret != 0) {
output.clear();
return output;
}
return output;
}
PYBIND11_MODULE(mpc_data_utils, m) {
// optional module docstring
m.doc() = "pybind11 paddle-mpc plugin: data_utils (share, reveal, psi)";
m.def("share", &share<long long, paddle::mpc::ABY3_SCALING_FACTOR>,
"split plaintext into three shares.");
m.def("reveal", &reveal<long long, paddle::mpc::ABY3_SCALING_FACTOR>,
"combine three shares to reveal plaintext.");
m.def("send_psi", &send_psi, "Send input in two party PSI.");
m.def("recv_psi", &recv_psi,
"Send input and return PSI result as output in two party PSI.");
}
} // namespace aby3
add_compile_options(-msse4.2 -maes)
set(PROTO_SRCS
"./aby3_protocol.cc"
"./mesh_network.cc"
"./mpc_config_parameters.cc"
"./context_holder.cc"
"./mpc_instance.cc"
"./mpc_protocol_factory.cc"
)
add_library(mpc_protocol_o OBJECT ${PROTO_SRCS})
add_dependencies(mpc_protocol_o fluid_framework gloo hiredis)
add_library(mpc_protocol STATIC $<TARGET_OBJECTS:mpc_protocol_o>)
target_link_libraries(mpc_protocol fluid_framework gloo hiredis privc3)
cc_test(mesh_network_test SRCS mesh_network_test.cc DEPS mpc_protocol)
cc_test(mpc_protocol_test SRCS mpc_protocol_test.cc DEPS mpc_protocol)
cc_test(mpc_instance_test SRCS mpc_instance_test.cc DEPS mpc_protocol)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <vector>
namespace paddle {
namespace mpc {
class AbstractNetwork {
public:
AbstractNetwork() = default;
virtual ~AbstractNetwork() = default;
virtual void send(size_t party, const void *data, size_t size) = 0;
virtual void recv(size_t party, void *data, size_t size) = 0;
virtual void broadcast(const void *data, size_t size) {
for (size_t i = 0; i < party_num(); ++i) {
if (i == party_id()) {
continue;
}
send(i, data, size);
}
}
virtual void gather(void *data[], size_t size) {
for (size_t i = 0; i < party_num(); ++i) {
if (i == party_id()) {
continue;
}
recv(i, data[i], size);
}
}
template <typename T> void send(size_t party, const T &data) {
send(party, &data, sizeof(T));
}
template <typename T, template <typename> class Tensor>
void send(size_t party, const Tensor<T> &tensor) {
send(party, tensor.data(), sizeof(T) * tensor.numel());
}
template <typename T> T recv(size_t party) {
T ret;
recv(party, &ret, sizeof(T));
return ret;
}
template <typename T, template <typename> class Tensor>
Tensor<T> &recv(size_t party, Tensor<T> &tensor) {
recv(party, tensor.data(), sizeof(T) * tensor.numel());
return tensor;
}
template <typename T> void broadcast(const T &data) {
broadcast(&data, sizeof(T));
}
template <typename T> std::vector<T> gather() {
std::vector<T> ret(party_num());
for (size_t i = 0; i < party_num(); ++i) {
if (i == party_id()) {
continue;
}
recv(i, &ret[i], sizeof(T));
}
return ret;
}
template <typename T> void send(size_t party, const T *begin, const T *end) {
send(party, begin, (end - begin) * sizeof(T));
}
template <typename T> T *recv(size_t party, T *begin, T *end) {
recv(party, begin, (end - begin) * sizeof(T));
return begin;
}
template <typename T>
void broadcast(size_t party, const T *begin, const T *end) {
broadcast(begin, (end - begin) * sizeof(T));
}
template <typename T> void gather(T *begin[], T *end[]) {
gather(begin, sizeof(T) * (end[0] - begin[0]));
}
virtual size_t party_id() const = 0;
virtual size_t party_num() const = 0;
};
} // namespace mpc
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description: implementations of each virtual op according to ABY3 protocol
#pragma once
#include <utility>
#include "context_holder.h"
#include "mpc_operators.h"
#include "paddle/fluid/framework/tensor.h"
#include "core/privc3/boolean_tensor.h"
#include "core/privc3/circuit_context.h"
#include "core/privc3/fixedpoint_tensor.h"
#include "core/privc3/paddle_tensor.h"
namespace paddle {
namespace mpc {
using paddle::framework::Tensor;
using aby3::CircuitContext;
// TODO: decide scaling factor
const size_t ABY3_SCALING_FACTOR = 16;
using FixedTensor = aby3::FixedPointTensor<int64_t, ABY3_SCALING_FACTOR>;
using BoolTensor = aby3::BooleanTensor<int64_t>;
using PaddleTensor = aby3::PaddleTensor<int64_t>;
class Aby3OperatorsImpl : public MpcOperators {
public:
void add(const Tensor *lhs, const Tensor *rhs, Tensor *out) override {
auto lhs_tuple = from_tensor(lhs);
auto rhs_tuple = from_tensor(rhs);
auto out_tuple = from_tensor(out);
auto lhs_ = std::get<0>(lhs_tuple).get();
auto rhs_ = std::get<0>(rhs_tuple).get();
auto out_ = std::get<0>(out_tuple).get();
lhs_->add(rhs_, out_);
}
// TODO: override
void sub(const Tensor *lhs, const Tensor *rhs, Tensor *out) override {
auto lhs_tuple = from_tensor(lhs);
auto rhs_tuple = from_tensor(rhs);
auto out_tuple = from_tensor(out);
auto lhs_ = std::get<0>(lhs_tuple).get();
auto rhs_ = std::get<0>(rhs_tuple).get();
auto out_ = std::get<0>(out_tuple).get();
lhs_->sub(rhs_, out_);
}
void neg(const Tensor *op, Tensor *out) override {
auto op_tuple = from_tensor(op);
auto out_tuple = from_tensor(out);
auto op_ = std::get<0>(op_tuple).get();
auto out_ = std::get<0>(out_tuple).get();
op_->negative(out_);
}
void sum(const Tensor *op, Tensor *out) override {
auto op_tuple = from_tensor(op);
auto out_tuple = from_tensor(out);
auto op_ = std::get<0>(op_tuple).get();
auto out_ = std::get<0>(out_tuple).get();
op_->sum(out_);
}
void mul(const Tensor *lhs, const Tensor *rhs, Tensor *out) override {
auto lhs_tuple = from_tensor(lhs);
auto rhs_tuple = from_tensor(rhs);
auto out_tuple = from_tensor(out);
auto lhs_ = std::get<0>(lhs_tuple).get();
auto rhs_ = std::get<0>(rhs_tuple).get();
auto out_ = std::get<0>(out_tuple).get();
lhs_->mul(rhs_, out_);
}
void matmul(const Tensor *lhs, const Tensor *rhs, Tensor *out) override {
auto lhs_tuple = from_tensor(lhs);
auto rhs_tuple = from_tensor(rhs);
auto out_tuple = from_tensor(out);
auto lhs_ = std::get<0>(lhs_tuple).get();
auto rhs_ = std::get<0>(rhs_tuple).get();
auto out_ = std::get<0>(out_tuple).get();
lhs_->mat_mul(rhs_, out_);
}
void scale(const Tensor *lhs, const double factor, Tensor *out) override {
auto lhs_tuple = from_tensor(lhs);
auto out_tuple = from_tensor(out);
auto lhs_ = std::get<0>(lhs_tuple).get();
auto out_ = std::get<0>(out_tuple).get();
PaddleTensor scale_tensor(ContextHolder::device_ctx());
scale_tensor.from_float_point_scalar(factor, lhs_->shape(),
ABY3_SCALING_FACTOR);
lhs_->mul(&scale_tensor, out_);
}
void relu(const Tensor *op, Tensor *out) override {
auto op_tuple = from_tensor(op);
auto out_tuple = from_tensor(out);
auto op_ = std::get<0>(op_tuple).get();
auto out_ = std::get<0>(out_tuple).get();
op_->relu(out_);
}
void softmax(const Tensor *op, Tensor *out) override {
auto op_tuple = from_tensor(op);
auto out_tuple = from_tensor(out);
auto op_ = std::get<0>(op_tuple).get();
auto out_ = std::get<0>(out_tuple).get();
op_->softmax(out_);
}
void gt(const Tensor *lhs, const Tensor *rhs, Tensor *out) override {
auto lhs_tuple = from_tensor(lhs);
auto lhs_ = std::get<0>(lhs_tuple).get();
PaddleTensor rhs_(ContextHolder::device_ctx());
rhs_.from_float_point_type<float>(*rhs, ABY3_SCALING_FACTOR);
PaddleTensor out_(ContextHolder::device_ctx(), *out);
auto tmp0 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape());
auto tmp1 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape());
BoolTensor bool_out(tmp0.get(), tmp1.get());
lhs_->gt(&rhs_, &bool_out);
bool_out.reveal(&out_);
}
void geq(const Tensor *lhs, const Tensor *rhs, Tensor *out) override {
lt(lhs, rhs, out);
std::transform(out->data<int64_t>(), out->data<int64_t>() + out->numel(),
out->data<int64_t>(), [](int64_t b) { return 1 - b; });
}
void lt(const Tensor *lhs, const Tensor *rhs, Tensor *out) override {
auto lhs_tuple = from_tensor(lhs);
auto lhs_ = std::get<0>(lhs_tuple).get();
PaddleTensor rhs_(ContextHolder::device_ctx(), *rhs);
rhs_.from_float_point_type<float>(*rhs, ABY3_SCALING_FACTOR);
PaddleTensor out_(ContextHolder::device_ctx(), *out);
auto tmp0 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape());
auto tmp1 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape());
BoolTensor bool_out(tmp0.get(), tmp1.get());
lhs_->lt(&rhs_, &bool_out);
bool_out.reveal(&out_);
}
void leq(const Tensor *lhs, const Tensor *rhs, Tensor *out) override {
gt(lhs, rhs, out);
std::transform(out->data<int64_t>(), out->data<int64_t>() + out->numel(),
out->data<int64_t>(), [](int64_t b) { return 1 - b; });
}
void eq(const Tensor *lhs, const Tensor *rhs, Tensor *out) override {
auto lhs_tuple = from_tensor(lhs);
auto lhs_ = std::get<0>(lhs_tuple).get();
PaddleTensor rhs_(ContextHolder::device_ctx(), *rhs);
rhs_.from_float_point_type<float>(*rhs, ABY3_SCALING_FACTOR);
PaddleTensor out_(ContextHolder::device_ctx(), *out);
auto tmp0 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape());
auto tmp1 = ContextHolder::tensor_factory()->create_int64_t(rhs_.shape());
BoolTensor bool_out(tmp0.get(), tmp1.get());
lhs_->eq(&rhs_, &bool_out);
bool_out.reveal(&out_);
}
void neq(const Tensor *lhs, const Tensor *rhs, Tensor *out) override {
eq(lhs, rhs, out);
std::transform(out->data<int64_t>(), out->data<int64_t>() + out->numel(),
out->data<int64_t>(), [](int64_t b) { return 1 - b; });
}
void relu_grad(const Tensor *y, const Tensor *dy, Tensor *dx,
float point = 0.0f) override {
auto y_tuple = from_tensor(y);
auto y_ = std::get<0>(y_tuple).get();
PaddleTensor point_(ContextHolder::device_ctx());
point_.from_float_point_scalar<float>(point, y_->shape(),
ABY3_SCALING_FACTOR);
auto tmp0 = ContextHolder::tensor_factory()->create_int64_t(y_->shape());
auto tmp1 = ContextHolder::tensor_factory()->create_int64_t(y_->shape());
BoolTensor bool_out(tmp0.get(), tmp1.get());
y_->gt(&point_, &bool_out);
auto out_tuple = from_tensor(dx);
auto out_ = std::get<0>(out_tuple).get();
auto dy_tuple = from_tensor(dy);
auto dy_ = std::get<0>(dy_tuple).get();
bool_out.mul(dy_, out_);
}
private:
std::tuple<std::shared_ptr<FixedTensor>, std::shared_ptr<PaddleTensor>,
std::shared_ptr<PaddleTensor>>
from_tensor(const Tensor *t) {
PADDLE_ENFORCE_EQ(t->dims()[0], 2);
auto pt0 = std::make_shared<PaddleTensor>(ContextHolder::device_ctx(),
t->Slice(0, 1));
auto pt1 = std::make_shared<PaddleTensor>(ContextHolder::device_ctx(),
t->Slice(1, 2));
aby3::TensorAdapter<int64_t> *pt_array[2] = {pt0.get(), pt1.get()};
auto ft = std::make_shared<FixedTensor>(pt_array);
return std::make_tuple(ft, pt0, pt1);
}
};
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description:
// an ABY3 protocol impl, including combination of operator, network and circuit
// context
#include "aby3_protocol.h"
#include "gloo/rendezvous/redis_store.h"
#include "mpc_protocol_factory.h"
namespace paddle {
namespace mpc {
void Aby3Protocol::init_with_store(
const MpcConfig &config, std::shared_ptr<gloo::rendezvous::Store> store) {
if (_is_initialized) {
return;
}
PADDLE_ENFORCE_NOT_NULL(store);
// read role, address and other info
auto role = config.get_int(Aby3Config::ROLE);
PADDLE_ENFORCE_LT(role, 3, "Input role should be less than party_size(3).");
auto local_addr =
config.get(Aby3Config::LOCAL_ADDR, Aby3Config::LOCAL_ADDR_DEFAULT);
auto net_server_addr = config.get(Aby3Config::NET_SERVER_ADDR,
Aby3Config::NET_SERVER_ADDR_DEFAULT);
auto net_server_port = config.get_int(Aby3Config::NET_SERVER_PORT,
Aby3Config::NET_SERVER_PORT_DEFAULT);
auto mesh_net = std::make_shared<MeshNetwork>(
role, local_addr, 3 /* netsize */, "Paddle-mpc" /* key-prefix in store*/,
store);
mesh_net->init();
_network = std::move(mesh_net);
_circuit_ctx = std::make_shared<CircuitContext>(role, _network);
_operators = std::make_shared<Aby3OperatorsImpl>();
_is_initialized = true;
}
std::shared_ptr<MpcOperators> Aby3Protocol::mpc_operators() {
PADDLE_ENFORCE(_is_initialized, PROT_INIT_ERR);
return _operators;
}
std::shared_ptr<AbstractNetwork> Aby3Protocol::network() {
PADDLE_ENFORCE(_is_initialized, PROT_INIT_ERR);
return _network;
}
std::shared_ptr<CircuitContext> Aby3Protocol::mpc_context() {
PADDLE_ENFORCE(_is_initialized, PROT_INIT_ERR);
return _circuit_ctx;
}
void Aby3Protocol::init(const MpcConfig &config) {
if (_is_initialized) {
return;
}
auto server_addr = config.get(Aby3Config::NET_SERVER_ADDR,
Aby3Config::NET_SERVER_ADDR_DEFAULT);
auto server_port = config.get_int(Aby3Config::NET_SERVER_PORT,
Aby3Config::NET_SERVER_PORT_DEFAULT);
auto gloo_store =
std::make_shared<gloo::rendezvous::RedisStore>(server_addr, server_port);
init_with_store(config, gloo_store);
}
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description:
// an ABY3 protocol impl, including combination of operator, network and circuit
// context
#pragma once
#include "abstract_network.h"
#include "aby3_operators.h"
#include "gloo/rendezvous/hash_store.h"
#include "mesh_network.h"
#include "mpc_operators.h"
#include "mpc_protocol.h"
#include "core/privc3/circuit_context.h"
namespace paddle {
namespace mpc {
using CircuitContext = aby3::CircuitContext;
class Aby3Protocol : public MpcProtocol {
public:
Aby3Protocol() : MpcProtocol("aby3") {}
// virtual ~Aby3Protocol() = default;
void init(const MpcConfig &config) override;
// for test purpose
void init_with_store(const MpcConfig &config,
std::shared_ptr<gloo::rendezvous::Store> store) override;
std::shared_ptr<MpcOperators> mpc_operators() override;
std::shared_ptr<AbstractNetwork> network() override;
std::shared_ptr<CircuitContext> mpc_context() override;
private:
bool _is_initialized = false;
const std::string PROT_INIT_ERR = "The protocol is not yet initialized.";
std::shared_ptr<MpcOperators> _operators;
std::shared_ptr<AbstractNetwork> _network;
std::shared_ptr<CircuitContext> _circuit_ctx;
};
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description:
// a public access context holder, mostly used in the mpc computation where the
// paddle execution
// and mpc circuit context are needed. The corresponding contexts in the
// environment where the operator
// is executed will be stored and accessed here, which are thread local.
#include "context_holder.h"
namespace paddle {
namespace mpc {
thread_local std::shared_ptr<CircuitContext> ContextHolder::current_mpc_ctx;
thread_local const ExecutionContext *ContextHolder::current_exec_ctx;
thread_local std::shared_ptr<aby3::TensorAdapterFactory>
ContextHolder::_s_current_tensor_factory;
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description:
// a public access context holder, mostly used in the mpc computation where the
// paddle execution
// and mpc circuit context are needed. The corresponding contexts in the
// environment where the operator
// is executed will be stored and accessed here, which are thread local.
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "core/privc3/circuit_context.h"
#include "core/privc3/paddle_tensor.h"
namespace paddle {
namespace mpc {
using CircuitContext = aby3::CircuitContext;
using ExecutionContext = paddle::framework::ExecutionContext;
class ContextHolder {
public:
template <typename Operation>
static void run_with_context(const ExecutionContext *exec_ctx,
std::shared_ptr<CircuitContext> mpc_ctx,
Operation op) {
// set new ctxs
auto old_mpc_ctx = current_mpc_ctx;
current_mpc_ctx = mpc_ctx;
auto old_exec_ctx = current_exec_ctx;
current_exec_ctx = exec_ctx;
auto old_factory = _s_current_tensor_factory;
_s_current_tensor_factory = nullptr;
tensor_factory();
// run the op
op();
// restore ctxs
current_mpc_ctx = old_mpc_ctx;
current_exec_ctx = old_exec_ctx;
_s_current_tensor_factory = old_factory;
}
static std::shared_ptr<CircuitContext> mpc_ctx() { return current_mpc_ctx; }
static const ExecutionContext *exec_ctx() { return current_exec_ctx; }
static const paddle::platform::DeviceContext *device_ctx() {
return &current_exec_ctx->device_context();
}
static std::shared_ptr<aby3::TensorAdapterFactory> tensor_factory() {
if (!_s_current_tensor_factory) {
_s_current_tensor_factory =
std::make_shared<aby3::PaddleTensorFactory>(device_ctx());
}
return _s_current_tensor_factory;
}
private:
thread_local static std::shared_ptr<CircuitContext> current_mpc_ctx;
thread_local static const ExecutionContext *current_exec_ctx;
thread_local static std::shared_ptr<aby3::TensorAdapterFactory>
_s_current_tensor_factory;
};
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mesh_network.h"
#include "gloo/common/string.h"
#include "gloo/rendezvous/prefix_store.h"
#include "gloo/transport/device.h"
#include "gloo/transport/tcp/device.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace mpc {
void MeshNetwork::init() {
if (_is_initialized) {
return;
}
auto context =
std::make_shared<gloo::rendezvous::Context>(_party_id, _net_size);
auto dev = gloo::transport::tcp::CreateDevice(_local_addr.c_str());
auto prefix_store = gloo::rendezvous::PrefixStore(_store_prefix, *_store);
context->connectFullMesh(prefix_store, dev);
_rendezvous_ctx = std::move(context);
_is_initialized = true;
}
void MeshNetwork::send(size_t party, const void *data, size_t size) {
PADDLE_ENFORCE_NOT_NULL(data);
PADDLE_ENFORCE(_is_initialized);
auto unbounded_buf =
_rendezvous_ctx->createUnboundBuffer(const_cast<void *>(data), size);
unbounded_buf->send(party, 0UL /*slot*/);
unbounded_buf->waitSend();
}
void MeshNetwork::recv(size_t party, void *data, size_t size) {
PADDLE_ENFORCE_NOT_NULL(data);
PADDLE_ENFORCE(_is_initialized);
auto unbounded_buf = _rendezvous_ctx->createUnboundBuffer(data, size);
unbounded_buf->recv(party, 0UL /*slot*/);
unbounded_buf->waitRecv();
}
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "gloo/rendezvous/context.h"
#include "gloo/rendezvous/hash_store.h"
#include "abstract_network.h"
namespace paddle {
namespace mpc {
// A full-connected network based on underlying GLOO toolkit, with the network
// size of 3
class MeshNetwork : public paddle::mpc::AbstractNetwork {
public:
// a ctor called for the explicit netowrk size and store as parameters
// prefix: the prefix of keys in the store to differentiate different runs
// example:
// auto store = std::make_shared<gloo::rendezvous::HashStore>();
// (in each thread:)
// paddle::mpc::MeshNetwork net(0, "127.0.0.1", 3, "test_prefix", store);
// net.init();
// net.send(1, data, sizeof(data))
//
MeshNetwork(const size_t party_id, const std::string &local_addr,
const size_t net_size, const std::string &prefix,
std::shared_ptr<gloo::rendezvous::Store> store)
: _party_id(party_id), _local_addr(local_addr), _net_size(net_size),
_store_prefix(prefix), _store(std::move(store)),
_is_initialized(false) {}
virtual ~MeshNetwork() = default;
void send(size_t party, const void *data, size_t size) override;
void recv(size_t party, void *data, size_t size) override;
size_t party_id() const override { return _party_id; };
size_t party_num() const override { return _net_size; };
// must be called before use
void init();
private:
const size_t _party_id;
const size_t _net_size;
const std::string _local_addr;
const std::string _store_prefix;
std::shared_ptr<gloo::rendezvous::Store> _store;
std::shared_ptr<gloo::rendezvous::Context> _rendezvous_ctx;
bool _is_initialized;
};
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h"
#include <thread>
#include "gtest/gtest.h"
namespace paddle {
namespace mpc {
class NetworkTest : public ::testing::Test {
public:
std::string _addr;
std::string _prefix;
std::shared_ptr<gloo::rendezvous::HashStore> _store;
MeshNetwork _n0;
MeshNetwork _n1;
AbstractNetwork *_p0;
AbstractNetwork *_p1;
NetworkTest()
: _addr("127.0.0.1"), _prefix("test_prefix"),
_store(std::make_shared<gloo::rendezvous::HashStore>()),
_n0(0, _addr, 2, _prefix, _store), _n1(1, _addr, 2, _prefix, _store),
_p0(&_n0), _p1(&_n1) {}
void SetUp() {
std::thread t0([this]() { _n0.init(); });
std::thread t1([this]() { _n1.init(); });
t0.join();
t1.join();
}
};
TEST_F(NetworkTest, basic_test) {
int buf[2] = {0, 1};
std::thread t0([this, &buf]() {
_p0->template send(1, buf[0]);
buf[0] = _p0->template recv<int>(1);
});
std::thread t1([this, &buf]() {
int to_send = buf[1];
buf[1] = _p1->template recv<int>(0);
_p1->template send(0, to_send);
});
t0.join();
t1.join();
EXPECT_EQ(1, buf[0]);
EXPECT_EQ(0, buf[1]);
}
} // namespace mpc
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description:
// abstract mpc operation interface
#pragma once
#include <string>
#include <unordered_map>
namespace paddle {
namespace mpc {
class MpcConfig {
public:
MpcConfig() {}
MpcConfig(const MpcConfig &config) = default;
int get_int(const std::string &key, int default_val = 0) const {
auto got = _prop_map.find(key);
if (got != _prop_map.end()) {
auto ret = got->second;
return std::stoi(ret);
}
return default_val;
}
// get value accoding to specified key, an empty string is returned otherwise
std::string get(const std::string &key,
const std::string &default_val = std::string()) const {
auto got = _prop_map.find(key);
if (got != _prop_map.end()) {
return got->second;
}
return default_val;
}
// set the config item. if an item with same key exists, it will be
// overwritten.
MpcConfig &set(const std::string &key, const std::string &value) {
_prop_map[key] = value;
return *this;
}
MpcConfig &set_int(const std::string &key, const int value) {
return set(key, std::to_string(value));
}
private:
std::unordered_map<std::string, std::string> _prop_map;
};
class Aby3Config : public MpcConfig {
public:
// predefined parameters for aby3 protocol configuration
static const std::string ROLE;
static const std::string LOCAL_ADDR;
static const std::string NET_SERVER_ADDR;
static const std::string NET_SERVER_PORT;
// default values
static const std::string LOCAL_ADDR_DEFAULT;
static const std::string NET_SERVER_ADDR_DEFAULT;
static const int NET_SERVER_PORT_DEFAULT;
};
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description:
// abstract mpc operation interface
#include "mpc_config.h"
namespace paddle {
namespace mpc {
// aby3 protocol parameters and defaults
const std::string Aby3Config::ROLE("role");
const std::string Aby3Config::LOCAL_ADDR("local.address");
const std::string Aby3Config::NET_SERVER_ADDR("net_server.address");
const std::string Aby3Config::NET_SERVER_PORT("net_server.port");
const std::string Aby3Config::LOCAL_ADDR_DEFAULT("localhost");
const std::string Aby3Config::NET_SERVER_ADDR_DEFAULT("localhost");
const int Aby3Config::NET_SERVER_PORT_DEFAULT =
6379; // default redis server port
} // mpc
} // paddle
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mpc_instance.h"
namespace paddle {
namespace mpc {
thread_local std::once_flag MpcInstance::_s_init_flag;
thread_local std::shared_ptr<MpcInstance> MpcInstance::_s_mpc_instance(nullptr);
thread_local std::shared_ptr<MpcProtocol> MpcInstance::_s_mpc_protocol(nullptr);
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description: A paddle_encrypted executor for running a mpc program
#pragma once
#include "gloo/rendezvous/hash_store.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_protocol.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_protocol_factory.h"
#include <memory>
#include <mutex>
namespace paddle {
namespace mpc {
class MpcInstance {
private:
MpcInstance(const std::string &protocol_name, const MpcConfig &config)
: _protocol_name(protocol_name), _mpc_config(config) {}
// for test purpose
void prepare_mpc_protocol_with_store(
std::shared_ptr<gloo::rendezvous::Store> store) {
_s_mpc_protocol = MpcProtocolFactory::build(_protocol_name);
PADDLE_ENFORCE_NOT_NULL(_s_mpc_protocol, "Unrecognized mpc protocol: %s",
_protocol_name);
_s_mpc_protocol->init_with_store(_mpc_config, store);
}
void prepare_mpc_protocol() {
_s_mpc_protocol = MpcProtocolFactory::build(_protocol_name);
PADDLE_ENFORCE_NOT_NULL(_s_mpc_protocol, "Unrecognized mpc protocol: %s",
_protocol_name);
_s_mpc_protocol->init(_mpc_config);
}
static void init_mpc(const std::string &protocol_name,
const MpcConfig &mpc_config) {
_s_mpc_instance.reset(new MpcInstance(protocol_name, mpc_config));
_s_mpc_instance->prepare_mpc_protocol();
}
// for test purpose
static void
init_mpc_with_store(const std::string &protocol_name,
const MpcConfig &mpc_config,
std::shared_ptr<gloo::rendezvous::Store> store) {
_s_mpc_instance.reset(new MpcInstance(protocol_name, mpc_config));
_s_mpc_instance->prepare_mpc_protocol_with_store(store);
}
public:
static std::shared_ptr<MpcInstance>
init_instance(const std::string &protocol_name, const MpcConfig &mpc_config) {
std::call_once(_s_init_flag, &MpcInstance::init_mpc, protocol_name,
mpc_config);
return _s_mpc_instance;
}
// for test purpose
static std::shared_ptr<MpcInstance>
init_instance_with_store(const std::string &protocol_name,
const MpcConfig &mpc_config,
std::shared_ptr<gloo::rendezvous::Store> store) {
std::call_once(_s_init_flag, &MpcInstance::init_mpc_with_store,
protocol_name, mpc_config, store);
return _s_mpc_instance;
}
static std::shared_ptr<MpcInstance> mpc_instance() {
PADDLE_ENFORCE_NOT_NULL(_s_mpc_instance,
"Mpc instance is not initialized!");
return _s_mpc_instance;
}
static std::shared_ptr<MpcProtocol> mpc_protocol() {
PADDLE_ENFORCE_NOT_NULL(_s_mpc_protocol, "MpcProtocol is null.");
return _s_mpc_protocol;
}
private:
static thread_local std::once_flag _s_init_flag;
const std::string _protocol_name;
MpcConfig _mpc_config;
static thread_local std::shared_ptr<MpcInstance> _s_mpc_instance;
static thread_local std::shared_ptr<MpcProtocol> _s_mpc_protocol;
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.i
#include "mpc_instance.h"
#include "mpc_config.h"
#include "gtest/gtest.h"
#include <thread>
#include "aby3_protocol.h"
#include "mpc_protocol_factory.h"
#include "core/privc3/circuit_context.h"
#include "gtest/gtest.h"
namespace paddle {
namespace mpc {
using namespace std;
TEST(MpcInstanceTest, InitInstance) {
using paddle::platform::EnforceNotMet;
EXPECT_THROW(MpcInstance::mpc_instance(), EnforceNotMet);
auto gloo_store = std::make_shared<gloo::rendezvous::HashStore>();
std::shared_ptr<std::thread> threads[3];
for (int idx = 0; idx < 3; ++idx) {
threads[idx] = std::make_shared<std::thread>([gloo_store, idx]() {
const std::string protocol_name("aby3");
MpcConfig aby3_config;
aby3_config.set_int(Aby3Config::ROLE, idx);
auto mpc_instance = MpcInstance::init_instance_with_store(
protocol_name, aby3_config, gloo_store);
ASSERT_NE(MpcInstance::mpc_instance(), nullptr);
EXPECT_EQ(MpcInstance::mpc_instance(), mpc_instance);
EXPECT_EQ(mpc_instance, MpcInstance::init_instance_with_store(
protocol_name, aby3_config, gloo_store));
EXPECT_EQ(mpc_instance->mpc_protocol()->name(), "aby3");
});
}
EXPECT_THROW(MpcInstance::mpc_instance(), EnforceNotMet);
for (auto thread : threads) {
thread->join();
}
}
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description:
// abstract mpc operation interface
#pragma once
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace mpc {
using paddle::framework::Tensor;
class MpcOperators {
public:
virtual void add(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0;
virtual void sub(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0;
virtual void neg(const Tensor *op, Tensor *out) = 0;
virtual void sum(const Tensor *op, Tensor *out) = 0;
virtual void mul(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0;
virtual void matmul(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0;
virtual void scale(const Tensor *lhs, const double factor, Tensor *out) = 0;
virtual void relu(const Tensor *op, Tensor *out) = 0;
virtual void softmax(const Tensor *op, Tensor *out) = 0;
virtual void gt(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0;
virtual void geq(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0;
virtual void lt(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0;
virtual void leq(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0;
virtual void eq(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0;
virtual void neq(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0;
virtual void relu_grad(const Tensor *y, const Tensor *dy, Tensor *dx,
const float point) = 0;
};
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description:
// mpc protocol base class
#pragma once
#include "abstract_network.h"
#include "gloo/rendezvous/hash_store.h"
#include "mpc_config.h"
#include "mpc_operators.h"
#include "core/privc3/circuit_context.h"
namespace paddle {
namespace mpc {
class MpcProtocol {
public:
MpcProtocol(const std::string &name) : _name(name){};
virtual ~MpcProtocol() = default;
virtual std::string name() const { return _name; }
virtual void init(const MpcConfig &config) = 0;
// for test purpose
virtual void
init_with_store(const MpcConfig &config,
std::shared_ptr<gloo::rendezvous::Store> store) = 0;
virtual std::shared_ptr<MpcOperators> mpc_operators() = 0;
virtual std::shared_ptr<AbstractNetwork> network() = 0;
virtual std::shared_ptr<aby3::CircuitContext> mpc_context() = 0;
private:
const std::string _name;
};
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.nguage governing permissions and
#include "mpc_protocol_factory.h"
#include "aby3_protocol.h"
namespace paddle {
namespace mpc {
void MpcProtocolFactory::register_protocol() {
if (!_is_initialized) {
_creator_map.insert({"aby3", std::make_shared<Aby3Protocol>});
}
_is_initialized = true;
}
std::shared_ptr<MpcProtocol>
MpcProtocolFactory::build(const std::string &name) {
if (!_is_initialized) {
register_protocol();
}
auto where = _creator_map.find(to_lowercase(name));
if (where == _creator_map.end()) {
return nullptr;
}
return where->second();
}
MpcProtocolFactory::CreatorMap MpcProtocolFactory::_creator_map;
bool MpcProtocolFactory::_is_initialized = false;
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description:
// a factory class, help give pre-defined mpc protocol instances with given name
#pragma once
#include <algorithm>
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include "mpc_protocol.h"
namespace paddle {
namespace mpc {
class MpcProtocolFactory {
public:
using Creator = std::function<std::shared_ptr<MpcProtocol>()>;
using CreatorMap = std::unordered_map<std::string, Creator>;
MpcProtocolFactory() = delete;
static void register_protocol();
static std::shared_ptr<MpcProtocol> build(const std::string &name);
private:
static bool _is_initialized;
static CreatorMap _creator_map;
static inline std::string to_lowercase(const std::string &str) {
std::string orig_str(str);
std::transform(orig_str.begin(), orig_str.end(), orig_str.begin(),
::tolower);
return orig_str;
}
};
} // mpc
} // paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thread>
#include "aby3_protocol.h"
#include "mpc_config.h"
#include "mpc_protocol_factory.h"
#include "core/privc3/circuit_context.h"
#include "gtest/gtest.h"
namespace paddle {
namespace mpc {
// tests for mpc_protocol and its configuration facilities
TEST(MpcProtocolTest, FindProtocol) {
const std::string aby3_name("aby3");
// empty name
auto illegal = MpcProtocolFactory::build("");
EXPECT_EQ(illegal, nullptr);
// find aby3 with lower case name
auto aby3_lower = MpcProtocolFactory::build("aby3");
ASSERT_NE(aby3_lower, nullptr);
EXPECT_EQ(aby3_lower->name(), aby3_name);
// find aby3 with mixed lower and upper case name
auto aby3_upper = MpcProtocolFactory::build("ABy3");
ASSERT_NE(aby3_upper, nullptr);
EXPECT_EQ(aby3_upper->name(), aby3_name);
// find unknown protocol
auto unknown = MpcProtocolFactory::build("foo");
EXPECT_EQ(unknown, nullptr);
}
TEST(MpcProtocolTest, ProtocolInit) {
using paddle::platform::EnforceNotMet;
auto mpc = MpcProtocolFactory::build("aby3");
ASSERT_NE(mpc, nullptr);
// not yet initialized
EXPECT_THROW(mpc->mpc_context(), EnforceNotMet);
EXPECT_THROW(mpc->mpc_operators(), EnforceNotMet);
EXPECT_THROW(mpc->network(), EnforceNotMet);
// try initialize
auto aby3 = std::dynamic_pointer_cast<Aby3Protocol>(mpc);
MpcConfig config;
// null store
EXPECT_THROW(aby3->init_with_store(config, nullptr), EnforceNotMet);
auto gloo_store = std::make_shared<gloo::rendezvous::HashStore>();
std::shared_ptr<std::thread> threads[3];
for (int idx = 0; idx < 3; ++idx) {
threads[idx] = std::make_shared<std::thread>([gloo_store, idx]() {
auto proto = std::make_shared<Aby3Protocol>();
ASSERT_NE(proto, nullptr);
MpcConfig aby3_config;
aby3_config.set_int(Aby3Config::ROLE, idx);
proto->init_with_store(aby3_config, gloo_store);
ASSERT_NE(proto->network(), nullptr);
EXPECT_EQ(proto->network()->party_id(), idx);
EXPECT_EQ(proto->network()->party_num(), 3);
ASSERT_NE(proto->mpc_context(), nullptr);
EXPECT_EQ(proto->mpc_context()->next_party(), (idx + 1) % 3);
EXPECT_EQ(proto->mpc_context()->pre_party(), (idx + 2) % 3);
EXPECT_NE(proto->mpc_operators(), nullptr);
});
}
for (auto thread : threads) {
thread->join();
}
}
TEST(MpcConfigTest, ConfigSetAndGet) {
MpcConfig config;
const std::string EMPTY_STR;
const int ZERO = 0;
// non-exist key leads to default string
EXPECT_EQ(config.get("foo"), EMPTY_STR);
// non-exist key leads to default int
EXPECT_EQ(config.get_int("bar"), ZERO);
// non-exist key leads to specified default str
const std::string DEF_STR("default");
EXPECT_EQ(config.get("foo1", DEF_STR), DEF_STR);
// non-exist key leads to specified default int
const int ONE = 1;
EXPECT_EQ(config.get_int("foo2", ONE), ONE);
const std::string KEY_STR("key1");
const std::string KEY_INT("key2");
const std::string VALUE_STR("value1");
const int VALUE_INT = 2;
config.set(KEY_STR, VALUE_STR).set_int(KEY_INT, VALUE_INT);
// expected results
EXPECT_EQ(config.get(KEY_STR), VALUE_STR);
EXPECT_EQ(config.get_int(KEY_INT), VALUE_INT);
// get wrong int
EXPECT_THROW(config.get_int(KEY_STR), std::invalid_argument);
// override existing key
const std::string VALUE_STR2("value2");
config.set(KEY_STR, VALUE_STR2);
EXPECT_EQ(config.get(KEY_STR), VALUE_STR2);
EXPECT_NE(config.get(KEY_STR), VALUE_STR);
}
} // mpc
} // paddle
add_compile_options(-msse4.2 -maes)
aux_source_directory(. DIR_SRCS)
add_library(mpc_ops_o OBJECT ${DIR_SRCS})
add_dependencies(mpc_ops_o fluid_framework gloo)
add_library(mpc_ops STATIC $<TARGET_OBJECTS:mpc_ops_o>)
target_link_libraries(mpc_ops fluid_framework gloo)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mpc_compare_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class MpcCompareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
platform::errors::NotFound(
"Input(Y) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcCompareOp should not be null."));
auto dim_x = ctx->GetInputDim("X");
auto dim_y = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(),
"The size of dim_y should not be greater than dim_x's.");
ctx->ShareDim("Y", /*->*/ "Out");
ctx->ShareLoD("Y", /*->*/ "Out");
}
framework::OpKernelType
GetExpectedKernelType(const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class MpcCompareOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The first input tensor of MpcCompareOp.");
AddInput("Y", "(Tensor), The second input tensor of MpcCompareOp.");
AddOutput("Out", "(Tensor), The output tensor of MpcCompareOp.");
AddComment(R"DOC(
MPC Compare Operator.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_than, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_greater_than,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MpcGreaterThanFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_greater_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MpcGreaterEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_than, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_less_than, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcLessThanFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_less_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcLessEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_not_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_not_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcNotEqualFunctor>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.uage governing permissions and
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include <math.h>
#include <type_traits>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
struct MpcGreaterThanFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->gt(
in_x_t, in_y_t, out_t);
}
};
struct MpcGreaterEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->geq(
in_x_t, in_y_t, out_t);
}
};
struct MpcLessThanFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->lt(
in_x_t, in_y_t, out_t);
}
};
struct MpcLessEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->leq(
in_x_t, in_y_t, out_t);
}
};
struct MpcEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->eq(
in_x_t, in_y_t, out_t);
}
};
struct MpcNotEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neq(
in_x_t, in_y_t, out_t);
}
};
template <typename DeviceContext, typename T, typename Functor>
class MpcCompareOpKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *out_t = ctx.Output<framework::LoDTensor>("Out");
auto out = out_t->mutable_data<T>(ctx.GetPlace());
Functor().Run(in_x_t, in_y_t, out_t);
}
};
} // namespace operators
} // namespace paddl
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mpc_elementwise_add_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class MpcElementwiseAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound(
"Input(Y) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_GE(
ctx->GetInputDim("X").size(), ctx->GetInputDim("Y").size(),
platform::errors::InvalidArgument(
"The dimensions of X should be greater than the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is "
"[%s]",
ctx->GetInputDim("X"), ctx->GetInputDim("Y")));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class MpcElementwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), The first input tensor of mpc elementwise add op.");
AddInput("Y",
"(Tensor), The second input tensor of mpc elementwise add op.");
AddOutput("Out", "(Tensor), The output tensor of mpc elementwise add op.");
AddAttr<int>("axis",
"(int, default -1). If X.dimension != Y.dimension,"
"Y.dimension must be a subsequence of x.dimension. And axis "
"is the start dimension index "
"for broadcasting Y onto X. ")
.SetDefault(-1)
.EqualGreaterThan(-1);
AddComment(R"DOC(
MPC elementwise add Operator.
)DOC");
}
};
class MpcElementwiseAddGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
"Input(Out@GRAD) should not be null.");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
}
}
};
template <typename T>
class MpcElementwiseAddOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_elementwise_add_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
return retv;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_elementwise_add, ops::MpcElementwiseAddOp,
ops::MpcElementwiseAddOpMaker,
ops::MpcElementwiseAddOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_elementwise_add_grad, ops::MpcElementwiseAddGradOp);
REGISTER_OP_CPU_KERNEL(
mpc_elementwise_add,
ops::MpcElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_elementwise_add_grad,
ops::MpcElementwiseAddGradKernel<
paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This op is different with elementwise_add of PaddlePaddle.
// We only consider that the dimensions of X is equal with the dimensions of Y.
#pragma once
#include "mpc_op.h"
#include "paddle/fluid/platform/transform.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// paddle/fluid/operators/elementwise/elementwise_op_function.h
template <typename T, typename DeviceContext> class RowwiseTransformIterator;
template <typename T>
class RowwiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
T *, T &> {
public:
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
return *this;
}
RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator+(int n) {
while (n-- > 0) {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
return *this;
}
bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) != &(*rhs);
}
const T &operator*() { return ptr_[i_]; }
private:
const T *ptr_;
int i_;
int64_t n_;
};
template <typename T> struct AddFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return x + y; }
};
struct GetMidDims {
inline HOSTDEVICE void operator()(const framework::DDim &x_dims,
const framework::DDim &y_dims,
const int axis, int *pre, int *n,
int *post) {
*pre = 1;
*n = 1;
*post = 1;
for (int i = 1; i < axis + 1; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 1; i < y_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
"Broadcast dimension mismatch.");
(*n) *= y_dims[i];
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
}
};
const size_t SHARE_NUM = 2;
template <typename DeviceContext, typename T>
class MpcElementwiseAddKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<LoDTensor>("X");
auto *in_y_t = ctx.Input<LoDTensor>("Y");
auto *out_t = ctx.Output<LoDTensor>("Out");
int axis = ctx.Attr<int>("axis");
auto out = out_t->mutable_data<T>(ctx.GetPlace());
if (in_x_t->dims() == in_y_t->dims()) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(
in_x_t, in_y_t, out_t);
} else {
Tensor in_x_t_slice;
Tensor in_y_t_slice;
Tensor out_t_slice;
for (size_t i = 0; i < SHARE_NUM; ++i) {
in_x_t_slice = in_x_t->Slice(i, i + 1);
in_y_t_slice = in_y_t->Slice(i, i + 1);
out_t_slice = out_t->Slice(i, i + 1);
auto x_dims = in_x_t_slice.dims();
auto y_dims = in_y_t_slice.dims();
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
GetMidDims get_mid_dims;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
PADDLE_ENFORCE_EQ(
post, 1, "post should be equal 1, but received post is [%s]", post);
auto x_ = in_x_t_slice.data<T>();
auto y_ = in_y_t_slice.data<T>();
auto out_ = out_t_slice.data<T>();
auto nx_ = in_x_t_slice.numel();
paddle::platform::Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), x_, x_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(y_, n), out_,
AddFunctor<T>());
}
}
}
};
template <typename DeviceContext, typename T>
class MpcElementwiseAddGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto dout_data = dout->data<T>();
if (dx) {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < dout->numel(); i++) {
dx_data[i] = dout_data[i];
}
}
if (dy) {
auto dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (in_x_t->dims().size() == in_y_t->dims().size()) {
for (size_t i = 0; i < dout->numel(); i++) {
dy_data[i] = dout_data[i];
}
} else {
auto x_dims = in_x_t->dims();
auto y_dims = in_y_t->dims();
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
GetMidDims get_mid_dims;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
PADDLE_ENFORCE_EQ(
post, 1, "post should be equal 1, but received post is [%s]", post);
for (size_t i = 0; i < SHARE_NUM; ++i) {
int y_offset = i * n;
for (size_t j = 0; j < pre; ++j) {
for (size_t k = 0; k < n; ++k) {
int out_offset = i * pre * n + j * n + k;
if (0 == j) {
dy_data[k + y_offset] = dout_data[out_offset];
} else {
dy_data[k + y_offset] += dout_data[out_offset];
}
}
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mpc_elementwise_sub_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class MpcElementwiseSubOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound(
"Input(Y) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("X"), ctx->GetInputDim("Y"),
platform::errors::InvalidArgument(
"The dimensions of X should be equal with the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is "
"[%s]",
ctx->GetInputDim("X"), ctx->GetInputDim("Y")));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class MpcElementwiseSubOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), The first input tensor of mpc elementwise sub op.");
AddInput("Y",
"(Tensor), The second input tensor of mpc elementwise sub op.");
AddOutput("Out", "(Tensor), The output tensor of mpc elementwise sub op.");
AddComment(R"DOC(
MPC elementwise sub Operator.
)DOC");
}
};
class MpcElementwiseSubGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
"Input(Out@GRAD) should not be null.");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
}
}
};
template <typename T>
class MpcElementwiseSubGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_elementwise_sub_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
return retv;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_elementwise_sub, ops::MpcElementwiseSubOp,
ops::MpcElementwiseSubOpMaker,
ops::MpcElementwiseSubGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_elementwise_sub_grad, ops::MpcElementwiseSubGradOp);
REGISTER_OP_CPU_KERNEL(
mpc_elementwise_sub,
ops::MpcElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_elementwise_sub_grad,
ops::MpcElementwiseSubGradKernel<
paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This op is different with elementwise_sub of PaddlePaddle.
// We only consider that the dimensions of X is equal with the dimensions of Y.
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class MpcElementwiseSubKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *in_y_t = ctx.Input<Tensor>("Y");
auto *out_t = ctx.Output<Tensor>("Out");
auto out = out_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(
in_x_t, in_y_t, out_t);
}
};
template <typename DeviceContext, typename T>
class MpcElementwiseSubGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
VLOG(3) << "******** MpcElementwiseSubGradKernel: ";
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto dout_data = dout->data<T>();
if (dx) {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < dout->numel(); i++) {
dx_data[i] = dout_data[i];
}
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neg(
dout, dy);
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description:
#include "paddle/fluid/framework/op_registry.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
using mpc::MpcConfig;
using mpc::Aby3Config;
class MpcInitOp : public framework::OperatorBase {
public:
MpcInitOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto protocol_name = Attr<std::string>("protocol_name");
auto role = Attr<int>("role");
auto local_addr = Attr<std::string>("local_addr");
auto net_server_addr = Attr<std::string>("net_server_addr");
auto net_server_port = Attr<int>("net_server_port");
MpcConfig _mpc_config;
_mpc_config.set_int(Aby3Config::ROLE, role);
_mpc_config.set(Aby3Config::LOCAL_ADDR, local_addr);
_mpc_config.set(Aby3Config::NET_SERVER_ADDR, net_server_addr);
_mpc_config.set_int(Aby3Config::NET_SERVER_PORT, net_server_port);
mpc::MpcInstance::init_instance(protocol_name, _mpc_config);
}
};
class MpcInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(
Where2 Operator.
)DOC");
AddAttr<std::string>("protocol_name", "(string , default aby3)"
"protocol name")
.SetDefault({"aby3"});
AddAttr<int>("role", "trainer role.").SetDefault(0);
AddAttr<std::string>("local_addr", "(string, default localhost)"
"local addr")
.SetDefault({"localhost"});
AddAttr<std::string>("net_server_addr", "(string, default localhost)"
"net server addr")
.SetDefault({"localhost"});
AddAttr<int>("net_server_port", "net server port, default to 6539.")
.SetDefault(6539);
}
};
class MpcInitOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_init, ops::MpcInitOp, ops::MpcInitOpMaker,
ops::MpcInitOpShapeInference);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mpc_mean_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class MpcMeanOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcMeanOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcMeanOp should not be null."));
ctx->SetOutputDim("Out", {2, 1});
}
};
class MpcMeanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc mean op.");
AddOutput("Out", "(Tensor), The output tensor of mpc mean op.");
AddComment(R"DOC(
MPC mean Operator calculates the mean of all elements in X.
)DOC");
}
};
class MpcMeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>
GetInputOutputWithSameType() const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
class MpcMeanGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
};
template <typename T>
class MpcMeanOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_mean_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return retv;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_mean, ops::MpcMeanOp, ops::MpcMeanOpMaker,
ops::MpcMeanOpInferVarType,
ops::MpcMeanOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_mean_grad, ops::MpcMeanGradOp);
REGISTER_OP_CPU_KERNEL(
mpc_mean, ops::MpcMeanKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
mpc_mean_grad,
ops::MpcMeanGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mpc_op.h"
#include "paddle/fluid/framework/eigen.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class MpcMeanKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *out_t = ctx.Output<Tensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace());
double scale = 1.0 / (in_x_t->numel() / 2.0);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sum(
in_x_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(
out_t, scale, out_t);
}
};
template <typename DeviceContext, typename T>
class MpcMeanGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(dout->numel() == 2,
"numel of MpcMean Gradient should be 2.");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dout_data = dout->data<T>();
if (dx) {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < dx->numel() / 2; ++i) {
dx_data[i] = dout_data[0];
}
for (size_t i = dx->numel() / 2; i < dx->numel(); ++i) {
dx_data[i] = dout_data[1];
}
double scale_factor = 1.0 / (dx->numel() / 2);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(
dx, scale_factor, dx);
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mpc_mul_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class MpcMulOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of Mpc MulOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MpcMulOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcMulOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims");
VLOG(3) << "mpc mul operator x.shape=" << x_dims << " y.shape=" << y_dims
<< " x_num_col_dims=" << x_num_col_dims
<< " y_num_col_dims=" << y_num_col_dims;
PADDLE_ENFORCE_NE(framework::product(y_dims), 0,
platform::errors::PreconditionNotMet(
"The Input variable Y(%s) has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
ctx->Inputs("Y").front()));
PADDLE_ENFORCE_GT(
x_dims.size(), x_num_col_dims,
platform::errors::InvalidArgument(
"The input tensor X's dimensions of MpcMulOp "
"should be larger than x_num_col_dims. But received X's "
"dimensions = %d, X's shape = [%s], x_num_col_dims = %d.",
x_dims.size(), x_dims, x_num_col_dims));
PADDLE_ENFORCE_GT(
y_dims.size(), y_num_col_dims,
platform::errors::InvalidArgument(
"The input tensor Y's dimensions of MpcMulOp "
"should be larger than y_num_col_dims. But received Y's "
"dimensions = %d, Y's shape = [%s], y_num_col_dims = %d.",
y_dims.size(), y_dims, y_num_col_dims));
int x_mat_width = 1;
int y_mat_height = 1;
for (size_t i = x_num_col_dims + 1; i < x_dims.size(); i++) {
x_mat_width *= x_dims[i];
}
for (size_t i = 1; i <= y_num_col_dims; i++) {
y_mat_height *= y_dims[i];
}
PADDLE_ENFORCE_EQ(
x_mat_width, y_mat_height,
platform::errors::InvalidArgument(
"After flatten the input tensor X and Y to 2-D dimensions "
"matrix X1 and Y1, the matrix X1's width must be equal with matrix "
"Y1's height. But received X's shape = [%s], X1's "
"width = %s; Y's shape = [%s], Y1's height = %s.",
x_dims, x_mat_width, y_dims, y_mat_height));
std::vector<int64_t> output_dims;
output_dims.reserve(static_cast<size_t>(1 + x_num_col_dims + y_dims.size() -
y_num_col_dims));
for (int i = 0; i <= x_num_col_dims; ++i) { // i=0, batch_size (share id)
output_dims.push_back(x_dims[i]);
}
for (int i = y_num_col_dims + 1; i < y_dims.size(); ++i) {
output_dims.push_back(y_dims[i]);
}
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class MpcMulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc mul op.");
AddInput("Y", "(Tensor), The second input tensor of mpc mul op.");
AddOutput("Out", "(Tensor), The output tensor of mpc mul op.");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<int>(
"x_num_col_dims",
R"DOC((int, default 1), The mul_op can take tensors with more than two
dimensions as its inputs. If the input $X$ is a tensor with more
than two dimensions, $X$ will be flattened into a two-dimensional
matrix first. The flattening rule is: the first `num_col_dims`
will be flattened to form the first dimension of the final matrix
(the height of the matrix), and the rest `rank(X) - num_col_dims`
dimensions are flattened to form the second dimension of the final
matrix (the width of the matrix). As a result, height of the
flattened matrix is equal to the product of $X$'s first
`x_num_col_dims` dimensions' sizes, and width of the flattened
matrix is equal to the product of $X$'s last `rank(x) - num_col_dims`
dimensions' size. For example, suppose $X$ is a 6-dimensional
tensor with the shape [2, 3, 4, 5, 6], and `x_num_col_dims` = 3.
Thus, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] =
[24, 30].
)DOC")
.SetDefault(1)
.EqualGreaterThan(1);
AddAttr<int>(
"y_num_col_dims",
R"DOC((int, default 1), The mul_op can take tensors with more than two,
dimensions as its inputs. If the input $Y$ is a tensor with more
than two dimensions, $Y$ will be flattened into a two-dimensional
matrix first. The attribute `y_num_col_dims` determines how $Y$ is
flattened. See comments of `x_num_col_dims` for more details.
)DOC")
.SetDefault(1)
.EqualGreaterThan(1);
AddAttr<float>(
"scale_x",
"scale_x to be used for int8 mul input data x. scale_x has the"
"same purpose as scale_in in OPs that support quantization."
"Only to be used with MKL-DNN INT8")
.SetDefault(1.0f);
AddAttr<std::vector<float>>(
"scale_y",
"scale_y to be used for int8 mul input data y. scale_y has the"
"same purpose as scale_weights in OPs that support quantization."
"Only to be used with MKL-DNN INT8")
.SetDefault({1.0f});
AddAttr<float>("scale_out", "scale_out to be used for int8 output data."
"Only used with MKL-DNN INT8")
.SetDefault(1.0f);
AddAttr<bool>(
"force_fp32_output",
"(bool, default false) Force quantize kernel output FP32, only "
"used in quantized MKL-DNN.")
.SetDefault(false);
AddComment(R"DOC(
MPC mul Operator.
)DOC");
}
};
class MpcMulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>
GetInputOutputWithSameType() const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
class MpcMulGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
"Input(Out@GRAD) should not be null.");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};
template <typename T>
class MpcMulOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_mul_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
return retv;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_mul, ops::MpcMulOp, ops::MpcMulOpMaker,
ops::MpcMulOpInferVarType,
ops::MpcMulOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_mul_grad, ops::MpcMulGradOp);
REGISTER_OP_CPU_KERNEL(
mpc_mul, ops::MpcMulKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
mpc_mul_grad,
ops::MpcMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class MpcMulKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Input<Tensor>("Y");
auto *out = ctx.Output<Tensor>("Out");
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto x_dims = x->dims();
auto y_dims = y->dims();
int x_mat_width = 1;
int x_mat_height = 1;
int y_mat_width = 1;
int y_mat_height = 1;
for (size_t i = 1; i < x_dims.size(); i++) {
if (i <= x_num_col_dims) {
x_mat_width *= x_dims[i];
} else {
x_mat_height *= x_dims[i];
}
}
for (size_t i = 1; i < y_dims.size(); i++) {
if (i <= y_num_col_dims) {
x_mat_width *= y_dims[i];
} else {
y_mat_height *= y_dims[i];
}
}
Tensor x_matrix;
Tensor y_matrix;
x_matrix.ShareDataWith(*x);
y_matrix.ShareDataWith(*y);
if (x_dims.size() > 3) {
x_matrix.Resize({2, x_mat_width, x_mat_height});
}
if (y_dims.size() > 3) {
y_matrix.Resize({2, y_mat_width, y_mat_height});
}
out->mutable_data<T>(ctx.GetPlace());
auto out_dim = out->dims();
if (out_dim.size() > 3) {
out->Resize({2, x_mat_width, y_mat_height});
}
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&x_matrix, &y_matrix, out);
if (out_dim.size() > 3) {
out->Resize(out_dim);
}
}
};
template <typename DeviceContext, typename T>
class MpcMulGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto x_dims = x->dims();
auto y_dims = y->dims();
auto dout_dims = dout->dims();
int x_mat_width = 1;
int x_mat_height = 1;
int y_mat_width = 1;
int y_mat_height = 1;
for (size_t i = 1; i < x_dims.size(); i++) {
if (i <= x_num_col_dims) {
x_mat_width *= x_dims[i];
} else {
x_mat_height *= x_dims[i];
}
}
for (size_t i = 1; i < y_dims.size(); i++) {
if (i <= y_num_col_dims) {
y_mat_width *= y_dims[i];
} else {
y_mat_height *= y_dims[i];
}
}
Tensor x_matrix;
Tensor y_matrix;
Tensor dout_matrix;
x_matrix.ShareDataWith(*x);
y_matrix.ShareDataWith(*y);
dout_matrix.ShareDataWith(*dout);
if (x_dims.size() > 3) {
x_matrix.Resize({2, x_mat_width, x_mat_height});
}
if (y_dims.size() > 3) {
y_matrix.Resize({2, y_mat_width, y_mat_height});
}
if (dout_dims.size() > 3) {
dout_matrix.Resize({2, x_mat_width, y_mat_height});
}
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
Tensor x_matrix_trans;
Tensor y_matrix_trans;
x_matrix_trans.mutable_data<T>(x->dims(), ctx.GetPlace());
y_matrix_trans.mutable_data<T>(y->dims(), ctx.GetPlace());
if (x_dims.size() >= 3) {
x_matrix_trans.Resize({2, x_mat_height, x_mat_width});
}
if (y_dims.size() >= 3) {
y_matrix_trans.Resize({2, y_mat_height, y_mat_width});
}
auto &dev_ctx = ctx.template device_context<DeviceContext>();
const int Rank = 3;
Eigen::array<int, Rank> permute;
permute[0] = 0;
permute[1] = 2;
permute[2] = 1;
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims().size() > 3) {
dx->Resize({2, x_mat_width, x_mat_height});
}
auto eigen_in = framework::EigenTensor<T, Rank>::From(y_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(y_matrix_trans);
auto *dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&dout_matrix, &y_matrix_trans, dx);
auto dx_dim = dx->dims();
if (dx_dim.size() > 3) {
dx->Resize(dx_dim);
}
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims().size() > 3) {
dy->Resize({2, y_mat_width, y_mat_height});
}
auto eigen_in = framework::EigenTensor<T, Rank>::From(x_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(x_matrix_trans);
auto *dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&x_matrix_trans, &dout_matrix, dy);
auto dy_dim = dy->dims();
if (dy_dim.size() > 3) {
dy->Resize(dy_dim);
}
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Description:
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/privc3/circuit_context.h"
namespace paddle {
namespace operators {
template <typename T> class MpcOpKernel : public framework::OpKernelBase {
public:
using ELEMENT_TYPE = T;
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(),
"Mpc protocol is not yet initialized in executor");
std::shared_ptr<aby3::CircuitContext> mpc_ctx(
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx,
[&] { ComputeImpl(ctx); });
}
virtual void ComputeImpl(const framework::ExecutionContext &ctx) const = 0;
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mpc_relu_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
// forward op defination
class MpcReluOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Y", in_dims);
}
};
// forward input & output defination
class MpcReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddOutput("Y", "Output of relu_op");
AddComment(R"DOC(
Mpc Relu Operator.
)DOC");
}
};
// backward op defination
class MpcReluGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Y"));
ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
}
};
// backward type, input & output defination
template <typename T>
class MpcReluGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<T> Apply() const override {
auto *op = new T();
op->SetType("mpc_relu_grad");
op->SetInput("Y", this->Output("Y"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return std::unique_ptr<T>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(mpc_relu, ops::MpcReluOp, ops::MpcReluOpMaker,
ops::MpcReluGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_relu_grad, ops::MpcReluGradOp);
REGISTER_OP_CPU_KERNEL(mpc_relu, ops::MpcReluKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_relu_grad, ops::MpcReluGradKernel<CPU, int64_t>);
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
ABY3 implementation.
../psi/aes.cc
\ No newline at end of file
../psi/aes.h
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
../psi/prng.cc
\ No newline at end of file
../psi/prng.h
\ No newline at end of file
此差异已折叠。
../psi/rand_utils.cc
\ No newline at end of file
../psi/rand_utils.h
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册