...
 
Commits (24)
    https://gitcode.net/primihub/primihub/-/commit/97b32059fce02eb6df123e9ebe6593ba2081702d update dockerfile and add platform_machine to requirements.txt (#562) 2023-07-12T16:49:54+08:00 fuxingbit zjning95@126.com * update dockerfile and add platform_machine to requirements.txt https://gitcode.net/primihub/primihub/-/commit/259bbbb1bb2c616ea2adc9181087f73e389f52d0 support training VFL logistic regression using CKKS (#559) 2023-07-13T16:08:01+08:00 Xuefeng Xu xuxf100@qq.com * support training VFL logistic regression using CKKS * add l2 regularization * print metrics * update ckks params * fix aarch64 https://gitcode.net/primihub/primihub/-/commit/ea477bf54e7368719abfb1806be6f3e58fe1c154 run kkrt in arm (#565) 2023-07-13T16:47:07+08:00 phoenix20162016 cuibo20062006@163.com https://gitcode.net/primihub/primihub/-/commit/7a48dfa019d01b9b03de3874d790608c2d135f12 save pir result add colum name (#558) 2023-07-13T18:37:50+08:00 phoenix20162016 cuibo20062006@163.com https://gitcode.net/primihub/primihub/-/commit/791d525756b0c86bfb48ad04146564483455bc32 update command link and Makefile (#568) 2023-07-13T19:15:49+08:00 fuxingbit zjning95@126.com * update command link * workflow install make on mac amd64 https://gitcode.net/primihub/primihub/-/commit/0f268ff9273619400e03b43bbe852edcc575f9fc check if data_set is empty before submit task (#570) 2023-07-14T11:06:40+08:00 Xuefeng Xu xuxf100@qq.com https://gitcode.net/primihub/primihub/-/commit/7b9dde2d4ba883b7de0fc4d97e6cb126d4e9f7d4 fix compile error (#571) 2023-07-14T11:29:34+08:00 phoenix20162016 cuibo20062006@163.com https://gitcode.net/primihub/primihub/-/commit/1397a7f097a29b8769c57b728a33b7078822386f use less ciphertext multiplication in HFL Paillier (#564) 2023-07-24T11:45:59+08:00 Xuefeng Xu xuxf100@qq.com https://gitcode.net/primihub/primihub/-/commit/dd735d73c252f3d618eeae59240ceebaa5c2edc9 Migrate ABY3 (#574) 2023-07-24T12:51:26+08:00 phoenix20162016 cuibo20062006@163.com * move ABY3 to new repo <a href="https://github.com/primihub/aby3/tree/bazel_branch" rel="nofollow noreferrer noopener" target="_blank">https://github.com/primihub/aby3/tree/bazel_branch</a> * move cryptoTools to new repo <a href="https://github.com/primihub/cryptoTools/tree/bazel_branch" rel="nofollow noreferrer noopener" target="_blank">https://github.com/primihub/cryptoTools/tree/bazel_branch</a> https://gitcode.net/primihub/primihub/-/commit/9c75c147cb5c94fc0383027f515c956971731fb5 PSI kkrt16 enabled on macos (#576) 2023-07-24T21:58:22+08:00 phoenix20162016 cuibo20062006@163.com * upgrage relic to 0.6.0 * enable kkrt16 on macos https://gitcode.net/primihub/primihub/-/commit/ca7867586607e04891eeb2772a496feae5d20587 add BOM header to csv file (#577) 2023-07-25T13:19:01+08:00 phoenix20162016 cuibo20062006@163.com https://gitcode.net/primihub/primihub/-/commit/aa37321139b6952954ddb17cde57d396df969365 generate keyword PIR db offline, and using generated db on online task (#579) 2023-07-27T09:40:14+08:00 phoenix20162016 cuibo20062006@163.com https://gitcode.net/primihub/primihub/-/commit/a470b3f3e5bc776a2f59a0c2737314247bdbc8c7 support multiclass VFL CKKS logistic regression (#578) 2023-07-28T10:15:37+08:00 Xuefeng Xu xuxf100@qq.com * support multiclass VFL CKKS logistic regression https://gitcode.net/primihub/primihub/-/commit/38abb8087474572b0c208b35e5f9fc05e686aaae split comm interface and implement for server 2023-07-28T18:03:18+08:00 phoenix20162016 cuibo20062006@163.com https://gitcode.net/primihub/primihub/-/commit/7a532336093e5e4de8d7df42cbb060aa21444c90 Merge pull request #581 from phoenix20162016/split_node_impl 2023-07-28T18:14:43+08:00 phoenix20162016 cuibo20062006@163.com split comm interface and implement for server https://gitcode.net/primihub/primihub/-/commit/fc648516183f71615cfced428a158bdf949d198f fix bug when using cuda device (#582) 2023-07-29T12:46:34+08:00 Xuefeng Xu xuxf100@qq.com https://gitcode.net/primihub/primihub/-/commit/326cfe333653eaae6371d88ff79f1879bdb2a0d1 split comm interface and implement for server (#580) 2023-07-30T13:41:11+08:00 phoenix20162016 cuibo20062006@163.com https://gitcode.net/primihub/primihub/-/commit/240b8564eb18b02b20613c7dfec7233162009f6c optimize read meta for csv file (#584) 2023-08-02T21:10:47+08:00 phoenix20162016 cuibo20062006@163.com https://gitcode.net/primihub/primihub/-/commit/2dccdef5acd87f0f342ae83f5a984b737657a5ab fix bug for post process psi data 2023-08-03T16:09:05+08:00 phoenix20162016 cuibo20062006@163.com https://gitcode.net/primihub/primihub/-/commit/9366fb0b61e4461453b7b8cda3929944a7608f4f Merge pull request #585 from phoenix20162016/fix_post_process 2023-08-03T16:11:37+08:00 phoenix20162016 cuibo20062006@163.com Fix post process https://gitcode.net/primihub/primihub/-/commit/79047f7c32f6002daad995aa1394cf47e44e68e8 check bom 2023-08-03T16:59:54+08:00 phoenix20162016 cuibo20062006@163.com https://gitcode.net/primihub/primihub/-/commit/f799518204ecb34ac6eaa0b03e02e5e69cbdf94a Merge pull request #586 from phoenix20162016/fix_post_process 2023-08-03T17:01:19+08:00 phoenix20162016 cuibo20062006@163.com check bom https://gitcode.net/primihub/primihub/-/commit/7544a2de4b6f04d5c7c8bd60fc0b5225f5f23531 check bom 2023-08-03T17:14:10+08:00 phoenix20162016 cuibo20062006@163.com https://gitcode.net/primihub/primihub/-/commit/f8b8d3159e209037bd146a726065d9553a44b6f5 Merge pull request #587 from phoenix20162016/fix_post_process 2023-08-03T17:15:08+08:00 phoenix20162016 cuibo20062006@163.com check bom
......@@ -7,6 +7,11 @@ build:linux_x86_64 --copt=-DENABLE_SSE
build:linux_x86_64 --define cpu=amd64
build:linux_x86_64 --define cpu_arch=x86_64
build:linux_x86_64 --define microsoft-apsi=true
build:linux_x86_64 --copt=-maes
build:linux_x86_64 --copt=-msse2
build:linux_x86_64 --copt=-msse3
build:linux_x86_64 --copt=-msse4.1
build:linux_x86_64 --copt=-mpclmul
#build:linux_x86_64 --define enable_mysql_driver=true
build:linux_aarch64 --cxxopt=-std=c++17
......@@ -33,6 +38,11 @@ build:linux_asan --copt -g
build:linux_asan --copt -fno-omit-frame-pointer
build:linux_asan --linkopt -fsanitize=address
build:linux_asan --linkopt -static-libasan
build:linux_asan --copt=-maes
build:linux_asan --copt=-msse2
build:linux_asan --copt=-msse3
build:linux_asan --copt=-msse4.1
build:linux_asan --copt=-mpclmul
build:linux_asan --define enable_mysql_driver=true
......
......@@ -23,17 +23,17 @@ jobs:
bash pre_build.sh
bazel build --config=linux_`arch` \
--cxxopt=-DMPC_SOCKET_CHANNEL \
//test/primihub/algorithm:logistic_test \
//test/primihub/algorithm:maxpool_test \
//test/primihub/algorithm:falcon_lenet_test \
//test/primihub/common/type:common_test \
//test/primihub/util:network_test
//test/primihub/algorithm:logistic_test
# //test/primihub/algorithm:maxpool_test \
# //test/primihub/algorithm:falcon_lenet_test \
# //test/primihub/common/type:common_test \
# //test/primihub/util:network_test
./bazel-bin/test/primihub/algorithm/logistic_test
./bazel-bin/test/primihub/algorithm/maxpool_test
./bazel-bin/test/primihub/algorithm/falcon_lenet_test
./bazel-bin/test/primihub/common/type/common_test
./bazel-bin/test/primihub/util/network_test
#./bazel-bin/test/primihub/algorithm/maxpool_test
#./bazel-bin/test/primihub/algorithm/falcon_lenet_test
#./bazel-bin/test/primihub/common/type/common_test
#./bazel-bin/test/primihub/util/network_test
# bazel test --test_output=all --config=linux test_opt_paillier_c2py
# bazel test --test_output=all --config=linux test_opt_paillier_pack_c2py
......@@ -45,26 +45,24 @@ jobs:
- name: bazel build
run: |
# cc_binary
bazel build --config=linux_`arch` //:node \
//:py_main \
//:py_main \
//:cli \
//src/primihub/pybind_warpper:opt_paillier_c2py \
//src/primihub/pybind_warpper:linkcontext
make
build-on-ubuntu-arm64:
runs-on: [self-hosted, Linux, ARM64]
steps:
- uses: actions/checkout@v3
- name: bazel test
run: |
bash pre_build.sh
bazel build --config=linux_`arch` \
--cxxopt=-DMPC_SOCKET_CHANNEL \
//test/primihub/algorithm:logistic_test
./bazel-bin/test/primihub/algorithm/logistic_test
- name: bazel build
run: |
# cc_binary
bash pre_build.sh
bazel build --config=linux_`arch` //:node \
//:py_main \
//:cli \
//src/primihub/pybind_warpper:linkcontext \
//src/primihub/pybind_warpper:opt_paillier_c2py
make
build-on-mac-amd64:
......@@ -76,14 +74,16 @@ jobs:
- name: set PYTHON LINK_OPTS
run: ./pre_build.sh
shell: bash
- name: Install make
run: brew install make
- name: bazel test
run: |
mv -f WORKSPACE_GITHUB WORKSPACE
bazel build --config=darwin_x86_64 \
--cxxopt=-DMPC_SOCKET_CHANNEL \
//test/primihub/algorithm:logistic_test \
//test/primihub/common/type:common_test \
//test/primihub/util:network_test
//test/primihub/algorithm:logistic_test
# //test/primihub/common/type:common_test \
# //test/primihub/util:network_test
./bazel-bin/test/primihub/algorithm/logistic_test
#./bazel-bin/test/primihub/common/type/common_test
......@@ -95,9 +95,5 @@ jobs:
- name: bazel build
run: |
# cc_binary
bazel build --config=darwin_x86_64 //:node \
//:py_main \
//:cli \
//src/primihub/pybind_warpper:opt_paillier_c2py \
//src/primihub/pybind_warpper:linkcontext
make
......@@ -17,13 +17,8 @@ jobs:
run: |
# cc_binary
bash pre_build.sh
bazel build --config=linux_`arch` --define enable_mysql_driver=true //:node \
//:py_main \
//:cli \
//src/primihub/pybind_warpper:opt_paillier_c2py \
//src/primihub/protos:worker_py_pb2_grpc \
//src/primihub/protos:service_py_pb2_grpc \
//src/primihub/pybind_warpper:linkcontext
make mysql=y protos=y
#copy generated py pb to python dir
cp -f bazel-bin/src/primihub/pybind_warpper/linkcontext.so python
cp -f bazel-bin/src/primihub/pybind_warpper/opt_paillier_c2py.so python
......@@ -32,6 +27,8 @@ jobs:
tar zcf primihub-linux-amd64.tar.gz bazel-bin/cli \
bazel-bin/node \
primihub-cli \
primihub-node \
bazel-bin/py_main \
bazel-bin/src/primihub/pybind_warpper/opt_paillier_c2py.so \
bazel-bin/src/primihub/pybind_warpper/linkcontext.so \
......@@ -64,14 +61,8 @@ jobs:
run: |
# cc_binary
bash pre_build.sh
bazel build --config=linux_`arch` //:node \
//:py_main \
//:cli \
//src/primihub/pybind_warpper:opt_paillier_c2py \
//src/primihub/pybind_warpper:linkcontext \
//src/primihub/protos:worker_py_pb2_grpc \
//src/primihub/protos:service_py_pb2_grpc
make protos=y
cp -f bazel-bin/src/primihub/pybind_warpper/linkcontext.so python
cp -f bazel-bin/src/primihub/pybind_warpper/opt_paillier_c2py.so python
cp -f bazel-bin/src/primihub/protos/*.py \
......@@ -79,6 +70,8 @@ jobs:
tar zcf primihub-linux-arm64.tar.gz bazel-bin/cli \
bazel-bin/node \
primihub-cli \
primihub-node \
bazel-bin/py_main \
bazel-bin/src/primihub/pybind_warpper/opt_paillier_c2py.so \
bazel-bin/src/primihub/pybind_warpper/linkcontext.so \
......@@ -109,6 +102,8 @@ jobs:
- uses: actions/checkout@v3
- name: Setup bazelisk
uses: bazelbuild/setup-bazelisk@v2
- name: Install make
run: brew install make
- name: set PYTHON LINK_OPTS
run: ./pre_build.sh
shell: bash
......@@ -116,14 +111,7 @@ jobs:
- name: bazel build
run: |
# cc_binary
bazel build --config=darwin_x86_64 //:node \
//:py_main \
//:cli \
//src/primihub/pybind_warpper:opt_paillier_c2py \
//src/primihub/pybind_warpper:linkcontext
# fix in future
#//src/primihub/protos:worker_py_pb2_grpc \
#//src/primihub/protos:service_py_pb2_grpc
make
cp -f bazel-bin/src/primihub/pybind_warpper/linkcontext.so python
cp -f bazel-bin/src/primihub/pybind_warpper/opt_paillier_c2py.so python
......@@ -132,6 +120,8 @@ jobs:
tar zcf primihub-darwin-amd64.tar.gz bazel-bin/cli \
bazel-bin/node \
primihub-cli \
primihub-node \
bazel-bin/py_main \
bazel-bin/src/primihub/pybind_warpper/opt_paillier_c2py.so \
bazel-bin/src/primihub/pybind_warpper/linkcontext.so \
......@@ -158,13 +148,7 @@ jobs:
run: |
# cc_binary
bash pre_build.sh
bazel build --config=darwin_arm64 //:node \
//:py_main \
//:cli \
//src/primihub/pybind_warpper:opt_paillier_c2py \
//src/primihub/pybind_warpper:linkcontext \
//src/primihub/protos:worker_py_pb2_grpc \
//src/primihub/protos:service_py_pb2_grpc
make protos=y
cp -f bazel-bin/src/primihub/pybind_warpper/linkcontext.so python
cp -f bazel-bin/src/primihub/pybind_warpper/opt_paillier_c2py.so python
......@@ -173,6 +157,8 @@ jobs:
tar zcf primihub-darwin-arm64.tar.gz bazel-bin/cli \
bazel-bin/node \
primihub-cli \
primihub-node \
bazel-bin/py_main \
bazel-bin/src/primihub/pybind_warpper/opt_paillier_c2py.so \
bazel-bin/src/primihub/pybind_warpper/linkcontext.so \
......
......@@ -33,13 +33,11 @@ ADD . /src
# Bazel build primihub-node & primihub-cli & paillier shared library
RUN bash pre_build.sh \
&& mv -f WORKSPACE_GITHUB WORKSPACE \
&& bazel build --config=linux_`arch` //:node \
//:py_main \
//:cli \
//src/primihub/pybind_warpper:opt_paillier_c2py \
//src/primihub/pybind_warpper::linkcontext \
&& make mysql=y \
&& tar zcf bazel-bin.tar.gz bazel-bin/cli \
bazel-bin/node \
primihub-cli \
primihub-node \
bazel-bin/py_main \
bazel-bin/src/primihub/pybind_warpper/opt_paillier_c2py.so \
bazel-bin/src/primihub/pybind_warpper/linkcontext.so \
......@@ -64,9 +62,7 @@ WORKDIR /app
# Copy opt_paillier_c2py.so linkcontext.so to /app/python, this enable setup.py find it.
RUN tar zxf /opt/bazel-bin.tar.gz \
&& mkdir log \
&& ln -s bazel-bin/node primihub-node \
&& ln -s bazel-bin/cli primihub-cli
&& mkdir log
WORKDIR /app/python
......
# Call this Dockerfile via the build_local.sh script
FROM ubuntu:20.04
FROM primihub/primihub-base
ENV LANG c.UTF-8
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
RUN apt-get update \
&& apt-get install -y python3 python3-dev libgmp-dev python3-pip libzmq5 tzdata libmysqlclient-dev \
&& ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& rm -rf /var/lib/apt/lists/*
# Change WorkDir to /app
WORKDIR /app
ADD bazel-bin.tar.gz ./
COPY src/primihub/protos/ src/primihub/protos/
# Make symlink to primihub-node & primihub-cli
RUN mkdir log \
&& ln -s bazel-bin/node primihub-node \
&& ln -s bazel-bin/cli primihub-cli
WORKDIR /app/python
RUN python3 -m pip install --upgrade pip \
&& python3 -m pip install -r requirements.txt -i https://pypi.douban.com/simple/ \
&& python3 setup.py develop \
&& rm -rf /root/.cache/pip/
WORKDIR /app
# ENV PYTHONPATH=/usr/lib/python3.9/site-packages/:$TARGET_PATH
RUN mkdir log \
&& cd python \
&& python3 setup.py develop
# gRPC server port
EXPOSE 50050
......@@ -4,7 +4,7 @@ ENV DEBIAN_FRONTEND=noninteractive
# Install python3 and GCC openmp (Depends with cryptFlow2 library)
RUN apt-get update \
&& apt-get install -y python3 python3-dev libgmp-dev python3-pip tzdata libmysqlclient-dev \
&& apt-get install -y python3 python3-dev libgmp-dev python3-pip tzdata wget libmysqlclient-dev \
&& ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& rm -rf /var/lib/apt/lists/*
......@@ -13,16 +13,17 @@ WORKDIR /app
COPY primihub-linux-amd64.tar.gz primihub-linux-arm64.tar.gz /opt/
COPY src/primihub/protos/ src/primihub/protos/
RUN ARCH=`arch | sed s/aarch64/arm64/ | sed s/x86_64/amd64/` \
&& tar zxf /opt/primihub-linux-${ARCH}.tar.gz \
&& mkdir log \
&& ln -s bazel-bin/node primihub-node \
&& ln -s bazel-bin/cli primihub-cli
RUN tar zxf /opt/primihub-linux-$(dpkg --print-architecture).tar.gz \
&& mkdir log
WORKDIR /app/python
RUN python3 -m pip install --upgrade pip \
&& python3 -m pip install -r requirements.txt \
&& if [ "$(dpkg --print-architecture)" = "arm64" ]; then \
wget https://primihub.oss-cn-beijing.aliyuncs.com/dev/tenseal-0.3.14-cp38-cp38-linux_aarch64.whl \
&& pip install tenseal-0.3.14-cp38-cp38-linux_aarch64.whl; \
fi \
&& python3 setup.py install \
&& rm -rf /root/.cache/pip/
......
BUILD_FLAG ?=
TARGET := //:node \
//:cli \
//src/primihub/cli:reg_cli \
//src/primihub/pybind_warpper:linkcontext \
//src/primihub/pybind_warpper:opt_paillier_c2py \
//:py_main
//:cli \
//src/primihub/cli:reg_cli \
//src/primihub/pybind_warpper:linkcontext \
//src/primihub/pybind_warpper:opt_paillier_c2py \
//:py_main
ifeq ($(mysql), y)
BUILD_FLAG += --define enable_mysql_driver=true
endif
ifeq ($(protos), y)
TARGET += //src/primihub/protos:worker_py_pb2_grpc //src/primihub/protos:service_py_pb2_grpc
endif
release:
bazel build --config=PLATFORM_HARDWARE ${TARGET}
bazel build --config=PLATFORM_HARDWARE $(BUILD_FLAG) ${TARGET}
rm -f primihub-cli
ln -s -f bazel-bin/cli primihub-cli
rm -f primihub-node
ln -s -f bazel-bin/node primihub-node
#linux_x86_64:
# bazel build --config=linux_x86_64 ${TARGET}
......
......@@ -11,6 +11,15 @@ filegroup(
)
"""
http_archive(
name = "platforms",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/platforms/releases/download/0.0.6/platforms-0.0.6.tar.gz",
"https://github.com/bazelbuild/platforms/releases/download/0.0.6/platforms-0.0.6.tar.gz",
],
sha256 = "5308fc1d8865406a49427ba24a9ab53087f17f5266a7aabbfc28823f3916e1ca",
)
http_archive(
name = "rules_foreign_cc",
sha256 = "484fc0e14856b9f7434072bc2662488b3fe84d7798a5b7c92a1feb1a0fa8d088",
......@@ -201,13 +210,13 @@ http_archive(
http_archive(
name = "bzip2",
build_file = "//bazel:bzip2.BUILD",
sha256 = "ab5a03176ee106d3f0fa90e381da478ddae405918153cca248e682cd0c4a2269",
strip_prefix = "bzip2-1.0.8",
urls = [
"https://primihub.oss-cn-beijing.aliyuncs.com/bzip2-1.0.8.tar.gz"
],
name = "bzip2",
build_file = "//bazel:bzip2.BUILD",
sha256 = "ab5a03176ee106d3f0fa90e381da478ddae405918153cca248e682cd0c4a2269",
strip_prefix = "bzip2-1.0.8",
urls = [
"https://primihub.oss-cn-beijing.aliyuncs.com/bzip2-1.0.8.tar.gz"
],
)
......@@ -428,27 +437,34 @@ http_archive(
)
# libPSI start
new_git_repository(
git_repository(
name = "osu_libpsi",
build_file = "//bazel:BUILD.libpsi",
commit = "4c5d5a3e49533c8547dcd4869e6a9842b6ce5b90",
#branch = "bazel_branch",
commit = "d56340b515989832983ac84132fef1a03184a9bc",
remote = "https://gitee.com/primihub/libPSI.git",
)
# libote
new_git_repository(
name = "osu_libote",
build_file = "//bazel:libOTe.BUILD",
commit = "f455eb7bf83034ebca6cab42e3aea9d9b33f8102",
remote = "https://gitee.com/primihub/libOTe.git",
)
new_git_repository(
git_repository(
name = "ladnir_cryptoTools",
build_file = "//bazel:cryptoTools.BUILD",
commit = "d52e05e2e803006256ddb66f48a0d51080f4b285",
# branch = "bazel_branch",
commit = "53026be7bf1f12cb572c3d8ef9c7ee1a21742360",
remote = "https://gitee.com/primihub/cryptoTools.git",
# shallow_since = "1591047380 -0700",
)
git_repository(
name = "com_github_ladnir_aby3",
# branch = "bazel_branch",
commit = "60a5e0f0d35d610792f3ad668fe2e684c9b4d76b",
remote = "https://gitee.com/primihub/aby3.git",
)
# libote
git_repository(
name = "osu_libote",
# branch = "bazel_branch",
commit = "d797f316d94b2931505f1862515b6e161a24cacb",
remote = "https://gitee.com/primihub/libOTe.git",
)
new_git_repository(
......@@ -459,12 +475,13 @@ new_git_repository(
# shallow_since = "1591047380 -0700",
)
new_git_repository(
name = "toolkit_relic",
build_file = "//bazel:BUILD.relic",
remote = "https://gitee.com/orzmzp/relic.git",
commit = "3f616ad64c3e63039277b8c90915607b6a2c504c",
shallow_since = "1581106153 -0800",
remote = "https://gitee.com/primihub/relic.git",
# tag 0.6.0
commit = "d7dcb22846e32172bb94111823bd3358ec9a49aa",
)
http_archive(
......@@ -477,14 +494,6 @@ http_archive(
strip_prefix = "hiredis-392de5d7f97353485df1237872cb682842e8d83f"
)
# libote
# http_archive(
# name = "osu_libote",
# build_file = "//external:libOTe.BUILD",
# #sha256 = "6f021f24136eb177af38af3bf5d53b3592a1fe1e71d1c098318488a85b0afc3a",
# strip_prefix = "libOTe-master",
# urls = ["https://github.com/osu-crypto/libOTe/archive/refs/heads/master.zip"],
# )
# cryptoTools
# http_archive(
......@@ -596,12 +605,14 @@ git_repository(
remote = "https://gitee.com/primihub/cpp-base64.git",
)
# for libp2p
new_git_repository(
name = "com_github_microsoft_gsl_v2_0_0",
build_file = "//bazel:BUILD.gsl_v_2_0_0",
tag = "v2.0.0",
remote = "https://gitee.com/mirrors_microsoft/GSL.git",
http_archive(
name = "com_github_microsoft_gsl",
build_file = "//bazel:BUILD.gsl",
sha256 = "f0e32cb10654fea91ad56bde89170d78cfbf4363ee0b01d8f097de2ba49f6ce9",
strip_prefix = "GSL-4.0.0",
urls = [
"https://primihub.oss-cn-beijing.aliyuncs.com/tools/GSL-4.0.0.tar.gz",
],
)
#python include
......
......@@ -12,6 +12,14 @@ filegroup(
)
"""
http_archive(
name = "platforms",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/platforms/releases/download/0.0.6/platforms-0.0.6.tar.gz",
"https://github.com/bazelbuild/platforms/releases/download/0.0.6/platforms-0.0.6.tar.gz",
],
sha256 = "5308fc1d8865406a49427ba24a9ab53087f17f5266a7aabbfc28823f3916e1ca",
)
http_archive(
name = "rules_foreign_cc",
......@@ -441,30 +449,39 @@ http_archive(
urls = ["https://github.com/google/sparsehash/archive/master.zip"],
)
# libPSI start
new_git_repository(
name = "osu_libpsi",
build_file = "//bazel:BUILD.libpsi",
commit = "4c5d5a3e49533c8547dcd4869e6a9842b6ce5b90",
remote = "https://github.com/primihub/libPSI.git",
git_repository(
name = "osu_libpsi",
#branch = "bazel_branch",
commit = "d56340b515989832983ac84132fef1a03184a9bc",
remote = "https://github.com/primihub/libPSI.git",
)
# libote
new_git_repository(
name = "osu_libote",
build_file = "//bazel:libOTe.BUILD",
commit = "f455eb7bf83034ebca6cab42e3aea9d9b33f8102",
remote = "https://github.com/primihub/libOTe.git",
git_repository(
name = "ladnir_cryptoTools",
# branch = "bazel_branch",
commit = "53026be7bf1f12cb572c3d8ef9c7ee1a21742360",
remote = "https://github.com/primihub/cryptoTools.git",
)
new_git_repository(
name = "ladnir_cryptoTools",
build_file = "//bazel:cryptoTools.BUILD",
commit = "d52e05e2e803006256ddb66f48a0d51080f4b285",
remote = "https://github.com/primihub/cryptoTools.git",
# shallow_since = "1591047380 -0700",
git_repository(
name = "com_github_ladnir_aby3",
# branch = "bazel_branch",
commit = "60a5e0f0d35d610792f3ad668fe2e684c9b4d76b",
remote = "https://gitee.com/primihub/aby3.git",
)
# libote
git_repository(
name = "osu_libote",
# branch = "bazel_branch",
commit = "d797f316d94b2931505f1862515b6e161a24cacb",
remote = "https://github.com/primihub/libOTe.git",
)
new_git_repository(
name = "github_ntl",
build_file = "//bazel:ntl.BUILD",
......@@ -474,11 +491,11 @@ new_git_repository(
)
new_git_repository(
name = "toolkit_relic",
build_file = "//bazel:BUILD.relic",
remote = "https://github.com/relic-toolkit/relic.git",
commit = "3f616ad64c3e63039277b8c90915607b6a2c504c",
shallow_since = "1581106153 -0800",
name = "toolkit_relic",
build_file = "//bazel:BUILD.relic",
remote = "https://github.com/primihub/relic.git",
# tag 0.6.0
commit = "d7dcb22846e32172bb94111823bd3358ec9a49aa",
)
http_archive(
......@@ -489,24 +506,6 @@ http_archive(
strip_prefix = "hiredis-392de5d7f97353485df1237872cb682842e8d83f"
)
# libote
#http_archive(
# name = "osu_libote",
# build_file = "//external:libOTe.BUILD",
# #sha256 = "6f021f24136eb177af38af3bf5d53b3592a1fe1e71d1c098318488a85b0afc3a",
# strip_prefix = "libOTe-master",
# urls = ["https://github.com/osu-crypto/libOTe/archive/refs/heads/master.zip"],
#)
# cryptoTools
#http_archive(
# name = "ladnir_cryptoTools",
# build_file = "//external:cryptoTools.BUILD",
# #sha256 = "6f021f24136eb177af38af3bf5d53b3592a1fe1e71d1c098318488a85b0afc3a",
# strip_prefix = "cryptoTools-master",
# urls = ["https://github.com/ladnir/cryptoTools/archive/refs/heads/master.zip"],
#)
#PSI
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
......@@ -514,7 +513,7 @@ git_repository(
name = "org_openmined_psi",
remote = "https://github.com/primihub/PSI.git",
branch = "master",
init_submodules = True,
init_submodules = False,
)
load("@org_openmined_psi//private_set_intersection:preload.bzl", "psi_preload")
......@@ -533,7 +532,7 @@ git_repository(
name = "org_openmined_pir",
remote = "https://github.com/primihub/PIR.git",
branch = "master",
init_submodules = True,
init_submodules = False,
)
load("@org_openmined_pir//pir:preload.bzl", "pir_preload")
......@@ -616,11 +615,15 @@ git_repository(
remote = "https://github.com/primihub/cpp-base64.git",
)
new_git_repository(
name = "com_github_microsoft_gsl_v2_0_0",
build_file = "//bazel:BUILD.gsl_v_2_0_0",
tag = "v2.0.0",
remote = "https://github.com/microsoft/GSL.git",
http_archive(
name = "com_github_microsoft_gsl",
build_file = "//bazel:BUILD.gsl",
sha256 = "f0e32cb10654fea91ad56bde89170d78cfbf4363ee0b01d8f097de2ba49f6ce9",
strip_prefix = "GSL-4.0.0",
urls = [
"https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.tar.gz",
"https://primihub.oss-cn-beijing.aliyuncs.com/tools/GSL-4.0.0.tar.gz",
],
)
#python include
......
# Description:
package(default_visibility = ["//visibility:public"])
cc_library(
name = "gsl_lib",
textual_hdrs = glob([
"include/gsl",
"include/gsl/*",
]),
includes = ["include"],
copts = [
"--std=c++17",
],
)
\ No newline at end of file
......@@ -6,21 +6,42 @@ filegroup(
visibility = ["//visibility:public"]
)
cmake(
name = "relic",
cache_entries = {
"CMAKE_INSTALL_LIBDIR": "lib",
},
cache_entries = select({
"@platforms//cpu:aarch64": {
"CMAKE_INSTALL_LIBDIR": "lib",
"WSIZE": "64",
},
"@platforms//cpu:x86_64": {
"CMAKE_INSTALL_LIBDIR": "lib",
"WSIZE": "64",
},
"@platforms//cpu:i386": {
"CMAKE_INSTALL_LIBDIR": "lib",
"WSIZE": "32",
},
"@platforms//cpu:arm": {
"CMAKE_INSTALL_LIBDIR": "lib",
"WSIZE": "32",
},
"//conditions:default": {
"CMAKE_INSTALL_LIBDIR": "lib",
"WSIZE": "32",
},
}),
build_args = [
"-j4",
],
linkopts = [
"-lpthread",
],
#includes = ["include"],
lib_source = ":src",
out_include_dir = "include/relic",
#out_include_dir = "include/relic",
out_static_libs = ["librelic_s.a"],
postfix_script = "[ \"$OSTYPE\" == \"linux-gnu\" ] && objcopy --localize-symbol=bn_init $INSTALLDIR/lib/librelic_s.a",
#postfix_script = "[ \"$OSTYPE\" == \"linux-gnu\" ] && objcopy --localize-symbol=bn_init $INSTALLDIR/lib/librelic_s.a",
#postfix_script = "objcopy --localize-symbol=bn_init $INSTALLDIR/lib/librelic_s.a",
visibility = ["//visibility:public"],
)
......@@ -15,13 +15,7 @@ fi
bash pre_build.sh
ARCH=`arch`
bazel build --config=linux_$ARCH --define enable_mysql_driver=true //:node \
//:py_main \
//:cli \
//src/primihub/pybind_warpper:opt_paillier_c2py \
//src/primihub/pybind_warpper:linkcontext
make mysql=y
if [ $? -ne 0 ]; then
echo "Build failed!!!"
......@@ -33,6 +27,8 @@ git rev-parse HEAD >> commit.txt
tar zcf bazel-bin.tar.gz bazel-bin/cli \
bazel-bin/node \
primihub-cli \
primihub-node \
bazel-bin/py_main \
bazel-bin/src/primihub/pybind_warpper/opt_paillier_c2py.so \
bazel-bin/src/primihub/pybind_warpper/linkcontext.so \
......
......@@ -4,9 +4,13 @@ if [ ! -d example ]; then
echo "no example found"
exit -1
fi
case_path="example/FL/logistic_regression"
case_path="example/FL/xgboost"
case_path="example"
case_list=$(ls example)
case_list=$(ls ${case_path})
for case_info in ${case_list[@]}; do
[ -d "example/${case_info}" ] && continue
#echo "./bazel-bin/cli --server=${SERVER_INFO} --task_config_file=example/${case_info}"
./bazel-bin/cli --server="${SERVER_INFO}" --task_config_file="example/${case_info}"
echo "./bazel-bin/cli --server=${SERVER_INFO} --task_config_file=${case_path}/${case_info}"
./bazel-bin/cli --server="${SERVER_INFO}" --task_config_file="${case_path}/${case_info}"
done
{
"party_info": {
"task_manager": "127.0.0.1:50050"
},
"component_params": {
"roles": {
"host": "Bob",
"guest": [
"Charlie"
],
"coordinator": "David"
},
"common_params": {
"model": "VFL_logistic_regression",
"method": "CKKS",
"process": "train",
"task_name": "VFL_logistic_regression_binclass_ckks_train",
"learning_rate": 1e-1,
"alpha": 1e-4,
"epoch": 10,
"shuffle_seed": 0,
"batch_size": 100,
"print_metrics": true
},
"role_params": {
"Bob": {
"data_set": "binclass_vfl_train_host",
"selected_column": null,
"id": "id",
"label": "y",
"model_path": "data/result/host_model.pkl",
"metric_path": "data/result/metrics.json"
},
"Charlie": {
"data_set": "binclass_vfl_train_guest",
"selected_column": null,
"id": "id",
"model_path": "data/result/guest_model.pkl"
},
"David": {
"data_set": "fl_fake_data"
}
}
}
}
\ No newline at end of file
{
"party_info": {
"task_manager": "127.0.0.1:50050"
},
"component_params": {
"roles": {
"host": "Bob",
"guest": [
"Charlie"
],
"coordinator": "David"
},
"common_params": {
"model": "VFL_logistic_regression",
"method": "CKKS",
"process": "train",
"task_name": "VFL_logistic_regression_multiclass_ckks_train",
"learning_rate": 1e-1,
"alpha": 1e-4,
"epoch": 2,
"shuffle_seed": 0,
"batch_size": 100,
"print_metrics": true
},
"role_params": {
"Bob": {
"data_set": "multiclass_vfl_train_host",
"selected_column": null,
"id": "id",
"label": "y",
"model_path": "data/result/host_model.pkl",
"metric_path": "data/result/metrics.json"
},
"Charlie": {
"data_set": "multiclass_vfl_train_guest",
"selected_column": null,
"id": "id",
"model_path": "data/result/guest_model.pkl"
},
"David": {
"data_set": "fl_fake_data"
}
}
}
}
\ No newline at end of file
{
"task_type": "PIR_TASK",
"task_name": "keyword_pir_generate_db_task",
"task_lang": "proto",
"task_code": {
"code_file_path": "",
"code": ""
},
"params": {
"pirType": {
"description": "ID_PIR = 0 [Unimplement]; KEY_PIR = 1;",
"type": "INT32",
"value": 1
},
"DbInfo": {
"description": "create sender db offline",
"type": "STRING",
"value": "data/cache/keyword_pir_server_data"
}
},
"party_datasets": {
"SERVER": {
"SERVER": "keyword_pir_server_data"
}
}
}
......@@ -878,7 +878,7 @@ _quiet = False
# The allowed line length of files.
# This is set by --linelength flag.
_line_length = 100
_line_length = 80
# This allows to use different include order rule than default
_include_order = "default"
......
......@@ -61,12 +61,10 @@ class LogisticRegression:
error = self.predict_prob(x)
idx = np.arange(len(y))
error[idx, y] -= 1
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
db = error.mean(axis=0, keepdims=True)
else:
error = self.predict_prob(x) - y
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
db = error.mean(keepdims=True)
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
db = error.mean(axis=0, keepdims=True)
return dw, db
def gradient_descent(self, x, y):
......@@ -154,7 +152,7 @@ class LogisticRegression_DPSGD(LogisticRegression):
return dw, db
class PaillierFunc:
class Paillier:
def __init__(self, public_key, private_key):
self.public_key = public_key
......@@ -179,12 +177,9 @@ class PaillierFunc:
return [[self.private_key.encrypt(i) for i in pv] for pv in plain_matrix]
class LogisticRegression_Paillier(LogisticRegression, PaillierFunc):
class LogisticRegression_Paillier(LogisticRegression):
def __init__(self, x, y, learning_rate=0.2, alpha=0.0001, *args):
super().__init__(x, y, learning_rate, alpha, *args)
def compute_grad(self, x, y):
def gradient_descent(self, x, y):
if self.multiclass:
error_msg = "Paillier method doesn't support multiclass classification"
logger.error(error_msg)
......@@ -192,16 +187,17 @@ class LogisticRegression_Paillier(LogisticRegression, PaillierFunc):
else:
# Approximate gradient
# First order of taylor expansion: sigmoid(x) = 0.5 + 0.25 * (x.dot(w) + b)
z = 0.5 + 0.25 * (x.dot(self.weight) + self.bias) - y
dw = x.T.dot(z) / x.shape[0] + self.alpha * self.weight
db = z.mean(keepdims=True)
return dw, db
error = 2 + x.dot(self.weight) + self.bias - 4 * y
factor = -self.learning_rate / x.shape[0]
self.weight += (factor * x).T.dot(error) + \
(-self.learning_rate * self.alpha) * self.weight
self.bias += factor * error.sum(keepdims=True)
def BCELoss(self, x, y):
# Approximate loss: L(x) = (0.5 - y) * (x.dot(w) + b)
# Ignore regularization term due to paillier doesn't support ciphertext multiplication
return (0.5 - y).dot(x.dot(self.weight) + self.bias) / x.shape[0]
return ((0.5 - y) / x.shape[0]).dot(x.dot(self.weight) + self.bias)
def CELoss(self, x, y, eps=1e-20):
error_msg = "Paillier method doesn't support multiclass classification"
......
......@@ -15,7 +15,8 @@ from primihub.FL.metrics.hfl_metrics import ks_from_fpr_tpr,\
auc_from_fpr_tpr
from .base import LogisticRegression,\
LogisticRegression_DPSGD,\
LogisticRegression_Paillier
LogisticRegression_Paillier,\
Paillier
class LogisticRegressionClient(BaseModel):
......@@ -24,10 +25,16 @@ class LogisticRegressionClient(BaseModel):
super().__init__(**kwargs)
def run(self):
if self.common_params['process'] == 'train':
process = self.common_params['process']
logger.info(f"process: {process}")
if process == 'train':
self.train()
elif self.common_params['process'] == 'predict':
elif process == 'predict':
self.predict()
else:
error_msg = f"Unsupported process: {process}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def train(self):
# setup communication channels
......@@ -196,7 +203,7 @@ class LogisticRegressionClient(BaseModel):
class Plaintext_Client:
def __init__(self, x, y, learning_rate, alpha, server_channel, *args):
def __init__(self, x, y, learning_rate, alpha, server_channel):
self.model = LogisticRegression(x, y, learning_rate, alpha)
self.param_init(x, y, server_channel)
......@@ -334,15 +341,15 @@ class DPSGD_Client(Plaintext_Client):
return accountant.get_epsilon(target_delta=delta)
class Paillier_Client(Plaintext_Client):
class Paillier_Client(Plaintext_Client, Paillier):
def __init__(self, x, y, learning_rate, alpha,
server_channel):
self.model = LogisticRegression_Paillier(x, y, learning_rate, alpha)
self.param_init(x, y, server_channel)
self.model.public_key = server_channel.recv("public_key")
self.model.set_theta(self.model.encrypt_vector(self.model.get_theta()))
self.public_key = server_channel.recv("public_key")
self.model.set_theta(self.encrypt_vector(self.model.get_theta()))
def send_loss(self, x, y):
# pallier only support compute approximate loss without penalty
......@@ -353,4 +360,4 @@ class Paillier_Client(Plaintext_Client):
def print_metrics(self, x, y):
# print loss
self.send_loss(x, y)
logger.info('no printed metrics during training when using paillier')
logger.info('View metrics at server while using Paillier')
......@@ -9,7 +9,7 @@ from phe import paillier
from primihub.FL.metrics.hfl_metrics import roc_vertical_avg,\
ks_from_fpr_tpr,\
auc_from_fpr_tpr
from .base import PaillierFunc
from .base import Paillier
class LogisticRegressionServer(BaseModel):
......@@ -18,8 +18,14 @@ class LogisticRegressionServer(BaseModel):
super().__init__(**kwargs)
def run(self):
if self.common_params['process'] == 'train':
process = self.common_params['process']
logger.info(f"process: {process}")
if process == 'train':
self.train()
else:
error_msg = f"Unsupported process: {process}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def train(self):
# setup communication channels
......@@ -213,8 +219,7 @@ class Plaintext_DPSGD_Server:
logger.info(f"loss={loss}, acc={acc}")
class Paillier_Server(Plaintext_DPSGD_Server,
PaillierFunc):
class Paillier_Server(Plaintext_DPSGD_Server, Paillier):
def __init__(self, alpha, n_length, client_channel):
Plaintext_DPSGD_Server.__init__(self, alpha, client_channel)
......
import tenseal as ts
import numpy as np
from primihub.utils.logger_util import logger
from .base import LogisticRegression
......@@ -6,8 +7,11 @@ from .base import LogisticRegression
class LogisticRegression_Host_Plaintext(LogisticRegression):
def __init__(self, x, y, learning_rate=0.2, alpha=0.0001):
super().__init__(x, y, learning_rate, alpha)
if self.multiclass:
self.output_dim = self.weight.shape[1]
else:
self.output_dim = 1
def compute_z(self, x, guest_z):
z = x.dot(self.weight) + self.bias
......@@ -30,8 +34,8 @@ class LogisticRegression_Host_Plaintext(LogisticRegression):
return error
def compute_regular_loss(self, guest_regular_loss):
return 0.5 * self.alpha * (self.weight ** 2).sum() \
+ sum(guest_regular_loss)
return (0.5 * self.alpha) * (self.weight ** 2).sum() \
+ guest_regular_loss
def BCELoss(self, y, z, regular_loss):
return (np.maximum(z, 0.).sum() - y.dot(z) +
......@@ -50,12 +54,8 @@ class LogisticRegression_Host_Plaintext(LogisticRegression):
return self.BCELoss(y, z, regular_loss)
def compute_grad(self, x, error):
if self.multiclass:
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
db = error.mean(axis=0, keepdims=True)
else:
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
db = error.sum(keepdims=True)
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
db = error.mean(axis=0, keepdims=True)
return dw, db
def gradient_descent(self, x, error):
......@@ -67,6 +67,60 @@ class LogisticRegression_Host_Plaintext(LogisticRegression):
self.gradient_descent(x, error)
class LogisticRegression_Host_CKKS(LogisticRegression_Host_Plaintext):
def compute_enc_z(self, x, guest_z):
z = self.weight.mm(x.T) + self.bias
z += sum(guest_z)
return z
def compute_error(self, y, z):
if self.multiclass:
error = z + 1 - self.output_dim * np.eye(self.output_dim)[y].T
else:
error = 2. + z - 4 * y
return error
def compute_regular_loss(self, guest_regular_loss):
if self.multiclass and isinstance(self.weight, ts.CKKSTensor):
return (0.5 * self.alpha) * (self.weight ** 2).sum().sum() \
+ guest_regular_loss
else:
return super().compute_regular_loss(guest_regular_loss)
def BCELoss(self, y, z, regular_loss):
return z.dot((0.5 - y) / y.shape[0]) + regular_loss
def CELoss(self, y, z, regular_loss):
factor = 1. / (y.shape[0] * self.output_dim)
if isinstance(z, ts.CKKSTensor):
# Todo: fix encrypted1 and encrypted2 parameter mismatch
return (z * factor \
- z * ((np.eye(self.output_dim)[y].T
+ np.random.normal(0, 1e-4, (self.output_dim, y.shape[0]))) \
* factor)).sum().sum() \
+ regular_loss
else:
return np.sum(np.sum(z, axis=1) - z[np.arange(len(y)), y]) \
* factor + regular_loss
def loss(self, y, z, regular_loss):
if self.multiclass:
return self.CELoss(y, z, regular_loss)
else:
return self.BCELoss(y, z, regular_loss)
def gradient_descent(self, x, error):
if self.multiclass:
factor = -self.learning_rate / (self.output_dim * x.shape[0])
self.bias += error.sum(axis=1).reshape((self.output_dim, 1)) * factor
else:
factor = -self.learning_rate / x.shape[0]
self.bias += error.sum() * factor
self.weight += error.mm(factor * x) \
+ (-self.learning_rate * self.alpha) * self.weight
class LogisticRegression_Guest_Plaintext:
def __init__(self, x, learning_rate=0.2, alpha=0.0001, output_dim=1):
......@@ -85,7 +139,7 @@ class LogisticRegression_Guest_Plaintext:
return x.dot(self.weight)
def compute_regular_loss(self):
return 0.5 * self.alpha * (self.weight ** 2).sum()
return (0.5 * self.alpha) * (self.weight ** 2).sum()
def compute_grad(self, x, error):
dw = x.T.dot(error) / x.shape[0] + self.alpha * self.weight
......@@ -96,4 +150,28 @@ class LogisticRegression_Guest_Plaintext:
self.weight -= self.learning_rate * dw
def fit(self, x, error):
self.gradient_descent(x, error)
\ No newline at end of file
self.gradient_descent(x, error)
class LogisticRegression_Guest_CKKS(LogisticRegression_Guest_Plaintext):
def __init__(self, x, learning_rate=0.2, alpha=0.0001, output_dim=1):
super().__init__(x, learning_rate, alpha, output_dim)
self.output_dim = output_dim
def compute_enc_z(self, x):
return self.weight.mm(x.T)
def compute_regular_loss(self):
if self.multiclass and isinstance(self.weight, ts.CKKSTensor):
return (0.5 * self.alpha) * (self.weight ** 2).sum().sum()
else:
return super().compute_regular_loss()
def gradient_descent(self, x, error):
if self.multiclass:
factor = -self.learning_rate / (self.output_dim * x.shape[0])
else:
factor = -self.learning_rate / x.shape[0]
self.weight += error.mm(factor * x) + \
(-self.learning_rate * self.alpha) * self.weight
from primihub.FL.utils.net_work import GrpcClient, MultiGrpcClients
from primihub.FL.utils.base import BaseModel
from primihub.utils.logger_util import logger
import math
import numpy as np
import tenseal as ts
class LogisticRegressionCoordinator(BaseModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def run(self):
process = self.common_params['process']
logger.info(f"process: {process}")
if process == 'train':
self.train()
else:
error_msg = f"Unsupported process: {process}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def train(self):
# setup communication channels
host_channel = GrpcClient(local_party=self.role_params['self_name'],
remote_party=self.roles['host'],
node_info=self.node_info,
task_info=self.task_info)
guest_channel = MultiGrpcClients(local_party=self.role_params['self_name'],
remote_parties=self.roles['guest'],
node_info=self.node_info,
task_info=self.task_info)
# coordinator init
method = self.common_params['method']
if method == 'CKKS':
batch_size = host_channel.recv('batch_size')
coordinator = CKKSCoordinator(batch_size,
host_channel,
guest_channel)
else:
error_msg = f"Unsupported method: {method}"
logger.error(error_msg)
raise RuntimeError(error_msg)
# coordinator training
logger.info("-------- start training --------")
epoch = self.common_params['epoch']
for i in range(epoch):
logger.info(f"-------- epoch {i+1} / {epoch} --------")
coordinator.train()
# print metrics
if self.common_params['print_metrics']:
coordinator.compute_loss()
logger.info("-------- finish training --------")
# decrypt & send plaintext model
coordinator.update_plaintext_model()
class CKKS:
def __init__(self, context):
if isinstance(context, bytes):
context = ts.context_from(context)
self.context = context
self.multiply_depth = context.data.seal_context().first_context_data().chain_index()
def encrypt_vector(self, vector, context=None):
if context:
return ts.ckks_vector(context, vector)
else:
return ts.ckks_vector(self.context, vector)
def encrypt_tensor(self, tensor, context=None):
if context:
return ts.ckks_tensor(context, tensor)
else:
return ts.ckks_tensor(self.context, tensor)
def decrypt(self, ciphertext, secret_key=None):
if ciphertext.context().has_secret_key():
return ciphertext.decrypt()
else:
return ciphertext.decrypt(secret_key)
def load_vector(self, vector):
return ts.ckks_vector_from(self.context, vector)
def load_tensor(self, tensor):
return ts.ckks_tensor_from(self.context, tensor)
class CKKSCoordinator(CKKS):
def __init__(self, batch_size, host_channel, guest_channel):
self.t = 0
self.host_channel = host_channel
self.guest_channel = guest_channel
self.multiclass = host_channel.recv('multiclass')
# set CKKS params
# use larger poly_mod_degree to support more encrypted multiplications
poly_mod_degree = 8192
# the least multiplication per iteration of gradient descent
# more multiplications lead to larger context size
multiply_per_iter = 2
self.max_iter = 1
multiply_depth = multiply_per_iter * self.max_iter
# sum(coeff_mod_bit_sizes) <= max coeff_modulus bit-length
fe_bits_scale = 60
bits_scale = 49
# 60*2 + 49*1*2 = 218 == 218 (for N = 8192 & 128 bit security)
coeff_mod_bit_sizes = [fe_bits_scale] + \
[bits_scale] * multiply_depth + \
[fe_bits_scale]
# create TenSEALContext
logger.info('create CKKS TenSEAL context')
secret_context = ts.context(ts.SCHEME_TYPE.CKKS,
poly_modulus_degree=poly_mod_degree,
coeff_mod_bit_sizes=coeff_mod_bit_sizes)
secret_context.global_scale = pow(2, bits_scale)
secret_context.generate_galois_keys()
context = secret_context.copy()
context.make_context_public()
super().__init__(context)
self.secret_context = secret_context
self.send_public_context()
num_examples = host_channel.recv('num_examples')
self.iter_per_epoch = math.ceil(num_examples / batch_size)
def send_public_context(self):
serialize_context = self.context.serialize()
self.host_channel.send("public_context", serialize_context)
self.guest_channel.send_all("public_context", serialize_context)
def recv_model(self):
if self.multiclass:
host_weight = self.load_tensor(self.host_channel.recv('host_weight'))
host_bias = self.load_tensor(self.host_channel.recv('host_bias'))
guest_weight = self.guest_channel.recv_all('guest_weight')
guest_weight = [self.load_tensor(weight) for weight in guest_weight]
else:
host_weight = self.load_vector(self.host_channel.recv('host_weight'))
host_bias = self.load_vector(self.host_channel.recv('host_bias'))
guest_weight = self.guest_channel.recv_all('guest_weight')
guest_weight = [self.load_vector(weight) for weight in guest_weight]
return host_weight, host_bias, guest_weight
def send_model(self, host_weight, host_bias, guest_weight):
self.host_channel.send('host_weight', host_weight)
self.host_channel.send('host_bias', host_bias)
# send n sub-lists to n parties seperately
self.guest_channel.send_seperately('guest_weight', guest_weight)
def decrypt_model(self, host_weight, host_bias, guest_weight):
host_weight = self.decrypt(host_weight, self.secret_context.secret_key())
host_bias = self.decrypt(host_bias, self.secret_context.secret_key())
guest_weight = [self.decrypt(weight, self.secret_context.secret_key()) \
for weight in guest_weight]
return host_weight, host_bias, guest_weight
def encrypt_model(self, host_weight, host_bias, guest_weight):
if self.multiclass:
host_weight = self.encrypt_tensor(host_weight)
host_bias = self.encrypt_tensor(host_bias)
guest_weight = [self.encrypt_tensor(weight) for weight in guest_weight]
else:
host_weight = self.encrypt_vector(host_weight)
host_bias = self.encrypt_vector(host_bias)
guest_weight = [self.encrypt_vector(weight) for weight in guest_weight]
return host_weight, host_bias, guest_weight
def update_ciphertext_model(self):
host_weight, host_bias, guest_weight = self.recv_model()
host_weight, host_bias, guest_weight = self.decrypt_model(
host_weight, host_bias, guest_weight)
host_weight, host_bias, guest_weight = self.encrypt_model(
host_weight, host_bias, guest_weight)
host_weight = host_weight.serialize()
host_bias = host_bias.serialize()
guest_weight = [weight.serialize() for weight in guest_weight]
self.send_model(host_weight, host_bias, guest_weight)
def update_plaintext_model(self):
host_weight, host_bias, guest_weight = self.recv_model()
host_weight, host_bias, guest_weight = self.decrypt_model(
host_weight, host_bias, guest_weight)
# list to numpy ndarrry
if self.multiclass:
host_weight = np.array(host_weight.tolist()).T
host_bias = np.array(host_bias.tolist()).T
guest_weight = [np.array(weight.tolist()).T for weight in guest_weight]
else:
host_weight = np.array(host_weight)
host_bias = np.array(host_bias)
guest_weight = [np.array(weight) for weight in guest_weight]
self.send_model(host_weight, host_bias, guest_weight)
def train(self):
logger.info(f'iteration {self.t} / {self.max_iter}')
self.t += self.iter_per_epoch
num_dec = self.t // self.max_iter
self.t = self.t % self.max_iter
if self.t == 0:
num_dec -= 1
self.t = self.max_iter
for i in range(num_dec):
logger.warning(f'decrypt model #{i+1}')
self.update_ciphertext_model()
def compute_loss(self):
logger.info(f'iteration {self.t} / {self.max_iter}')
if self.t >= self.max_iter:
self.t = 0
logger.warning('decrypt model')
self.update_ciphertext_model()
if self.multiclass:
loss = self.load_tensor(self.host_channel.recv('loss'))
loss = self.decrypt(loss, self.secret_context.secret_key()).tolist()
else:
loss = self.load_vector(self.host_channel.recv('loss'))
loss = self.decrypt(loss, self.secret_context.secret_key())[0]
logger.info(f'loss={loss}')
\ No newline at end of file
......@@ -5,9 +5,12 @@ from primihub.FL.utils.dataset import read_data, DataLoader
from primihub.utils.logger_util import logger
import pickle
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from .vfl_base import LogisticRegression_Guest_Plaintext,\
LogisticRegression_Guest_CKKS
from .vfl_coordinator import CKKS
from .vfl_base import LogisticRegression_Guest_Plaintext
class LogisticRegressionGuest(BaseModel):
......@@ -15,19 +18,31 @@ class LogisticRegressionGuest(BaseModel):
super().__init__(**kwargs)
def run(self):
if self.common_params['process'] == 'train':
process = self.common_params['process']
logger.info(f"process: {process}")
if process == 'train':
self.train()
elif self.common_params['process'] == 'predict':
elif process == 'predict':
self.predict()
else:
error_msg = f"Unsupported process: {process}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def train(self):
# setup communication channels
remote_party = self.roles[self.role_params['others_role']]
host_channel = GrpcClient(local_party=self.role_params['self_name'],
remote_party=remote_party,
remote_party=self.roles['host'],
node_info=self.node_info,
task_info=self.task_info)
method = self.common_params['method']
if method == 'CKKS':
coordinator_channel = GrpcClient(local_party=self.role_params['self_name'],
remote_party=self.roles['coordinator'],
node_info=self.node_info,
task_info=self.task_info)
# load dataset
selected_column = self.role_params['selected_column']
id = self.role_params['id']
......@@ -37,24 +52,29 @@ class LogisticRegressionGuest(BaseModel):
x = x.values
# guest init
method = self.common_params['method']
batch_size = min(x.shape[0], self.common_params['batch_size'])
if method == 'Plaintext':
guest = Plaintext_Guest(x,
self.common_params['learning_rate'],
self.common_params['alpha'],
host_channel)
elif method == 'CKKS':
guest = CKKS_Guest(x,
self.common_params['learning_rate'],
self.common_params['alpha'],
host_channel,
coordinator_channel)
else:
error_msg = f"Unsupported method: {method}"
logger.error(error_msg)
raise RuntimeError(error_msg)
# data preprocessing
# minmaxscaler
scaler = MinMaxScaler()
# StandardScaler
scaler = StandardScaler()
x = scaler.fit_transform(x)
# guest training
batch_size = min(x.shape[0], self.common_params['batch_size'])
train_dataloader = DataLoader(dataset=x,
label=None,
batch_size=batch_size,
......@@ -73,8 +93,12 @@ class LogisticRegressionGuest(BaseModel):
guest.compute_metrics(x)
logger.info("-------- finish training --------")
# receive plaintext model
if method == 'CKKS':
guest.update_plaintext_model()
# compute final metrics
guest.compute_metrics(x)
guest.compute_final_metrics(x)
# save model for prediction
modelFile = {
......@@ -156,4 +180,91 @@ class Plaintext_Guest:
def compute_metrics(self, x):
self.send_z(x)
self.send_regular_loss()
\ No newline at end of file
def compute_final_metrics(self, x):
self.compute_metrics(x)
class CKKS_Guest(Plaintext_Guest, CKKS):
def __init__(self, x, learning_rate, alpha,
host_channel, coordinator_channel):
self.t = 0
output_dim = self.recv_output_dim(host_channel)
self.model = LogisticRegression_Guest_CKKS(x,
learning_rate,
alpha,
output_dim)
self.recv_public_context(coordinator_channel)
CKKS.__init__(self, self.context)
multiply_per_iter = 2
self.max_iter = self.multiply_depth // multiply_per_iter
self.encrypt_model()
def recv_public_context(self, coordinator_channel):
self.coordinator_channel = coordinator_channel
self.context = coordinator_channel.recv('public_context')
def encrypt_model(self):
if self.model.multiclass:
self.model.weight = self.encrypt_tensor(self.model.weight.T)
else:
self.model.weight = self.encrypt_vector(self.model.weight)
def update_ciphertext_model(self):
self.coordinator_channel.send('guest_weight',
self.model.weight.serialize())
if self.model.multiclass:
self.model.weight = self.load_tensor(
self.coordinator_channel.recv('guest_weight'))
else:
self.model.weight = self.load_vector(
self.coordinator_channel.recv('guest_weight'))
def update_plaintext_model(self):
self.coordinator_channel.send('guest_weight',
self.model.weight.serialize())
self.model.weight = self.coordinator_channel.recv('guest_weight')
def send_enc_z(self, x):
guest_z = self.model.compute_enc_z(x)
self.host_channel.send('guest_z', guest_z.serialize())
def send_enc_regular_loss(self):
if self.model.alpha != 0.:
guest_regular_loss = self.model.compute_regular_loss()
self.host_channel.send('guest_regular_loss',
guest_regular_loss.serialize())
def train(self, x):
logger.info(f'iteration {self.t} / {self.max_iter}')
if self.t >= self.max_iter:
self.t = 0
logger.warning(f'decrypt model')
self.update_ciphertext_model()
self.t += 1
self.send_enc_z(x)
if self.model.multiclass:
error = self.load_tensor(self.host_channel.recv('error'))
else:
error = self.load_vector(self.host_channel.recv('error'))
self.model.fit(x, error)
def compute_metrics(self, x):
logger.info(f'iteration {self.t} / {self.max_iter}')
if self.t >= self.max_iter:
self.t = 0
logger.warning(f'decrypt model')
self.update_ciphertext_model()
self.send_enc_z(x)
self.send_enc_regular_loss()
logger.info('View metrics at coordinator while using CKKS')
def compute_final_metrics(self, x):
super().compute_metrics(x)
\ No newline at end of file
......@@ -11,9 +11,11 @@ import numpy as np
from sklearn import metrics
from primihub.FL.metrics.hfl_metrics import ks_from_fpr_tpr,\
auc_from_fpr_tpr
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from .vfl_base import LogisticRegression_Host_Plaintext
from .vfl_base import LogisticRegression_Host_Plaintext,\
LogisticRegression_Host_CKKS
from .vfl_coordinator import CKKS
class LogisticRegressionHost(BaseModel):
......@@ -22,19 +24,31 @@ class LogisticRegressionHost(BaseModel):
super().__init__(**kwargs)
def run(self):
if self.common_params['process'] == 'train':
process = self.common_params['process']
logger.info(f"process: {process}")
if process == 'train':
self.train()
elif self.common_params['process'] == 'predict':
elif process == 'predict':
self.predict()
else:
error_msg = f"Unsupported process: {process}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def train(self):
# setup communication channels
remote_parties = self.roles[self.role_params['others_role']]
guest_channel = MultiGrpcClients(local_party=self.role_params['self_name'],
remote_parties=remote_parties,
remote_parties=self.roles['guest'],
node_info=self.node_info,
task_info=self.task_info)
method = self.common_params['method']
if method == 'CKKS':
coordinator_channel = GrpcClient(local_party=self.role_params['self_name'],
remote_party=self.roles['coordinator'],
node_info=self.node_info,
task_info=self.task_info)
# load dataset
selected_column = self.role_params['selected_column']
id = self.role_params['id']
......@@ -46,24 +60,30 @@ class LogisticRegressionHost(BaseModel):
x = x.values
# host init
method = self.common_params['method']
batch_size = min(x.shape[0], self.common_params['batch_size'])
if method == 'Plaintext':
host = Plaintext_Host(x, y,
self.common_params['learning_rate'],
self.common_params['alpha'],
guest_channel)
elif method == 'CKKS':
coordinator_channel.send('batch_size', batch_size)
host = CKKS_Host(x, y,
self.common_params['learning_rate'],
self.common_params['alpha'],
guest_channel,
coordinator_channel)
else:
error_msg = f"Unsupported method: {method}"
logger.error(error_msg)
raise RuntimeError(error_msg)
# data preprocessing
# minmaxscaler
scaler = MinMaxScaler()
# StandardScaler
scaler = StandardScaler()
x = scaler.fit_transform(x)
# host training
batch_size = min(x.shape[0], self.common_params['batch_size'])
train_dataloader = DataLoader(dataset=x,
label=y,
batch_size=batch_size,
......@@ -82,8 +102,12 @@ class LogisticRegressionHost(BaseModel):
host.compute_metrics(x, y)
logger.info("-------- finish training --------")
# receive plaintext model
if method == 'CKKS':
host.update_plaintext_model()
# compute final metrics
trainMetrics = host.compute_metrics(x, y)
trainMetrics = host.compute_final_metrics(x, y)
metric_path = self.role_params['metric_path']
check_directory_exist(metric_path)
logger.info(f"metric path: {metric_path}")
......@@ -171,13 +195,7 @@ class Plaintext_Host:
def send_output_dim(self, guest_channel):
self.guest_channel = guest_channel
if self.model.multiclass:
output_dim = self.model.weight.shape[1]
else:
output_dim = 1
guest_channel.send_all('output_dim', output_dim)
guest_channel.send_all('output_dim', self.model.output_dim)
def compute_z(self, x):
guest_z = self.guest_channel.recv_all('guest_z')
......@@ -186,7 +204,7 @@ class Plaintext_Host:
def compute_regular_loss(self):
if self.model.alpha != 0:
guest_regular_loss = self.guest_channel.recv_all('guest_regular_loss')
return self.model.compute_regular_loss(guest_regular_loss)
return self.model.compute_regular_loss(sum(guest_regular_loss))
else:
return 0.
......@@ -236,4 +254,113 @@ class Plaintext_Host:
'train_ks': ks,
'train_auc': auc
}
\ No newline at end of file
def compute_final_metrics(self, x, y):
return self.compute_metrics(x, y)
class CKKS_Host(Plaintext_Host, CKKS):
def __init__(self, x, y, learning_rate, alpha,
guest_channel, coordinator_channel):
self.t = 0
self.model = LogisticRegression_Host_CKKS(x, y,
learning_rate,
alpha)
self.send_output_dim(guest_channel)
coordinator_channel.send('multiclass', self.model.multiclass)
self.recv_public_context(coordinator_channel)
coordinator_channel.send('num_examples', x.shape[0])
CKKS.__init__(self, self.context)
multiply_per_iter = 2
self.max_iter = self.multiply_depth // multiply_per_iter
self.encrypt_model()
def recv_public_context(self, coordinator_channel):
self.coordinator_channel = coordinator_channel
self.context = coordinator_channel.recv('public_context')
def encrypt_model(self):
if self.model.multiclass:
self.model.weight = self.encrypt_tensor(self.model.weight.T)
self.model.bias = self.encrypt_tensor(self.model.bias.T)
else:
self.model.weight = self.encrypt_vector(self.model.weight)
self.model.bias = self.encrypt_vector(self.model.bias)
def update_ciphertext_model(self):
self.coordinator_channel.send('host_weight',
self.model.weight.serialize())
self.coordinator_channel.send('host_bias',
self.model.bias.serialize())
if self.model.multiclass:
self.model.weight = self.load_tensor(
self.coordinator_channel.recv('host_weight'))
self.model.bias = self.load_tensor(
self.coordinator_channel.recv('host_bias'))
else:
self.model.weight = self.load_vector(
self.coordinator_channel.recv('host_weight'))
self.model.bias = self.load_vector(
self.coordinator_channel.recv('host_bias'))
def update_plaintext_model(self):
self.coordinator_channel.send('host_weight',
self.model.weight.serialize())
self.coordinator_channel.send('host_bias',
self.model.bias.serialize())
self.model.weight = self.coordinator_channel.recv('host_weight')
self.model.bias = self.coordinator_channel.recv('host_bias')
def compute_enc_z(self, x):
guest_z = self.guest_channel.recv_all('guest_z')
if self.model.multiclass:
guest_z = [self.load_tensor(z) for z in guest_z]
else:
guest_z = [self.load_vector(z) for z in guest_z]
return self.model.compute_enc_z(x, guest_z)
def compute_enc_regular_loss(self):
if self.model.alpha != 0:
guest_regular_loss = self.guest_channel.recv_all('guest_regular_loss')
if self.model.multiclass:
guest_regular_loss = [self.load_tensor(s) for s in guest_regular_loss]
else:
guest_regular_loss = [self.load_vector(s) for s in guest_regular_loss]
return self.model.compute_regular_loss(sum(guest_regular_loss))
else:
return 0.
def train(self, x, y):
logger.info(f'iteration {self.t} / {self.max_iter}')
if self.t >= self.max_iter:
self.t = 0
logger.warning(f'decrypt model')
self.update_ciphertext_model()
self.t += 1
z = self.compute_enc_z(x)
error = self.model.compute_error(y, z)
self.guest_channel.send_all('error', error.serialize())
self.model.fit(x, error)
def compute_metrics(self, x, y):
logger.info(f'iteration {self.t} / {self.max_iter}')
if self.t >= self.max_iter:
self.t = 0
logger.warning(f'decrypt model')
self.update_ciphertext_model()
z = self.compute_enc_z(x)
regular_loss = self.compute_enc_regular_loss()
loss = self.model.loss(y, z, regular_loss)
self.coordinator_channel.send('loss', loss.serialize())
logger.info('View metrics at coordinator while using Paillier')
def compute_final_metrics(self, x, y):
return super().compute_metrics(x, y)
\ No newline at end of file
......@@ -9,7 +9,8 @@
},
"VFL_logistic_regression": {
"guest": "primihub.FL.logistic_regression.vfl_guest.LogisticRegressionGuest",
"host": "primihub.FL.logistic_regression.vfl_host.LogisticRegressionHost"
"host": "primihub.FL.logistic_regression.vfl_host.LogisticRegressionHost",
"coordinator": "primihub.FL.logistic_regression.vfl_coordinator.LogisticRegressionCoordinator"
},
"HFL_neural_network": {
"client": "primihub.FL.neural_network.hfl_client.NeuralNetworkClient",
......
......@@ -23,10 +23,16 @@ class NeuralNetworkClient(BaseModel):
super().__init__(**kwargs)
def run(self):
if self.common_params['process'] == 'train':
process = self.common_params['process']
logger.info(f"process: {process}")
if process == 'train':
self.train()
elif self.common_params['process'] == 'predict':
elif process == 'predict':
self.predict()
else:
error_msg = f"Unsupported process: {process}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def train(self):
# setup communication channels
......@@ -257,6 +263,7 @@ class Plaintext_Client:
def set_model(self, model):
self.model.load_state_dict(model)
self.model.to(self.device)
def send_output_dim(self, y):
if self.task == 'regression':
......@@ -344,8 +351,8 @@ class Plaintext_Client:
pred = self.model(x)
if self.task == 'classification':
y_true = torch.cat((y_true, y))
y_score = torch.cat((y_score, pred))
y_true = torch.cat((y_true, y.cpu()))
y_score = torch.cat((y_score, pred.cpu()))
loss += self.loss_fn(pred, y).item() * len(x)
if self.output_dim == 1:
......@@ -353,8 +360,8 @@ class Plaintext_Client:
else:
acc += (pred.argmax(1) == y).type(torch.float).sum().item()
elif self.task == 'regression':
mae += F.l1_loss(pred, y) * len(x)
mse += F.mse_loss(pred, y) * len(x)
mae += F.l1_loss(pred, y).cpu() * len(x)
mse += F.mse_loss(pred, y).cpu() * len(x)
client_metrics = {}
......@@ -384,11 +391,11 @@ class Plaintext_Client:
elif self.task == 'regression':
mse /= size
client_metrics['train_mse'] = mse
self.server_channel.send("mse", mse.type(torch.float64))
self.server_channel.send("mse", mse)
mae /= size
client_metrics['train_mae'] = mae
self.server_channel.send("mae", mae.type(torch.float64))
self.server_channel.send("mae", mae)
logger.info(f"mse={mse}, mae={mae}")
......@@ -414,7 +421,7 @@ class DPSGD_Client(Plaintext_Client):
input_shape = list(self.input_shape)
# set batch size equals to 1 to initilize lazy module
input_shape.insert(0, 1)
self.model.forward(torch.ones(input_shape))
self.model.forward(torch.ones(input_shape).to(self.device))
super().lazy_module_init()
def enable_DP_training(self, train_dataloader):
......
......@@ -22,10 +22,16 @@ class CNNClient(BaseModel):
super().__init__(**kwargs)
def run(self):
if self.common_params['process'] == 'train':
process = self.common_params['process']
logger.info(f"process: {process}")
if process == 'train':
self.train()
elif self.common_params['process'] == 'predict':
elif process == 'predict':
self.predict()
else:
error_msg = f"Unsupported process: {process}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def train(self):
# setup communication channels
......@@ -140,6 +146,7 @@ class CNNClient(BaseModel):
model.eval()
with torch.no_grad():
for x, y in data_loader:
x = x.to(device)
pred = model(x)
pred_prob = torch.softmax(pred, dim=1)
pred_y = pred_prob.argmax(1)
......
......@@ -15,8 +15,14 @@ class CNNServer(BaseModel):
super().__init__(**kwargs)
def run(self):
if self.common_params['process'] == 'train':
process = self.common_params['process']
logger.info(f"process: {process}")
if process == 'train':
self.train()
else:
error_msg = f"Unsupported process: {process}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def train(self):
# setup communication channels
......
......@@ -17,8 +17,14 @@ class NeuralNetworkServer(BaseModel):
super().__init__(**kwargs)
def run(self):
if self.common_params['process'] == 'train':
process = self.common_params['process']
logger.info(f"process: {process}")
if process == 'train':
self.train()
else:
error_msg = f"Unsupported process: {process}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def train(self):
# setup communication channels
......@@ -146,7 +152,7 @@ class Plaintext_Server:
input_shape = list(self.input_shape)
# set batch size equals to 1 to initilize lazy module
input_shape.insert(0, 1)
self.model.forward(torch.ones(input_shape))
self.model.forward(torch.ones(input_shape).to(self.device))
self.model.load_state_dict(self.model.state_dict())
self.server_model_broadcast()
......@@ -164,7 +170,7 @@ class Plaintext_Server:
np.array(self.num_positive_examples_weights)).tolist()
self.num_examples_weights = torch.tensor(self.num_examples_weights,
dtype=torch.float32)
dtype=torch.float32).to(self.device)
self.num_examples_weights_sum = self.num_examples_weights.sum()
def client_model_aggregate(self):
......@@ -201,9 +207,11 @@ class Plaintext_Server:
raise RuntimeError(error_msg)
client_metrics = self.client_channel.recv_all(metrics_name)
return np.average(client_metrics,
weights=self.num_examples_weights)
metrics = torch.tensor(client_metrics, dtype=torch.float).to(self.device) \
@ self.num_examples_weights \
/ self.num_examples_weights_sum
return float(metrics)
def get_fpr_tpr(self):
client_fpr = self.client_channel.recv_all('fpr')
......
......@@ -32,9 +32,9 @@ class GrpcClient:
def send(self, key, val):
key = self.local_party + '_' + key
logger.info(f"Start send {key}")
logger.info(f"Start send {key} to {self.remote_party}")
self.send_channel.send(key, pickle.dumps(val))
logger.info(f"End send {key}")
logger.info(f"End send {key} to {self.remote_party}")
def recv(self, key):
key = self.remote_party + '_' + key
......@@ -61,11 +61,20 @@ class MultiGrpcClients:
logger.info("End send all")
def send_selected(self, key, val, selected_remote):
logger.info(f"Start send {selected_remote}")
logger.info(f"Start send to {selected_remote}")
for remote_party in selected_remote:
client = self.Clients[remote_party]
client.send(key, val)
logger.info(f"End send {selected_remote}")
logger.info(f"End send to {selected_remote}")
def send_seperately(self, key, valList):
assert len(valList) == len(self.Clients)
i = 0
logger.info(f"Start send seperately")
for client in self.Clients.values():
client.send(key, valList[i])
i += 1
logger.info(f"End send seperately")
def recv_all(self, key):
logger.info("Start receive all")
......
......@@ -21,6 +21,7 @@ from concurrent.futures import ThreadPoolExecutor
import json
import csv
import copy
import codecs
import mysql.connector
from datetime import datetime
from primihub.utils.logger_util import logger
......@@ -181,8 +182,6 @@ class DataAlign:
def generate_new_datast_from_mysql(self, meta_info, query_thread_num):
if not self.has_data_rows(meta_info["psiPath"]):
raise Exception("PSI result is empty, no intersection is found")
db_info = meta_info["localdata_path"]
# Connect to mysql server and create cursor.
try:
......@@ -239,17 +238,21 @@ class DataAlign:
# Collect all ids that PSI output.
intersect_ids = []
try:
in_f = open(meta_info["psiPath"])
reader = csv.reader(in_f)
next(reader)
for id in reader:
intersect_ids.append(id[0])
code_type = "utf-8"
if self.hasBOM(meta_info["psiPath"]):
code_type = "utf-8-sig"
with open(meta_info["psiPath"], encoding=code_type) as in_f:
reader = csv.reader(in_f)
next(reader)
for id in reader:
intersect_ids.append(id[0])
except OSError as e:
logger.error("Open file {} for read failed.".format(
meta_info["psiPath"]))
logger.error(e)
raise e
if (len(intersect_ids) == 0):
raise Exception("PSI result is empty, no intersection is found")
# Open new file to save all value of these ids.
writer = None
try:
......@@ -332,7 +335,6 @@ class DataAlign:
return
out_f.close()
in_f.close()
if num_rows != len(intersect_ids):
raise RuntimeError("Expect query {} rows from mysql but mysql return {} rows, this should be a bug.".format(
......@@ -347,9 +349,13 @@ class DataAlign:
return
def has_data_rows(self, fname):
with open(fname) as file:
return file.readline() and file.readline()
def hasBOM(self, csv_file):
with open(csv_file, 'rb') as f:
bom = f.read(3)
if bom.startswith(codecs.BOM_UTF8):
return True
else:
return False
def generate_new_dataset_from_csv(self, meta_info):
psi_path = meta_info["psiPath"]
......@@ -366,14 +372,15 @@ class DataAlign:
psi_index, old_dataset_path)
logger.info(log_msg)
if not self.has_data_rows(psi_path):
raise Exception("PSI result is empty, no intersection is found")
intersection_map = {}
intersection_set = set()
intersection_list = list()
with open(psi_path) as f:
code_type = "utf-8"
if self.hasBOM(psi_path):
code_type = "utf-8-sig"
with open(psi_path, encoding=code_type) as f:
f_csv = csv.reader(f)
header = next(f_csv)
for items in f_csv:
......@@ -382,6 +389,8 @@ class DataAlign:
continue
intersection_set.add(item)
intersection_list.append(item)
if (len(intersection_list) == 0):
raise Exception("PSI result is empty, no intersection is found")
with open(old_dataset_path) as old_f, open(new_dataset_output_path, 'w') as new_f:
f_csv = csv.reader(old_f)
......
......@@ -29,8 +29,10 @@ class Client:
party_datasets = {}
for party_name, role_param in self.role_params.items():
Dataset = common_pb2.Dataset()
Dataset.data['data_set'] = role_param.get('data_set','{}')
party_datasets[party_name] = Dataset
data_key = role_param.get('data_set')
if data_key:
Dataset.data['data_set'] = data_key
party_datasets[party_name] = Dataset
# construct 'task_info'
task_info = common_pb2.TaskContext()
......
# --extra-index-url https://download.pytorch.org/whl/cpu
--extra-index-url https://download.pytorch.org/whl/cpu
pyarrow==6.0.1
pandas
......@@ -14,6 +14,7 @@ protobuf==3.20.0
sphinx
scikit-learn==1.2.2
phe==1.5.0
tenseal==0.3.14; platform_machine != "arm64" and platform_machine != "aarch64"
mysql-connector-python
sqlalchemy==2.0.16
......@@ -28,7 +29,7 @@ scipy~=1.7.1
modin
opacus==1.4.0
torch==1.13.1
torchvision==0.14.1
# torch==1.13.1+cpu
# torchvision==0.14.1+cpu
torch==1.13.1+cpu; platform_machine != "arm64" and platform_machine != "aarch64"
torchvision==0.14.1+cpu; platform_machine != "arm64" and platform_machine != "aarch64"
torch==1.13.1; platform_machine == "arm64" or platform_machine == "aarch64"
torchvision==0.14.1; platform_machine == "arm64" or platform_machine == "aarch64"
......@@ -14,74 +14,154 @@ config_setting(
values = {"define": "cpu_arch=darwin_x86_64"},
)
DEFAULT_DEPS_OPT = [
"//src/primihub/common:common_defination",
"@eigen//:eigen",
"@com_github_ladnir_aby3//aby3:aby3_lib",
"@ladnir_cryptoTools//:libcryptoTools",
"@com_github_glog_glog//:glog",
]
cc_library(
name = "algorithm_lib",
srcs = glob([
"aby3ML.cc",
"linear_model_gen.cc",
"logistic.cc",
"logistic_plain.cc",
"plainML.cc",
"falcon_lenet.cc",
"arithmetic.cc",
# "src/primihub/executor/express.cc",
# "src/primihub/operator/aby3_operator.cc",
"missing_val_processing.cc",
"mpc_statistics.cc",
]),
hdrs = glob([
"base.h",
"aby3ML.h",
"linear_model_gen.h",
"logistic.h",
"logistic_plain.h",
"plainML.h",
"regression.h",
"falcon_lenet.h",
"arithmetic.h",
"missing_val_processing.h",
"mpc_statistics.h",
]),
linkstatic = False,
deps = [
# ":eigen",
"//src/primihub/common:common_lib",
"//src/primihub/protocol:protocol_aby3_lib",
"//src/primihub/protocol:protocol_falcon_lib",
"//src/primihub/data_store:data_store_lib",
"//src/primihub/util:model_util_lib",
"//src/primihub/util:util_lib",
"//src/primihub/service:dataset_service",
"//src/primihub/executor:mpc_express_executor",
"//src/primihub/data_store:data_store_util",
"//src/primihub/util/crypto:crypto_lib",
#TODO condition select
"//src/primihub/util/network:communication_lib",
"//src/primihub/util/network:mpc_commpkg"
],
name = "aby3_ml",
srcs = [
"aby3ML.cc",
],
hdrs = [
"aby3ML.h",
],
deps = DEFAULT_DEPS_OPT,
)
cc_library(
name = "cryptflow2_algorithm_lib",
srcs = glob([
"cryptflow2_maxpool.cc",
]),
hdrs = glob([
"base.h",
"cryptflow2_maxpool.h",
]),
copts = select({
":x86_64": ["-maes", "-mrdseed", "-mavx2"],
":aarch64": [],
"//conditions:default": [],
}),
deps = [
"//src/primihub/common:common_lib",
"//src/primihub/data_store:data_store_lib",
"//src/primihub/service:dataset_service",
"//src/primihub/protocol:protocol_cryptflow2_ot_lib",
"//src/primihub/util/network:network_lib",
"//src/primihub/util:instruction_check_util",
"//src/primihub/util/network:communication_lib",
],
)
\ No newline at end of file
name = "generate_linear_model",
srcs = [
"linear_model_gen.cc",
],
hdrs = [
"linear_model_gen.h",
],
deps = DEFAULT_DEPS_OPT,
)
cc_library(
name = "regression",
hdrs = ["regression.h"],
srcs = ["regression.cc"],
deps = DEFAULT_DEPS_OPT,
)
cc_library(
name = "plain_ml",
srcs = ["plainML.cc",],
hdrs = ["plainML.h",],
deps = DEFAULT_DEPS_OPT,
)
cc_library(
name = "logistic_plain",
srcs = ["logistic_plain.cc",],
hdrs = ["logistic_plain.h",],
deps = [
"//src/primihub/util:eigen_util",
":regression",
":generate_linear_model",
":plain_ml",
],
)
cc_library(
name = "algorithm_base",
hdrs = ["base.h"],
srcs = ["base.cc"],
deps = [
"//src/primihub/common:party_config",
"//src/primihub/service:dataset_service",
"//src/primihub/util/network:communication_lib",
"//src/primihub/util/network:message_exchange_interface",
],
)
cc_library(
name = "logistic",
srcs = ["logistic.cc",],
hdrs = ["logistic.h",],
deps = [
":aby3_ml",
":algorithm_base",
":generate_linear_model",
":plain_ml",
":regression",
"//src/primihub/service:dataset_service",
"//src/primihub/util/network:message_exchange_interface",
"@ladnir_cryptoTools//:libcryptoTools",
"@arrow",
"@com_github_glog_glog//:glog",
"@eigen//:eigen",
],
)
cc_library(
name = "arithmetic",
srcs = ["arithmetic.cc",],
hdrs = ["arithmetic.h",],
deps = [
":algorithm_base",
"//src/primihub/executor:mpc_express_executor",
"//src/primihub/service:dataset_service",
"//src/primihub/util/network:communication_lib",
"//src/primihub/common:common_defination",
"//src/primihub/util/network:message_exchange_interface",
"//src/primihub/util:util_lib",
"@com_github_ladnir_aby3//aby3:aby3_lib",
"@arrow",
],
)
cc_library(
name = "missing_val_proc",
srcs = ["missing_val_processing.cc",],
hdrs = ["missing_val_processing.h",],
deps = [
":algorithm_base",
"//src/primihub/executor:mpc_express_executor",
"//src/primihub/service:dataset_service",
"//src/primihub/util/network:communication_lib",
"//src/primihub/util/network:message_exchange_interface",
"@arrow",
],
)
cc_library(
name = "mpc_statistics",
srcs = ["mpc_statistics.cc"],
hdrs = ["mpc_statistics.h"],
deps = [
":algorithm_base",
"//src/primihub/executor:mpc_express_executor",
"//src/primihub/service:dataset_service",
"//src/primihub/util/network:communication_lib",
"//src/primihub/util/network:message_exchange_interface",
"@arrow",
],
)
cc_library(
name = "algorithm_lib",
deps = [
":algorithm_base",
":mpc_statistics",
":missing_val_proc",
":arithmetic",
":logistic",
":logistic_plain",
":plain_ml",
":regression",
":aby3_ml",
":generate_linear_model",
],
)
cc_library(
name = "lib_opt_paillier",
deps = [
"//src/primihub/algorithm/opt_paillier:lib_opt_paillier_impl",
],
visibility = ["//visibility:public"],
)
// "Copyright [2021] <Primihub>"
#include "src/primihub/algorithm/aby3ML.h"
#include <glog/logging.h>
#include <memory>
#include <utility>
namespace primihub {
#ifdef MPC_SOCKET_CHANNEL
void aby3ML::init(u64 partyIdx, Session& prev, Session& next,
block seed) {
mPreproPrev = prev.addChannel();
mPreproNext = next.addChannel();
mPrev = prev.addChannel();
mNext = next.addChannel();
auto commPtr = std::make_shared<CommPkg>(mPrev, mNext);
mRt.init(partyIdx, commPtr);
PRNG prng(seed);
mEnc.init(partyIdx, *commPtr.get(), prng.get<block>());
mEval.init(partyIdx, *commPtr.get(), prng.get<block>());
void aby3ML::init(u64 partyIdx, Session& prev,
Session& next, block seed) {
auto comm_pkg_ = std::make_unique<aby3::CommPkg>();
comm_pkg_->mNext = next.addChannel();
comm_pkg_->mPrev = prev.addChannel();
mRt.init(partyIdx, *comm_pkg_);
osuCrypto::PRNG prng(seed);
mEnc.init(partyIdx, *comm_pkg_, prng.get<block>());
mEval.init(partyIdx, *comm_pkg_, prng.get<block>());
}
void aby3ML::fini(void) {
mPreproPrev.close();
mPreproNext.close();
mPrev.close();
mNext.close();
void aby3ML::init(u64 partyIdx,
std::unique_ptr<aby3::CommPkg> comm_pkg,
block seed) {
auto comm_pkg_ = std::move(comm_pkg);
mRt.init(partyIdx, *comm_pkg_);
LOG(INFO) << "Runtime init finish.";
osuCrypto::PRNG prng(seed);
mEnc.init(partyIdx, *comm_pkg_, prng.get<block>());
LOG(INFO) << "Encryptor init finish.";
mEval.init(partyIdx, *comm_pkg_, prng.get<block>());
LOG(INFO) << "Evaluator init finish.";
}
#else
void aby3ML::init(u64 partyIdx, MpcChannel &prev, MpcChannel &next, block seed) {
mNext = next;
mPrev = prev;
auto commPtr = std::make_shared<CommPkg>(mPrev, mNext);
mRt.init(partyIdx, commPtr);
void aby3ML::init(u64 partyIdx, aby3::CommPkg* comm_pkg, block seed) {
this->comm_pkg_ref_ = comm_pkg;
mRt.init(partyIdx, *comm_pkg_ref_);
LOG(INFO) << "Runtime init finish.";
PRNG prng(seed);
mEnc.init(partyIdx, *commPtr.get(), prng.get<block>());
osuCrypto::PRNG prng(seed);
mEnc.init(partyIdx, *comm_pkg_ref_, prng.get<block>());
LOG(INFO) << "Encryptor init finish.";
mEval.init(partyIdx, *commPtr.get(), prng.get<block>());
mEval.init(partyIdx, *comm_pkg_ref_, prng.get<block>());
LOG(INFO) << "Evaluator init finish.";
}
void aby3ML::fini(void) {}
#endif
void aby3ML::fini(void) {
// this->mNext().close();
// this->mPrev().close();
}
} // namespace primihub
#ifndef SRC_primihub_ALGORITHM_ABY3ML_H_
#define SRC_primihub_ALGORITHM_ABY3ML_H_
// "Copyright [2021] <Primihub>"
#ifndef SRC_PRIMIHUB_ALGORITHM_ABY3ML_H_
#define SRC_PRIMIHUB_ALGORITHM_ABY3ML_H_
#include <algorithm>
#include <random>
#include <vector>
#include <memory>
#include "cryptoTools/Common/Defines.h"
#include "aby3/sh3/Sh3Types.h"
#include "aby3/sh3/Sh3Encryptor.h"
#include "aby3/sh3/Sh3Evaluator.h"
#include "aby3/sh3/Sh3Piecewise.h"
#include "aby3/sh3/Sh3ShareGen.h"
#include "aby3/sh3/Sh3FixedPoint.h"
#include "src/primihub/common/defines.h"
#include "src/primihub/common/type/fixed_point.h"
#include "src/primihub/protocol/aby3/encryptor.h"
#include "src/primihub/protocol/aby3/evaluator/evaluator.h"
#include "src/primihub/protocol/aby3/evaluator/piecewise.h"
#include "src/primihub/protocol/aby3/sh3_gen.h"
#include "src/primihub/util/crypto/prng.h"
#ifdef MPC_SOCKET_CHANNEL
#include "src/primihub/util/network/socket/channel.h"
#include "src/primihub/util/network/socket/session.h"
#else
#include "src/primihub/util/network/mpc_channel.h"
#endif
#include "cryptoTools/Network/Channel.h"
#include "cryptoTools/Network/Session.h"
using Channel = osuCrypto::Channel;
using Session = osuCrypto::Session;
namespace primihub {
using namespace aby3; // NOLINT
class aby3ML {
public:
#ifdef MPC_SOCKET_CHANNEL
Channel mPreproNext;
Channel mPreproPrev;
Channel mNext;
Channel mPrev;
#else
MpcChannel mNext;
MpcChannel mPrev;
#endif
// std::unique_ptr<aby3::CommPkg> comm_pkg_{nullptr};
aby3::CommPkg* comm_pkg_ref_{nullptr};
Sh3Encryptor mEnc;
Sh3Evaluator mEval;
Sh3Runtime mRt;
bool mPrint = true;
u64 partyIdx() {
return mRt.mPartyIdx;
}
u64 partyIdx() { return mRt.mPartyIdx;}
#ifdef MPC_SOCKET_CHANNEL
void init(u64 partyIdx, Session& prev, Session& next, block seed);
#else
void init(u64 partyIdx, MpcChannel &prev, MpcChannel &next, block seed);
#endif
void init(u64 partyIdx, std::unique_ptr<aby3::CommPkg> comm_pkg, block seed);
void init(u64 partyIdx, aby3::CommPkg* comm_pkg, block seed);
void fini(void);
Channel& mNext() {return comm_pkg_ref_->mNext;}
Channel& mPrev() {return comm_pkg_ref_->mPrev;}
template<Decimal D>
sf64Matrix<D> localInput(const f64Matrix<D>& val) {
std::array<u64, 2> size{ val.rows(), val.cols() };
mNext.asyncSendCopy(size);
mPrev.asyncSendCopy(size);
this->mNext().asyncSendCopy(size);
this->mPrev().asyncSendCopy(size);
sf64Matrix<D> dest(size[0], size[1]);
mEnc.localFixedMatrix(mRt.noDependencies(), val, dest).get();
return dest;
......@@ -71,19 +62,20 @@ class aby3ML {
void localInputSize(const eMatrix<double>& val) {
std::array<u64, 2> size{
static_cast<unsigned long long>(val.rows()),
static_cast<unsigned long long>(val.cols())
static_cast<u64>(val.rows()),
static_cast<u64>(val.cols())
};
mNext.send(size);
mPrev.send(size);
this->mNext().send(size);
this->mPrev().send(size);
return;
}
template<Decimal D>
sf64Matrix<D> localInput(const eMatrix<double>& vals) {
f64Matrix<D> v2(vals.rows(), vals.cols());
for (u64 i = 0; i < vals.size(); ++i)
for (i64 i = 0; i < vals.size(); ++i) {
v2(i) = vals(i);
}
return localInput(v2);
}
......@@ -97,12 +89,13 @@ class aby3ML {
template<Decimal D>
sf64Matrix<D> remoteInput(u64 partyIdx) {
std::array<u64, 2> size;
if (partyIdx == (mRt.mPartyIdx + 1) % 3)
mNext.recv(size);
else if (partyIdx == (mRt.mPartyIdx + 2) % 3)
mPrev.recv(size);
else
if (partyIdx == (mRt.mPartyIdx + 1) % 3) {
this->mNext().recv(size);
} else if (partyIdx == (mRt.mPartyIdx + 2) % 3) {
this->mPrev().recv(size);
} else {
throw RTE_LOC;
}
sf64Matrix<D> dest(size[0], size[1]);
mEnc.remoteFixedMatrix(mRt.noDependencies(), dest).get();
......@@ -119,13 +112,13 @@ class aby3ML {
std::array<u64, 2> remoteInputSize(u64 partyIdx) {
std::array<u64, 2> size;
if (partyIdx == (mRt.mPartyIdx + 1) % 3)
mNext.recv(size);
else if (partyIdx == (mRt.mPartyIdx + 2) % 3)
mPrev.recv(size);
else
if (partyIdx == (mRt.mPartyIdx + 1) % 3) {
this->mNext().recv(size);
} else if (partyIdx == (mRt.mPartyIdx + 2) % 3) {
this->mPrev().recv(size);
} else {
throw RTE_LOC;
}
return size;
}
......@@ -139,8 +132,9 @@ class aby3ML {
mEnc.revealAll(mRt.noDependencies(), vals, temp).get();
eMatrix<double> ret(vals.rows(), vals.cols());
for (u64 i = 0; i < ret.size(); ++i)
for (i64 i = 0; i < ret.size(); ++i) {
ret(i) = static_cast<double>(temp(i));
}
return ret;
}
......@@ -149,8 +143,9 @@ class aby3ML {
mEnc.revealAll(mRt.noDependencies(), vals, temp).get();
eMatrix<double> ret(vals.rows(), vals.cols());
for (u64 i = 0; i < ret.size(); ++i)
for (i64 i = 0; i < ret.size(); ++i) {
ret(i) = static_cast<double>(temp(i));
}
return ret;
}
......@@ -386,9 +381,9 @@ class aby3ML {
e.seed(time(nullptr));
if (e() != 0) {
return int(e());
return static_cast<int>(e());
} else {
return int(e() + 1);
return static_cast<int>(e() + 1);
}
}
......@@ -525,4 +520,4 @@ class aby3ML {
} // namespace primihub
#endif // SRC_primihub_ALGORITHM_ABY3ML_H_
#endif // SRC_PRIMIHUB_ALGORITHM_ABY3ML_H_
#include "src/primihub/algorithm/base.h"
#include "src/primihub/common/defines.h"
#include "src/primihub/common/type/type.h"
#include "src/primihub/data_store/driver.h"
#include "src/primihub/executor/express.h"
#include "src/primihub/util/network/mpc_channel.h"
#include "src/primihub/service/dataset/service.h"
#include <stdlib.h>
#include <time.h>
#include <algorithm>
#include <iostream>
#include <math.h>
#include <cmath>
#include <sstream>
#include <stdlib.h>
#include <string>
#include <time.h>
#include <vector>
#include "src/primihub/data_store/driver.h"
#include "src/primihub/executor/express.h"
// #include "src/primihub/service/dataset/service.h"
#include "src/primihub/common/type.h"
namespace primihub {
template <Decimal Dbit> class ArithmeticExecutor : public AlgorithmBase {
template <Decimal Dbit>
class ArithmeticExecutor : public AlgorithmBase {
public:
explicit ArithmeticExecutor(PartyConfig &config,
std::shared_ptr<DatasetService> dataset_service);
int loadParams(primihub::rpc::Task &task) override;
int loadDataset(void) override;
int initPartyComm(void) override;
int execute() override;
int finishPartyComm(void) override;
int saveModel(void);
retcode InitEngine() override;
private:
int _LoadDatasetFromCSV(std::string &filename);
......@@ -40,30 +40,11 @@ private:
std::unique_ptr<MPCOperator> mpc_op_exec_;
std::unique_ptr<MPCExpressExecutor<Dbit>> mpc_exec_;
#ifdef MPC_SOCKET_CHANNEL
Session ep_next_;
Session ep_prev_;
IOService ios_;
std::string next_ip_, prev_ip_;
uint16_t next_port_, prev_port_;
#else
primihub::Node local_node_;
std::map<uint16_t, primihub::Node> node_map_;
std::shared_ptr<network::IChannel> base_channel_next_;
std::shared_ptr<network::IChannel> base_channel_prev_;
std::shared_ptr<MpcChannel> mpc_channel_next_;
std::shared_ptr<MpcChannel> mpc_channel_prev_;
#endif
ABY3PartyConfig party_config_;
std::string task_id_;
std::string job_id_;
// For MPC compare task.
bool is_cmp;
bool is_cmp{false};
std::vector<bool> cmp_res_;
// For MPC express task.
......@@ -76,44 +57,44 @@ private:
std::map<std::string, std::vector<int64_t>> col_and_val_int;
};
#ifndef MPC_SOCKET_CHANNEL
// This class just run send and recv many type of value with MPC channel.
class MPCSendRecvExecutor : public AlgorithmBase {
public:
explicit MPCSendRecvExecutor(PartyConfig &config,
std::shared_ptr<DatasetService> dataset_service);
using TaskGetChannelFunc =
std::function<std::shared_ptr<network::IChannel>(primihub::Node &node)>;
using TaskGetRecvQueueFunc =
std::function<ThreadSafeQueue<std::string> &(const std::string &key)>;
int loadParams(rpc::Task &task) override;
int loadDataset(void) override;
int initPartyComm(void) override;
int execute() override;
int finishPartyComm(void) override;
int saveModel(void);
private:
std::string job_id_;
std::string task_id_;
std::unique_ptr<MPCOperator> mpc_op_;
std::shared_ptr<network::IChannel> base_channel_next_;
std::shared_ptr<network::IChannel> base_channel_prev_;
std::shared_ptr<MpcChannel> mpc_channel_next_;
std::shared_ptr<MpcChannel> mpc_channel_prev_;
std::map<uint16_t, primihub::Node> partyid_node_map_;
uint16_t local_party_id_;
uint16_t next_party_id_;
uint16_t prev_party_id_;
primihub::Node local_node_;
};
#endif
// #ifndef MPC_SOCKET_CHANNEL
// // This class just run send and recv many type of value with MPC channel.
// class MPCSendRecvExecutor : public AlgorithmBase {
// public:
// explicit MPCSendRecvExecutor(PartyConfig &config,
// std::shared_ptr<DatasetService> dataset_service);
// using TaskGetChannelFunc =
// std::function<std::shared_ptr<network::IChannel>(primihub::Node &node)>;
// using TaskGetRecvQueueFunc =
// std::function<ThreadSafeQueue<std::string> &(const std::string &key)>;
// int loadParams(rpc::Task &task) override;
// int loadDataset(void) override;
// int initPartyComm(void) override;
// int execute() override;
// int finishPartyComm(void) override;
// int saveModel(void);
// private:
// std::string job_id_;
// std::string task_id_;
// std::unique_ptr<MPCOperator> mpc_op_;
// // std::shared_ptr<network::IChannel> base_channel_next_;
// // std::shared_ptr<network::IChannel> base_channel_prev_;
// // std::shared_ptr<MpcChannel> mpc_channel_next_;
// // std::shared_ptr<MpcChannel> mpc_channel_prev_;
// // std::map<uint16_t, primihub::Node> partyid_node_map_;
// // uint16_t local_party_id_;
// // uint16_t next_party_id_;
// // uint16_t prev_party_id_;
// // primihub::Node local_node_;
// };
// #endif
} // namespace primihub
/*
Copyright 2022 PrimiHub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "src/primihub/algorithm/base.h"
#include <map>
#include "src/primihub/util/network/message_interface.h"
#include "src/primihub/util/network/link_context.h"
namespace primihub {
oc::IOService g_ios_;
AlgorithmBase::AlgorithmBase(const PartyConfig& config,
std::shared_ptr<DatasetService> dataset_service)
: dataset_service_(std::move(dataset_service)) {
party_config_.Init(config);
this->set_party_name(config.party_name());
this->set_party_id(config.party_id());
#ifdef MPC_SOCKET_CHANNEL
auto &node_map = config.node_map;
std::map<uint16_t, rpc::Node> party_id_node_map;
for (auto iter = node_map.begin(); iter != node_map.end(); iter++) {
auto& node = iter->second;
uint16_t party_id = static_cast<uint16_t>(node.vm(0).party_id());
party_id_node_map[party_id] = node;
}
auto iter = node_map.find(config.node_id); // node_id
if (iter == node_map.end()) {
std::stringstream ss;
ss << "Can't find " << config.node_id << " in node_map.";
throw std::runtime_error(ss.str());
}
uint16_t local_id_ = iter->second.vm(0).party_id();
LOG(INFO) << "Note party id of this node is " << local_id_ << ".";
if (local_id_ == 0) {
rpc::Node &node = party_id_node_map[0];
uint16_t port = 0;
// Two Local server addr.
auto& next_ip = node.ip();
uint16_t next_port = node.vm(0).next().port();
next_addr_ = std::make_pair(next_ip, next_port);
auto& prev_ip = node.ip();
uint16_t prev_port = node.vm(0).prev().port();
prev_addr_ = std::make_pair(prev_ip, prev_port);
} else if (local_id_ == 1) {
rpc::Node &node = party_id_node_map[1];
// A local server addr.
auto& next_ip = node.ip();
uint16_t next_port = node.vm(0).next().port();
next_addr_ = std::make_pair(next_ip, next_port);
// A remote server addr.
auto& prev_ip = node.vm(0).prev().ip();
uint16_t prev_port = node.vm(0).prev().port();
prev_addr_ = std::make_pair(prev_ip, prev_port);
} else {
rpc::Node &node = party_id_node_map[2];
// Two remote server addr.
auto& next_ip = node.vm(0).next().ip();
uint16_t next_port = node.vm(0).next().port();
next_addr_ = std::make_pair(next_ip, next_port);
auto& prev_ip = node.vm(0).prev().ip();
uint16_t prev_port = node.vm(0).prev().port();
prev_addr_ = std::make_pair(prev_ip, prev_port);
}
#endif
}
#ifdef MPC_SOCKET_CHANNEL
int AlgorithmBase::initPartyComm() {
VLOG(3) << "Next addr: " << next_addr_.first << ":" << next_addr_.second << ".";
VLOG(3) << "Prev addr: " << prev_addr_.first << ":" << prev_addr_.second << ".";
if (party_id_ == 0) {
std::ostringstream ss;
ss << "sess_" << party_id_ << "_1";
std::string sess_name_1 = ss.str();
ss.str("");
ss << "sess_" << party_id_ << "_2";
std::string sess_name_2 = ss.str();
ep_next_.start(ios_, next_addr_.first, next_addr_.second,
oc::SessionMode::Server, sess_name_1);
LOG(INFO) << "[Next] Init server session, party " << party_id_ << ", "
<< "ip " << next_addr_.first << ", port " << next_addr_.second
<< ", name " << sess_name_1 << ".";
ep_prev_.start(ios_, prev_addr_.first, prev_addr_.second,
oc::SessionMode::Server, sess_name_2);
LOG(INFO) << "[Prev] Init server session, party " << party_id_ << ", "
<< "ip " << prev_addr_.first << ", port " << prev_addr_.second
<< ", name " << sess_name_2 << ".";
} else if (party_id_ == 1) {
std::ostringstream ss;
ss << "sess_" << party_id_ << "_1";
std::string sess_name_1 = ss.str();
ss.str("");
ss << "sess_" << party_config_.PrevPartyId() << "_1";
std::string sess_name_2 = ss.str();
ep_next_.start(ios_, next_addr_.first, next_addr_.second,
oc::SessionMode::Server, sess_name_1);
LOG(INFO) << "[Next] Init server session, party " << party_id_ << ", "
<< "ip " << next_addr_.first << ", port " << next_addr_.second
<< ", name " << sess_name_1 << ".";
ep_prev_.start(ios_, prev_addr_.first, prev_addr_.second,
oc::SessionMode::Client, sess_name_2);
LOG(INFO) << "[Prev] Init client session, party " << party_id_ << ", "
<< "ip " << prev_addr_.first << ", port " << prev_addr_.second
<< ", name " << sess_name_2 << ".";
} else {
std::ostringstream ss;
ss.str("");
ss << "sess_" << party_config_.NextPartyId() << "_2";
std::string sess_name_1 = ss.str();
ss.str("");
ss << "sess_" << party_config_.PrevPartyId() << "_1";
std::string sess_name_2 = ss.str();
ep_next_.start(ios_, next_addr_.first, next_addr_.second,
oc::SessionMode::Client, sess_name_1);
LOG(INFO) << "[Next] Init client session, party " << party_id_ << ", "
<< "ip " << next_addr_.first << ", port " << next_addr_.second
<< ", name " << sess_name_1 << ".";
ep_prev_.start(ios_, prev_addr_.first, prev_addr_.second,
oc::SessionMode::Client, sess_name_2);
LOG(INFO) << "[Prev] Init client session, party " << party_id_ << ", "
<< "ip " << prev_addr_.first << ", port " << prev_addr_.second
<< ", name " << sess_name_2 << ".";
}
session_enabled = true;
// init communication channel
comm_pkg_ = std::make_unique<aby3::CommPkg>();
comm_pkg_->mNext = ep_next_.addChannel();
comm_pkg_->mPrev = ep_prev_.addChannel();
this->mNext().waitForConnection();
this->mPrev().waitForConnection();
this->mNext().send(party_id_);
this->mPrev().send(party_id_);
uint16_t prev_party = 0;
uint16_t next_party = 0;
this->mNext().recv(next_party);
this->mPrev().recv(prev_party);
if (next_party != party_config_.NextPartyId()) {
LOG(ERROR) << "Party " << party_id_ << ", expect next party id "
<< party_config_.NextPartyId() << ", but give " << next_party << ".";
return -3;
}
if (prev_party != party_config_.PrevPartyId()) {
LOG(ERROR) << "Party " << party_id_ << ", expect prev party id "
<< party_config_.PrevPartyId() << ", but give " << prev_party << ".";
return -3;
}
return 0;
}
int AlgorithmBase::finishPartyComm() {
if (comm_pkg_ == nullptr) {
return 0;
}
this->mNext().close();
this->mPrev().close();
if (session_enabled) {
ep_next_.stop();
ep_prev_.stop();
}
return 0;
}
#else // GRPC MPC_SOCKET_CHANNEL
int AlgorithmBase::initPartyComm() {
uint16_t prev_party_id = this->party_config_.PrevPartyId();
uint16_t next_party_id = this->party_config_.NextPartyId();
auto link_ctx = this->GetLinkContext();
if (link_ctx == nullptr) {
LOG(ERROR) << "link context is not available";
return -1;
}
// construct channel for communication to next party
std::string party_name_next = this->party_config_.NextPartyName();
auto party_node_next = this->party_config_.NextPartyInfo();
auto base_channel_next = link_ctx->getChannel(party_node_next);
// construct channel for communication to prev party
std::string party_name_prev = this->party_config_.PrevPartyName();
auto party_node_prev = this->party_config_.PrevPartyInfo();
auto base_channel_prev = link_ctx->getChannel(party_node_prev);
std::string job_id = link_ctx->job_id();
std::string task_id = link_ctx->task_id();
std::string request_id = link_ctx->request_id();
LOG(INFO) << "local_id_local_id_: " << this->party_id();
LOG(INFO) << "next_party: " << party_name_next << " detail: " << party_node_next.to_string();
LOG(INFO) << "prev_party: " << party_name_prev << " detail: " << party_node_prev.to_string();
LOG(INFO) << "job_id: " << job_id << " task_id: " << task_id << " request id: " << request_id;
// The 'osuCrypto::Channel' will consider it to be a unique_ptr and will
// reset the unique_ptr, so the 'osuCrypto::Channel' will delete it.
auto msg_interface_prev = std::make_unique<network::TaskMessagePassInterface>(
this->party_name(), party_name_prev, link_ctx, base_channel_prev);
auto msg_interface_next = std::make_unique<network::TaskMessagePassInterface>(
this->party_name(), party_name_next, link_ctx, base_channel_next);
oc::Channel chl_prev(g_ios_, msg_interface_prev.release());
oc::Channel chl_next(g_ios_, msg_interface_next.release());
comm_pkg_ = std::make_unique<aby3::CommPkg>();
comm_pkg_->mPrev = std::move(chl_prev);
comm_pkg_->mNext = std::move(chl_next);
return 0;
}
int AlgorithmBase::finishPartyComm() {
if (comm_pkg_ == nullptr) {
return 0;
}
// std::this_thread::sleep_for(std::chrono::milliseconds(500));
VLOG(5) << "stop next channel, " << link_ctx_ref_->request_id();
this->mNext().close();
VLOG(5) << "stop prev channel " << link_ctx_ref_->request_id();
this->mPrev().close();
return 0;
}
#endif // MPC_SOCKET_CHANNEL
} // namespace primihub
/*
Copyright 2022 Primihub
Copyright 2022 PrimiHub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -18,24 +18,35 @@
#define SRC_PRIMIHUB_ALGORITHM_BASE_H_
#include <string>
#include <memory>
#include <utility>
#include "src/primihub/service/dataset/service.h"
#include "src/primihub/protos/common.grpc.pb.h"
#include "src/primihub/protos/common.pb.h"
#include "src/primihub/util/network/link_context.h"
#include "src/primihub/util/network/socket/session.h"
#include "src/primihub/common/party_config.h"
#include "src/primihub/common/common.h"
#include "cryptoTools/Common/Defines.h"
#include "cryptoTools/Network/IOService.h"
#include "cryptoTools/Network/Channel.h"
#include "cryptoTools/Network/Session.h"
#include "aby3/sh3/Sh3Types.h"
using primihub::rpc::Task;
using primihub::service::DatasetService;
namespace primihub {
extern oc::IOService g_ios_;
struct ABY3PartyConfig {
ABY3PartyConfig() = default;
ABY3PartyConfig(const PartyConfig& config) {
explicit ABY3PartyConfig(const PartyConfig& config) {
party_config.CopyFrom(config);
}
retcode Init(const PartyConfig& config) {
party_config.CopyFrom(config);
return retcode::SUCCESS;
}
uint16_t NextPartyId() {
return (SelfPartyId() + 1) % ABY3_TOTAL_PARTY_NUM;
......@@ -114,14 +125,17 @@ struct ABY3PartyConfig {
class AlgorithmBase {
public:
explicit AlgorithmBase(std::shared_ptr<DatasetService> dataset_service)
: dataset_service_(dataset_service) {};
virtual ~AlgorithmBase(){};
: dataset_service_(dataset_service) {}
AlgorithmBase(const PartyConfig& party_config,
std::shared_ptr<DatasetService> dataset_service);
virtual ~AlgorithmBase() = default;
virtual int loadParams(primihub::rpc::Task &task) = 0;
virtual int loadDataset() = 0;
virtual int initPartyComm() = 0;
virtual int initPartyComm();
virtual retcode InitEngine() {return retcode::SUCCESS;} // to be pure virtual
virtual int execute() = 0;
virtual int finishPartyComm() = 0;
virtual int finishPartyComm();
virtual int saveModel() = 0;
std::shared_ptr<DatasetService>& datasetService() {
......@@ -142,13 +156,30 @@ class AlgorithmBase {
std::string party_name() {return party_name_;}
void set_party_name(const std::string& party_name) {party_name_ = party_name;}
oc::Channel& mNext() {return comm_pkg_->mNext;}
oc::Channel& mPrev() {return comm_pkg_->mPrev;}
aby3::CommPkg* CommPkgPtr() {return comm_pkg_.get();}
protected:
std::shared_ptr<DatasetService> dataset_service_;
std::string algorithm_name_;
network::LinkContext* link_ctx_ref_{nullptr};
std::string party_name_;
uint16_t party_id_;
#ifdef MPC_SOCKET_CHANNEL
std::pair<std::string, uint16_t> next_addr_; // ip:port
std::pair<std::string, uint16_t> prev_addr_; // ip::port
oc::Session ep_next_;
oc::Session ep_prev_;
oc::IOService ios_;
bool session_enabled{false};
#endif
// communication related
std::unique_ptr<aby3::CommPkg> comm_pkg_{nullptr};
// oc::IOService ios_;
ABY3PartyConfig party_config_;
};
} // namespace primihub
} // namespace primihub
#endif // SRC_PRIMIHUB_ALGORITHM_BASE_H_
#endif // SRC_PRIMIHUB_ALGORITHM_BASE_H_
......@@ -21,13 +21,12 @@
#include "src/primihub/algorithm/base.h"
#include "src/primihub/common/clp.h"
#include "src/primihub/common/defines.h"
#include "src/primihub/common/common.h"
#include "src/primihub/common/type/type.h"
#include "src/primihub/protocol/cryptflow2/NonLinear/maxpool.h"
#include "src/primihub/protocol/cryptflow2/globals.h"
#include "src/primihub/util/network/socket/session.h"
#include "src/primihub/common/clp.h"
#include "src/primihub/common/defines.h"
#include "src/primihub/common/type/type.h"
#include "src/primihub/data_store/driver.h"
#include "src/primihub/data_store/factory.h"
......
......@@ -24,7 +24,7 @@
#include "Eigen/Dense"
#include "assert.h"
#include "src/primihub/common/clp.h"
#include "src/primihub/common/defines.h"
#include "src/primihub/common/common.h"
#include "src/primihub/common/type/type.h"
#include "src/primihub/util/network/socket/session.h"
......
......@@ -2,18 +2,59 @@
#include "src/primihub/algorithm/linear_model_gen.h"
namespace primihub {
void LinearModelGen::setModel(eMatrix<double>& model,
double noise, double sd) {
void LinearModelGen::setModel(aby3::eMatrix<double>& model,
double noise, double sd) {
mModel = model;
mNoise = noise;
mSd = sd;
}
void LogisticModelGen::setModel(eMatrix<double>& model,
double noise, double sd) {
void LogisticModelGen::setModel(aby3::eMatrix<double>& model,
double noise, double sd) {
mModel = model;
mNoise = noise;
mSd = sd;
}
aby3::eMatrix<double> LoadDataLocalLogistic(const std::string& full_path) {
std::cout << "go into function: load_data_local" << full_path << std::endl;
std::string line, num;
std::fstream fs;
fs.open(full_path.c_str(), std::fstream::in);
if (!fs) {
std::cout << "File not exists: " << full_path << std::endl;
return aby3::eMatrix<double>(0, 0);
} else {
std::cout << "Start to load data from a fixed full path: "
<< full_path << std::endl;
}
unsigned int rowInd = 0;
unsigned int colInd;
std::vector<std::vector<double>> tmpVec;
std::vector<double> tmpRow;
while (getline(fs, line)) {
rowInd += 1;
colInd = 0;
std::istringstream readstr(line);
tmpRow.clear();
// while (readstr >> num) {
while (getline(readstr, num, ',')) {
colInd += 1;
tmpRow.push_back(std::stold(num.c_str()));
}
tmpVec.push_back(tmpRow);
}
fs.close();
aby3::eMatrix<double> res(rowInd, colInd);
for (unsigned int i = 0; i < rowInd; ++i) {
for (unsigned int j = 0; j < colInd; ++j) {
res(i, j) = tmpVec[i][j];
}
}
return res;
}
} // namespace primihub
// Copyright [2021] <primihub.com>
#ifndef SRC_primihub_ALGORITHM_LINEAR_MODEL_GEN_H_
#define SRC_primihub_ALGORITHM_LINEAR_MODEL_GEN_H_
#ifndef SRC_PRIMIHUB_ALGORITHM_LINEAR_MODEL_GEN_H_
#define SRC_PRIMIHUB_ALGORITHM_LINEAR_MODEL_GEN_H_
#include <fstream>
#include <random>
......@@ -8,30 +8,26 @@
#include <iostream>
#include "Eigen/Dense"
#include "src/primihub/common/type/type.h"
#include "src/primihub/common/defines.h"
#include "cryptoTools/Common/Defines.h"
#include "aby3/sh3/Sh3Types.h"
namespace primihub {
class LinearModelGen {
public:
eMatrix<double> mModel;
aby3::eMatrix<double> mModel;
double mNoise, mSd;
void setModel(eMatrix<double>& model, double noise = 1,
double sd = 1);
void setModel(aby3::eMatrix<double>& model, double noise = 1, double sd = 1);
};
class LogisticModelGen {
public:
eMatrix<double> mModel;
aby3::eMatrix<double> mModel;
double mNoise, mSd;
void setModel(eMatrix<double>& model, double noise = 1,
double sd = 1);
void setModel(aby3::eMatrix<double>& model, double noise = 1, double sd = 1);
};
aby3::eMatrix<double> LoadDataLocalLogistic(const std::string& full_path);
} // namespace primihub
#endif // SRC_primihub_ALGORITHM_LINEAR_MODEL_GEN_H_
#endif // SRC_PRIMIHUB_ALGORITHM_LINEAR_MODEL_GEN_H_
......@@ -24,14 +24,13 @@
#include "src/primihub/data_store/dataset.h"
#include "src/primihub/data_store/factory.h"
#include "src/primihub/service/dataset/model.h"
#include "src/primihub/util/network/message_interface.h"
#include "src/primihub/util/network/link_context.h"
using namespace std;
using namespace Eigen;
using arrow::Array;
using arrow::DoubleArray;
using arrow::Int64Array;
using arrow::Table;
namespace primihub {
eMatrix<double> logistic_main(sf64Matrix<D> &train_data_0_1,
sf64Matrix<D> &train_label_0_1,
......@@ -61,64 +60,9 @@ eMatrix<double> logistic_main(sf64Matrix<D> &train_data_0_1,
LogisticRegressionExecutor::LogisticRegressionExecutor(
PartyConfig &config, std::shared_ptr<DatasetService> dataset_service)
: AlgorithmBase(dataset_service) {
: AlgorithmBase(config, dataset_service) {
this->algorithm_name_ = "logistic_regression";
this->set_party_name(config.party_name());
this->set_party_id(config.party_id());
local_id_ = config.party_id();
#ifdef MPC_SOCKET_CHANNEL
auto &node_map = config.node_map;
std::map<uint16_t, rpc::Node> party_id_node_map;
for (auto iter = node_map.begin(); iter != node_map.end(); iter++) {
rpc::Node &node = iter->second;
uint16_t party_id = static_cast<uint16_t>(node.vm(0).party_id());
party_id_node_map[party_id] = node;
}
auto iter = node_map.find(config.node_id); // node_id
if (iter == node_map.end()) {
stringstream ss;
ss << "Can't find " << config.node_id << " in node_map.";
throw std::runtime_error(ss.str());
}
local_id_ = iter->second.vm(0).party_id();
LOG(INFO) << "Note party id of this node is " << local_id_ << ".";
if (local_id_ == 0) {
rpc::Node &node = party_id_node_map[0];
uint16_t port = 0;
// Two Local server addr.
port = node.vm(0).next().port();
next_addr_ = std::make_pair(node.ip(), port);
port = node.vm(0).prev().port();
prev_addr_ = std::make_pair(node.ip(), port);
} else if (local_id_ == 1) {
rpc::Node &node = party_id_node_map[1];
// A local server addr.
uint16_t port = node.vm(0).next().port();
next_addr_ = std::make_pair(node.ip(), port);
// A remote server addr.
prev_addr_ =
std::make_pair(node.vm(0).prev().ip(), node.vm(0).prev().port());
} else {
rpc::Node &node = party_id_node_map[2];
// Two remote server addr.
next_addr_ =
std::make_pair(node.vm(0).next().ip(), node.vm(0).next().port());
prev_addr_ =
std::make_pair(node.vm(0).prev().ip(), node.vm(0).prev().port());
}
#else
this->party_config_.CopyFrom(config);
#endif
// Key when save model.
std::stringstream ss;
ss << config.job_id << "_" << config.task_id << "_party_" << local_id_
......@@ -148,9 +92,9 @@ int LogisticRegressionExecutor::loadParams(primihub::rpc::Task &task) {
num_iter_ = param_map["NumIters"].value_int32();
model_file_name_ = param_map["modelName"].value_string();
if (model_file_name_ == "")
if (model_file_name_ == "") {
model_file_name_ = "./" + model_name_ + ".csv";
}
} catch (std::exception &e) {
LOG(ERROR) << "Failed to load params: " << e.what();
return -1;
......@@ -231,8 +175,9 @@ int LogisticRegressionExecutor::_LoadDatasetFromCSV(std::string &dataset_id) {
for (int64_t j = 0; j < array_len; j++) {
if (j < train_length) {
train_input_(j, i) = 1;
} else
} else {
test_input_(j - train_length, i) = 1;
}
// m(j, i) = array->Value(j);
}
} else {
......@@ -242,10 +187,11 @@ int LogisticRegressionExecutor::_LoadDatasetFromCSV(std::string &dataset_id) {
table->column(i - 1)->chunk(0));
for (int64_t j = 0; j < array->length(); j++) {
if (j < train_length)
if (j < train_length) {
train_input_(j, i) = array->Value(j);
else
} else {
test_input_(j - train_length, i) = array->Value(j);
}
// m(j, i) = array->Value(j);
}
} else {
......@@ -297,161 +243,10 @@ int LogisticRegressionExecutor::loadDataset() {
return 0;
}
#ifdef MPC_SOCKET_CHANNEL
int LogisticRegressionExecutor::initPartyComm(void) {
VLOG(3) << "Next addr: " << next_addr_.first << ":" << next_addr_.second
<< ".";
VLOG(3) << "Prev addr: " << prev_addr_.first << ":" << prev_addr_.second
<< ".";
if (local_id_ == 0) {
std::ostringstream ss;
ss << "sess_" << local_id_ << "_1";
std::string sess_name_1 = ss.str();
ss.str("");
ss << "sess_" << local_id_ << "_2";
std::string sess_name_2 = ss.str();
ep_next_.start(ios_, next_addr_.first, next_addr_.second,
SessionMode::Server, sess_name_1);
LOG(INFO) << "[Next] Init server session, party " << local_id_ << ", "
<< "ip " << next_addr_.first << ", port " << next_addr_.second
<< ", name " << sess_name_1 << ".";
ep_prev_.start(ios_, prev_addr_.first, prev_addr_.second,
SessionMode::Server, sess_name_2);
LOG(INFO) << "[Prev] Init server session, party " << local_id_ << ", "
<< "ip " << prev_addr_.first << ", port " << prev_addr_.second
<< ", name " << sess_name_2 << ".";
} else if (local_id_ == 1) {
std::ostringstream ss;
ss << "sess_" << local_id_ << "_1";
std::string sess_name_1 = ss.str();
ss.str("");
ss << "sess_" << PrevPartyId() << "_1";
std::string sess_name_2 = ss.str();
ep_next_.start(ios_, next_addr_.first, next_addr_.second,
SessionMode::Server, sess_name_1);
LOG(INFO) << "[Next] Init server session, party " << local_id_ << ", "
<< "ip " << next_addr_.first << ", port " << next_addr_.second
<< ", name " << sess_name_1 << ".";
ep_prev_.start(ios_, prev_addr_.first, prev_addr_.second,
SessionMode::Client, sess_name_2);
LOG(INFO) << "[Prev] Init client session, party " << local_id_ << ", "
<< "ip " << prev_addr_.first << ", port " << prev_addr_.second
<< ", name " << sess_name_2 << ".";
} else {
std::ostringstream ss;
ss.str("");
ss << "sess_" << this->NextPartyId() << "_2";
std::string sess_name_1 = ss.str();
ss.str("");
ss << "sess_" << this->PrevPartyId() << "_1";
std::string sess_name_2 = ss.str();
ep_next_.start(ios_, next_addr_.first, next_addr_.second,
SessionMode::Client, sess_name_1);
LOG(INFO) << "[Next] Init client session, party " << local_id_ << ", "
<< "ip " << next_addr_.first << ", port " << next_addr_.second
<< ", name " << sess_name_1 << ".";
ep_prev_.start(ios_, prev_addr_.first, prev_addr_.second,
SessionMode::Client, sess_name_2);
LOG(INFO) << "[Prev] Init client session, party " << local_id_ << ", "
<< "ip " << prev_addr_.first << ", port " << prev_addr_.second
<< ", name " << sess_name_2 << ".";
}
auto chann_next = ep_next_.addChannel();
auto chann_prev = ep_prev_.addChannel();
chann_next.waitForConnection();
chann_prev.waitForConnection();
chann_next.send(local_id_);
chann_prev.send(local_id_);
uint16_t prev_party = 0;
uint16_t next_party = 0;
chann_next.recv(next_party);
chann_prev.recv(prev_party);
if (next_party != this->NextPartyId()) {
LOG(ERROR) << "Party " << local_id_ << ", expect next party id "
<< this->NextPartyId() << ", but give " << next_party << ".";
return -3;
}
if (prev_party != this->PrevPartyId()) {
LOG(ERROR) << "Party " << local_id_ << ", expect prev party id "
<< this->PrevPartyId() << ", but give " << prev_party << ".";
return -3;
}
chann_next.close();
chann_prev.close();
engine_.init(local_id_, ep_prev_, ep_next_, toBlock(local_id_));
LOG(INFO) << "Init party: " << local_id_ << " communication finish.";
return 0;
}
#else
int LogisticRegressionExecutor::initPartyComm(void) {
uint16_t prev_party_id = this->PrevPartyId();
uint16_t next_party_id = this->NextPartyId();
auto link_ctx = this->GetLinkContext();
if (link_ctx == nullptr) {
LOG(ERROR) << "link context is not available";
return -1;
}
auto& party_id_map = party_config_.PartyId2PartyNameMap();
auto& party_info_map = party_config_.PartyName2PartyInfoMap();
// construct channel for communication to next party
std::string party_name_next = party_id_map[next_party_id];
auto pb_party_node_next = party_info_map[party_name_next];
Node party_node_next;
pbNode2Node(pb_party_node_next, &party_node_next);
auto base_channel_next = link_ctx->getChannel(party_node_next);
// construct channel for communication to prev party
std::string party_name_prev = party_id_map[prev_party_id];
auto pb_party_node_prev = party_info_map[party_name_prev];
Node party_node_prev;
pbNode2Node(pb_party_node_prev, &party_node_prev);
auto base_channel_prev = link_ctx->getChannel(party_node_prev);
LOG(INFO) << "local_id_local_id_: " << local_id_;
LOG(INFO) << "next_party: " << party_name_next << " detail: " << party_node_next.to_string();
LOG(INFO) << "prev_party: " << party_name_prev << " detail: " << party_node_prev.to_string();
MpcChannel channel_next(this->party_name(), link_ctx);
MpcChannel channel_prev(this->party_name(), link_ctx);
channel_next.SetupBaseChannel(party_name_next, base_channel_next);
channel_prev.SetupBaseChannel(party_name_prev, base_channel_prev);
engine_.init(local_id_, channel_prev, channel_next, toBlock(local_id_));
return 0;
}
#endif
#ifdef MPC_SOCKET_CHANNEL
int LogisticRegressionExecutor::finishPartyComm(void) {
ep_next_.stop();
ep_prev_.stop();
engine_.fini();
return 0;
retcode LogisticRegressionExecutor::InitEngine() {
engine_.init(this->party_id(), this->comm_pkg_.get(), oc::toBlock(local_id_));
return retcode::SUCCESS;
}
#else
int LogisticRegressionExecutor::finishPartyComm(void) { return 0; }
#endif
int LogisticRegressionExecutor::_ConstructShares(sf64Matrix<D> &w,
sf64Matrix<D> &train_data,
......@@ -491,8 +286,8 @@ int LogisticRegressionExecutor::_ConstructShares(sf64Matrix<D> &w,
int row_index = 0;
for (int h = 0; h < 3; h++) {
for (int i = 0; i < train_shares[h].rows(); i++) {
for (int j = 0; j < train_shares[h].cols() - 1; j++) {
for (u64 i = 0; i < train_shares[h].rows(); i++) {
for (u64 j = 0; j < train_shares[h].cols() - 1; j++) {
train_data[0](row_index, j) = train_shares[h][0](i, j);
train_data[1](row_index, j) = train_shares[h][1](i, j);
}
......@@ -502,7 +297,7 @@ int LogisticRegressionExecutor::_ConstructShares(sf64Matrix<D> &w,
row_index = 0;
for (int h = 0; h < 3; h++) {
for (int i = 0; i < train_shares[h].rows(); i++) {
for (u64 i = 0; i < train_shares[h].rows(); i++) {
train_label[0](row_index, 0) = train_shares[h][0](i, num_cols);
train_label[1](row_index, 0) = train_shares[h][1](i, num_cols);
row_index++;
......@@ -551,8 +346,8 @@ int LogisticRegressionExecutor::_ConstructShares(sf64Matrix<D> &w,
row_index = 0;
for (int h = 0; h < 3; h++) {
for (int i = 0; i < test_shares[h].rows(); i++) {
for (int j = 0; j < test_shares[h].cols() - 1; j++) {
for (u64 i = 0; i < test_shares[h].rows(); i++) {
for (u64 j = 0; j < test_shares[h].cols() - 1; j++) {
test_data[0](row_index, j) = test_shares[h][0](i, j);
test_data[1](row_index, j) = test_shares[h][1](i, j);
}
......@@ -562,7 +357,7 @@ int LogisticRegressionExecutor::_ConstructShares(sf64Matrix<D> &w,
row_index = 0;
for (int h = 0; h < 3; h++) {
for (int i = 0; i < test_shares[h].rows(); i++) {
for (u64 i = 0; i < test_shares[h].rows(); i++) {
// train_label(row_index++, 1) = train_shares[h](i, num_cols);
test_label[0](row_index, 0) = test_shares[h][0](i, num_cols);
test_label[1](row_index, 0) = test_shares[h][1](i, num_cols);
......@@ -672,4 +467,4 @@ int LogisticRegressionExecutor::saveModel(void) {
return 0;
}
} // namespace primihub
} // namespace primihub
......@@ -17,33 +17,37 @@
#ifndef SRC_PRIMIHUB_ALGORITHM_LOGISTIC_H_
#define SRC_PRIMIHUB_ALGORITHM_LOGISTIC_H_
#include <time.h>
#include <stdlib.h>
#include <math.h>
#include <algorithm>
#include <exception>
#include <fstream>
#include <iostream>
#include <math.h>
#include <sstream>
#include <stdlib.h>
#include <string>
#include <time.h>
#include <vector>
#include <utility>
#include <memory>
#include "Eigen/Dense"
#include "src/primihub/algorithm/aby3ML.h"
#include "src/primihub/algorithm/base.h"
#include "src/primihub/algorithm/linear_model_gen.h"
#include "src/primihub/algorithm/plainML.h"
#include "src/primihub/algorithm/regression.h"
#include "src/primihub/common/clp.h"
#include "src/primihub/common/defines.h"
#include "src/primihub/common/type/type.h"
#include "src/primihub/data_store/driver.h"
#include "src/primihub/util/network/socket/channel.h"
#include "src/primihub/util/network/socket/ioservice.h"
#include "src/primihub/util/network/socket/session.h"
#include "cryptoTools/Common/Defines.h"
#include "aby3/sh3/Sh3FixedPoint.h"
namespace primihub {
#ifdef MPC_SOCKET_CHANNEL
using Session = oc::Session;
using IOService = oc::IOService;
using SessionMode = oc::SessionMode;
#endif
using Decimal = aby3::Decimal;
const Decimal D = Decimal::D20;
eMatrix<double> logistic_main(sf64Matrix<D> &train_data_0_1,
sf64Matrix<D> &train_label_0_1,
sf64Matrix<D> &W2_0_1,
......@@ -52,19 +56,18 @@ eMatrix<double> logistic_main(sf64Matrix<D> &train_data_0_1,
int IT, int pIdx);
class LogisticRegressionExecutor : public AlgorithmBase {
public:
public:
explicit LogisticRegressionExecutor(
PartyConfig &config, std::shared_ptr<DatasetService> dataset_service);
int loadParams(primihub::rpc::Task &task) override;
int loadDataset(void) override;
int initPartyComm(void) override;
int execute() override;
int finishPartyComm(void) override;
int constructShares(void);
int saveModel(void);
retcode InitEngine() override;
private:
private:
int _ConstructShares(sf64Matrix<D> &w, sf64Matrix<D> &train_data,
sf64Matrix<D> &train_label, sf64Matrix<D> &test_data,
sf64Matrix<D> &test_label);
......@@ -80,24 +83,15 @@ private:
eMatrix<double> train_input_;
eMatrix<double> test_input_;
eMatrix<double> model_;
aby3ML engine_;
#ifdef MPC_SOCKET_CHANNEL
std::pair<std::string, uint16_t> next_addr_;
std::pair<std::string, uint16_t> prev_addr_;
Session ep_next_;
Session ep_prev_;
IOService ios_;
#else
PartyConfig party_config_;
#endif
// Logistic regression parameters
std::string train_input_filepath_, test_input_filepath_;
int batch_size_, num_iter_;
std::string train_input_filepath_;
std::string test_input_filepath_;
int batch_size_;
int num_iter_;
};
} // namespace primihub
} // namespace primihub
#endif // SRC_PRIMIHUB_ALGORITHM_LOGISTIC_H_
#endif // SRC_PRIMIHUB_ALGORITHM_LOGISTIC_H_
// Copyright [2021] <primihub.com>
#include "src/primihub/algorithm/logistic_plain.h"
using namespace std;
using namespace Eigen;
#include <random>
#include "src/primihub/util/eigen_util.h"
#include "aby3/sh3/Sh3Types.h"
#include "cryptoTools/Crypto/PRNG.h"
#include "cryptoTools/Common/Defines.h"
#include "src/primihub/common/common.h"
#include "src/primihub/algorithm/linear_model_gen.h"
#include "src/primihub/algorithm/regression.h"
#include "src/primihub/algorithm/plainML.h"
// using namespace std;
// using namespace Eigen;
namespace primihub {
template<typename T>
using eMatrix = aby3::eMatrix<T>;
using PRNG = oc::PRNG;
template <typename Derived>
void writeToCSVfile(std::string name, const Eigen::MatrixBase<Derived>& matrix)
......@@ -15,8 +28,9 @@ void writeToCSVfile(std::string name, const Eigen::MatrixBase<Derived>& matrix)
}
// 明文的逻辑回归
void plain_Logistic_sample(eMatrix<double>& X, eMatrix<double>& Y, eMatrix<double>& mModel, double mNoise = 1,
double mSd = 1, bool print = false) {
void plain_Logistic_sample(eMatrix<double>& X, eMatrix<double>& Y,
eMatrix<double>& mModel, double mNoise = 1,
double mSd = 1, bool print = false) {
if (X.rows() != Y.rows()) throw std::runtime_error(LOCATION);
if (1 != Y.cols()) throw std::runtime_error(LOCATION);
if (X.cols() != mModel.rows()) throw std::runtime_error(LOCATION);
......@@ -34,7 +48,7 @@ void plain_Logistic_sample(eMatrix<double>& X, eMatrix<double>& Y, eMatrix<doubl
Y = X * mModel + noise;
IOFormat HeavyFmt(FullPrecision, 0, ", ", ";\n", "[", "]");
Eigen::IOFormat HeavyFmt(Eigen::FullPrecision, 0, ", ", ";\n", "[", "]");
for (int i = 0; i < Y.size(); ++i) {
if (print) {
std::cout << X.row(i).format(HeavyFmt);
......@@ -47,9 +61,9 @@ void plain_Logistic_sample(eMatrix<double>& X, eMatrix<double>& Y, eMatrix<doubl
int logistic_plain_main() {
int N = 1000, D = 100, B = 128, IT = 10, testN = 100;
PRNG prng(toBlock(1));
PRNG prng(oc::toBlock(1));
eMatrix<double> model(D, 1);
for (u64 i = 0; i < D; ++i) {
for (int i = 0; i < D; ++i) {
model(i, 0) = prng.get<int>() % 10;
}
......@@ -57,7 +71,7 @@ int logistic_plain_main() {
eMatrix<double> test_data(testN, D), test_label(testN, 1);
plain_Logistic_sample(train_data, train_label, model);
MatrixXd result;
Eigen::MatrixXd result;
result.resize(train_data.rows(), train_data.cols() + 1);
result << train_data, train_label;
writeToCSVfile("matrix.csv", result);
......@@ -78,7 +92,7 @@ int logistic_plain_main() {
SGD_Logistic(params, engine, train_data, train_label, W2,
&test_data, &test_label);
for (u64 i = 0; i < D; ++i) {
for (int i = 0; i < D; ++i) {
std::cout << i << " " << model(i, 0) << " " << W2(i, 0) << std::endl;
}
......@@ -88,13 +102,13 @@ int logistic_plain_main() {
int logistic_2plain_main(std::string &filename) {
int B = 128, IT = 10000, testN = 1000;
PRNG prng(toBlock(1));
PRNG prng(oc::toBlock(1));
LogisticModelGen gen;
eMatrix<double> bytematrix;
u64 row, col;
bytematrix = load_data_local_logistic(filename);
bytematrix = LoadDataLocalLogistic(filename);
row = bytematrix.rows();
col = bytematrix.cols();
......@@ -139,7 +153,7 @@ int logistic_2plain_main(std::string &filename) {
SGD_Logistic(params, engine, train_data, train_label, W2,
&test_data, &test_label);
for (u64 i = 0; i < D; ++i) {
for (int i = 0; i < D; ++i) {
std::cout << i << " " << W2(i, 0) << std::endl;
}
......
// Copyright [2021] <primihub.com>
#ifndef SRC_primihub_ALGORITHM_LOGISTIC_PLAIN_H_
#define SRC_primihub_ALGORITHM_LOGISTIC_PLAIN_H_
#ifndef SRC_PRIMIHUB_ALGORITHM_LOGISTIC_PLAIN_H_
#define SRC_PRIMIHUB_ALGORITHM_LOGISTIC_PLAIN_H_
#include <time.h>
#include <cmath>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <math.h>
#include <vector>
#include <algorithm>
#include <stdlib.h>
#include <exception>
#include <time.h>
#include "Eigen/Dense"
#include "src/primihub/common/defines.h"
#include "src/primihub/common/clp.h"
#include "src/primihub/common/type/type.h"
// #include "src/primihub/data_service/dataload.h"
#include "src/primihub/algorithm/regression.h"
#include "src/primihub/algorithm/linear_model_gen.h"
#include "src/primihub/algorithm/aby3ML.h"
#include "src/primihub/algorithm/plainML.h"
#include "src/primihub/util/network/socket/ioservice.h"
#include "src/primihub/util/network/socket/channel.h"
#include "src/primihub/data_store/driver_legcy.h"
namespace primihub {
int logistic_plain_main();
int logistic_2plain_main(std::string &filename);
}
#endif // SRC_primihub_ALGORITHM_LOGISTIC_H_
#endif // SRC_PRIMIHUB_ALGORITHM_LOGISTIC_PLAIN_H_
......@@ -12,8 +12,7 @@
#include <vector>
#include "src/primihub/algorithm/base.h"
#include "src/primihub/common/defines.h"
#include "src/primihub/common/type/type.h"
#include "src/primihub/common/type.h"
#include "src/primihub/data_store/driver.h"
#include "src/primihub/executor/express.h"
#include "src/primihub/service/dataset/service.h"
......@@ -21,91 +20,48 @@
namespace primihub {
class MissingProcess : public AlgorithmBase {
public:
public:
explicit MissingProcess(PartyConfig &config,
std::shared_ptr<DatasetService> dataset_service);
int loadParams(primihub::rpc::Task &task) override;
int loadDataset(void) override;
int initPartyComm(void) override;
int execute() override;
int finishPartyComm(void) override;
int saveModel(void);
retcode InitEngine() override;
int set_task_info(std::string platform_type, std::string job_id,
std::string task_id);
inline std::string platform() { return platform_type_; }
inline std::string job_id() { return job_id_; }
inline std::string task_id() { return task_id_; }
private:
private:
using NestedVectorI32 = std::vector<std::vector<uint32_t>>;
inline int _strToInt64(const std::string &str, int64_t &i64_val);
inline int _strToDouble(const std::string &str, double &d_val);
inline int _avoidStringArray(std::shared_ptr<arrow::Array> array);
inline void _buildNewColumn(std::vector<std::string> &col_val,
std::shared_ptr<arrow::Array> &array);
inline void _buildNewColumn(std::shared_ptr<arrow::Table> table,
int col_index, const std::string &replace,
NestedVectorI32 &abnormal_index, bool need_double,
std::shared_ptr<arrow::Array> &new_array);
inline void _buildNewColumn(std::shared_ptr<arrow::Table> table,
int col_index, const std::string &replace,
std::vector<int> both_index, bool need_double,
std::shared_ptr<arrow::Array> &new_array);
int _strToInt64(const std::string &str, int64_t &i64_val);
int _strToDouble(const std::string &str, double &d_val);
int _avoidStringArray(std::shared_ptr<arrow::Array> array);
void _buildNewColumn(std::vector<std::string> &col_val,
std::shared_ptr<arrow::Array> &array);
void _buildNewColumn(std::shared_ptr<arrow::Table> table,
int col_index, const std::string &replace,
NestedVectorI32 &abnormal_index, bool need_double,
std::shared_ptr<arrow::Array> &new_array);
void _buildNewColumn(std::shared_ptr<arrow::Table> table,
int col_index, const std::string &replace,
std::vector<int> both_index, bool need_double,
std::shared_ptr<arrow::Array> &new_array);
int _LoadDatasetFromCSV(std::string &filename);
int _LoadDatasetFromDB(std::string &source);
void _spiltStr(string str, const string &split, std::vector<string> &strlist);
void _spiltStr(std::string str,
const std::string &split,
std::vector<std::string> &strlist);
std::unique_ptr<MPCOperator> mpc_op_exec_{nullptr};
std::string job_id_{""};
std::string task_id_{""};
#ifdef MPC_SOCKET_CHANNEL
IOService ios_;
Session ep_next_;
Session ep_prev_;
std::string next_ip_{""}, prev_ip_{""};
uint16_t next_port_{0}, prev_port_{0};
#else
ABY3PartyConfig party_config_;
uint16_t local_party_id_{0};
uint16_t next_party_id_{0};
uint16_t prev_party_id_{0};
primihub::Node local_node_;
std::shared_ptr<network::IChannel> base_channel_next_{nullptr};
std::shared_ptr<network::IChannel> base_channel_prev_{nullptr};
std::shared_ptr<MpcChannel> mpc_channel_next_{nullptr};
std::shared_ptr<MpcChannel> mpc_channel_prev_{nullptr};
std::map<uint16_t, primihub::Node> partyid_node_map_;
#endif
std::map<std::string, uint32_t> col_and_dtype_;
std::vector<std::string> local_col_names;
std::string data_file_path_{""};
std::string replace_type_{""};
std::string conn_info_{""};
std::shared_ptr<arrow::Table> table{nullptr};
std::map<std::string, std::vector<int>> db_both_index;
bool use_db{false};
std::string table_name{""};
std::string node_id_{""};
uint32_t party_id_{0};
std::string new_dataset_id_{""};
std::string new_dataset_path_{""};
std::string platform_type_{""};
template <class T>
void replaceValue(map<std::string, uint32_t>::iterator &iter,
void replaceValue(std::map<std::string, uint32_t>::iterator &iter,
std::shared_ptr<arrow::Table> &table, int &col_index,
T &col_value,
std::vector<std::vector<unsigned int>> &abnormal_index,
......@@ -144,6 +100,30 @@ private:
table = result.ValueOrDie();
LOG(INFO) << "Finish.";
}
private:
std::unique_ptr<MPCOperator> mpc_op_exec_{nullptr};
std::string job_id_{""};
std::string task_id_{""};
std::map<std::string, uint32_t> col_and_dtype_;
std::vector<std::string> local_col_names;
std::string data_file_path_{""};
std::string replace_type_{""};
std::string conn_info_{""};
std::shared_ptr<arrow::Table> table{nullptr};
std::map<std::string, std::vector<int>> db_both_index;
bool use_db{false};
std::string table_name{""};
std::string node_id_{""};
uint32_t party_id_{0};
std::string new_dataset_id_{""};
std::string new_dataset_path_{""};
std::string platform_type_{""};
};
} // namespace primihub
......
// "Copyright [2023] <Primihub>"
#include "src/primihub/algorithm/mpc_statistics.h"
#include "src/primihub/common/common.h"
#include "src/primihub/util/file_util.h"
#include <arrow/api.h>
#include <arrow/csv/api.h>
#include <arrow/csv/writer.h>
......@@ -9,24 +7,29 @@
#include <arrow/result.h>
#include <arrow/status.h>
#include <arrow/table.h>
#include <rapidjson/document.h>
#include <algorithm>
#include <utility>
#include "src/primihub/common/common.h"
#include "src/primihub/util/file_util.h"
#include "src/primihub/util/network/message_interface.h"
using namespace rapidjson;
using primihub::columnDtypeToString;
// using primihub::columnDtypeToString;
namespace primihub {
MPCStatisticsExecutor::MPCStatisticsExecutor(
PartyConfig &config, std::shared_ptr<DatasetService> dataset_service)
: AlgorithmBase(dataset_service) {
party_config_.Init(config);
party_id_ = party_config_.SelfPartyId();
// party_id_ = party_config_.SelfPartyId();
this->set_party_name(party_config_.SelfPartyName());
this->set_party_id(party_config_.SelfPartyId());
// node_id_ = config.node_id;
// job_id_ = config.job_id;
// task_id_ = config.task_id;
// // Save all party's node config.
// const auto &node_map = config.node_map;
// for (auto iter = node_map.begin(); iter != node_map.end(); iter++) {
......@@ -358,48 +361,11 @@ int MPCStatisticsExecutor::execute() {
return 0;
}
int MPCStatisticsExecutor::initPartyComm() {
if (do_nothing_) {
LOG(WARNING) << "Skip setup channel due to nothing to do.";
return 0;
}
auto link_ctx = this->GetLinkContext();
if (link_ctx == nullptr) {
LOG(ERROR) << "link context is unavailable";
return -1;
}
// uint16_t next_party = (party_id_ + 1) % 3;
// construct channel for next party
std::string next_party_name = this->party_config_.NextPartyName();
Node next_party_info = this->party_config_.NextPartyInfo();
auto base_channel_1 = link_ctx->getChannel(next_party_info);
LOG(INFO) << "Create channel to node " << next_party_info.to_string() << ".";
channel_1 = std::make_shared<MpcChannel>(
this->party_config_.SelfPartyName(), link_ctx);
channel_1->SetupBaseChannel(next_party_name, base_channel_1);
// construct channel for prev party
std::string prev_party_name = this->party_config_.PrevPartyName();
// next_party = (party_id_ + 2) % 3;
Node prev_party_info = this->party_config_.PrevPartyInfo();
auto base_channel_2 = link_ctx->getChannel(prev_party_info);
LOG(INFO) << "Create channel to node " << prev_party_info.to_string() << ".";
channel_2 = std::make_shared<MpcChannel>(
this->party_config_.SelfPartyName(), link_ctx);
channel_2->SetupBaseChannel(prev_party_name, base_channel_2);
executor_->setupChannel(party_id_, *channel_2, *channel_1);
return 0;
retcode MPCStatisticsExecutor::InitEngine() {
executor_->setupChannel(this->party_id(), this->CommPkgPtr());
return retcode::SUCCESS;
}
int MPCStatisticsExecutor::finishPartyComm() { return 0; }
int MPCStatisticsExecutor::loadDataset() {
if (do_nothing_) {
LOG(WARNING) << "Skip load dataset due to nothing to do.";
......@@ -467,11 +433,16 @@ int MPCStatisticsExecutor::saveModel() {
#else
int MPCStatisticsExecutor::loadParams(primihub::rpc::Task &task) {return 0;}
int MPCStatisticsExecutor::loadDataset() {return 0;}
int MPCStatisticsExecutor::initPartyComm() {return 0;}
int MPCStatisticsExecutor::execute() {return 0;}
int MPCStatisticsExecutor::finishPartyComm() {return 0;}
int MPCStatisticsExecutor::saveModel() {return 0;}
retcode MPCStatisticsExecutor::_parseColumnName(const std::string &json_str) {return retcode::SUCCESS;}
retcode MPCStatisticsExecutor::_parseColumnDtype(const std::string &json_str) {return retcode::SUCCESS;}
retcode MPCStatisticsExecutor::InitEngine() {
return retcode::SUCCESS;
}
retcode MPCStatisticsExecutor::_parseColumnName(const std::string &json_str) {
return retcode::SUCCESS;
}
retcode MPCStatisticsExecutor::_parseColumnDtype(const std::string &json_str) {
return retcode::SUCCESS;
}
#endif // endif MPC_SOCKET_CHANNEL
}; // namespace primihub
package(default_visibility = ["//visibility:public"])
cc_library(
name = "lib_opt_paillier_impl",
srcs = glob([
......@@ -9,5 +10,4 @@ cc_library(
deps = [
"@com_github_gmp//:gmp",
],
visibility = ["//visibility:public"],
)
\ No newline at end of file
// Copyright [2021] <primihub.com>
#include "src/primihub/algorithm/plainML.h"
此差异已折叠。
// Copyright [2021] <primihub.com>
#include "src/primihub/util/network/socket/iobuffer.h"
namespace primihub {
}
\ No newline at end of file
#include "src/primihub/algorithm/regression.h"
此差异已折叠。
......@@ -18,7 +18,6 @@ cc_library(
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/memory",
"@com_github_glog_glog//:glog",
"@toolkit_relic//:relic",
"//src/primihub/protos:worker_proto",
"//src/primihub/protos:service_proto",
"//src/primihub/common:common_lib",
......
......@@ -325,7 +325,10 @@ retcode BuildFederatedRequest(const nlohmann::json& js_task_config, rpc::Task* t
continue;
}
auto dataset_ptr = (*party_dataset_ptr)[party_name].mutable_data();
(*dataset_ptr)["data_set"] = role_param["data_set"].get<std::string>();
std::string data_key = role_param["data_set"].get<std::string>();
if (!data_key.empty()) {
(*dataset_ptr)["data_set"] = data_key;
}
}
return retcode::SUCCESS;
}
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。