未验证 提交 84408c03 编写于 作者: PhoenixTree2013's avatar PhoenixTree2013 提交者: GitHub

Merge pull request #218 from primihub/develop

Release 1.5.0 PR
name: Backup Git repository
on:
workflow_dispatch:
push:
branches:
- develop
jobs:
BackupGit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: backup
uses: jenkins-zh/git-backup-actions@v0.0.7
env:
GIT_DEPLOY_KEY: ${{ secrets.GIT_DEPLOY_KEY }}
TARGET_GIT: "git@gitee.com:primihub/primihub.git"
BRANCH: develop
...@@ -28,18 +28,36 @@ RUN apt update \ ...@@ -28,18 +28,36 @@ RUN apt update \
# install bazelisk # install bazelisk
RUN npm install -g @bazel/bazelisk RUN npm install -g @bazel/bazelisk
# Install keyword PIR dependencies
WORKDIR /opt
RUN wget https://github.com/zeromq/libzmq/archive/refs/tags/v4.3.4.tar.gz \
&& tar -zxf v4.3.4.tar.gz && mkdir libzmq-4.3.4/build && cd libzmq-4.3.4/build \
&& cmake .. && make && make install
RUN wget https://github.com/zeromq/cppzmq/archive/refs/tags/v4.9.0.tar.gz \
&& tar -zxf v4.9.0.tar.gz && mkdir cppzmq-4.9.0/build && cd cppzmq-4.9.0/build \
&& cmake .. && make && make install
RUN wget https://github.com/google/flatbuffers/archive/refs/tags/v2.0.0.tar.gz \
&& tar -zxf v2.0.0.tar.gz && mkdir flatbuffers-2.0.0/build && cd flatbuffers-2.0.0/build \
&& cmake .. && make && make install
RUN wget https://sourceforge.net/projects/tclap/files/tclap-1.2.5.tar.gz \
&& tar -zxvf tclap-1.2.5.tar.gz && cd tclap-1.2.5 && ./configure \
&& make && make install
WORKDIR /src WORKDIR /src
ADD . /src ADD . /src
# Bazel build primihub-node & primihub-cli & paillier shared library # Bazel build primihub-node & primihub-cli & paillier shared library
RUN bash pre_build.sh \ RUN bash pre_build.sh \
&& bazel build --cxxopt=-D_AMD64_ --config=linux :node :cli :opt_paillier_c2py && bazel build --cxxopt=-D_AMD64_ --config=linux --define microsoft-apsi=true :node :cli :opt_paillier_c2py
FROM ubuntu:20.04 as runner FROM ubuntu:20.04 as runner
# Install python3 and GCC openmp (Depends with cryptFlow2 library) # Install python3 and GCC openmp (Depends with cryptFlow2 library)
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y python3 python3-dev libgomp1 python3-pip \ && apt-get install -y python3 python3-dev libgomp1 python3-pip libzmq5 \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
ARG TARGET_PATH=/root/.cache/bazel/_bazel_root/f8087e59fd95af1ae29e8fcb7ff1a3dc/execroot/primihub/bazel-out/k8-fastbuild/bin ARG TARGET_PATH=/root/.cache/bazel/_bazel_root/f8087e59fd95af1ae29e8fcb7ff1a3dc/execroot/primihub/bazel-out/k8-fastbuild/bin
......
...@@ -6,7 +6,7 @@ ENV DEBIAN_FRONTEND=noninteractive ...@@ -6,7 +6,7 @@ ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y python3 python3-dev libgmp-dev python3-pip git \ && apt-get install -y python3 python3-dev libgmp-dev python3-pip libzmq5 git \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
ARG TARGET_PATH=/root/.cache/bazel/_bazel_root/17a1cd4fb136f9bc7469e0db6305b35a/execroot/__main__/bazel-out/k8-fastbuild/bin ARG TARGET_PATH=/root/.cache/bazel/_bazel_root/17a1cd4fb136f9bc7469e0db6305b35a/execroot/__main__/bazel-out/k8-fastbuild/bin
...@@ -33,7 +33,7 @@ COPY ./src/primihub/protos/ ./src/primihub/protos/ ...@@ -33,7 +33,7 @@ COPY ./src/primihub/protos/ ./src/primihub/protos/
WORKDIR /app/python WORKDIR /app/python
RUN python3 -m pip install --upgrade pip \ RUN python3 -m pip install --upgrade pip \
&& python3 -m pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple \ && python3 -m pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple \
&& python3 setup.py develop \ && python3 setup.py develop \
&& python3 setup.py solib --solib-path $TARGET_PATH && python3 setup.py solib --solib-path $TARGET_PATH
......
...@@ -29,11 +29,12 @@ curl https://get.primihub.com/release/1.3.9/docker-compose.yml -s -o docker-comp ...@@ -29,11 +29,12 @@ curl https://get.primihub.com/release/1.3.9/docker-compose.yml -s -o docker-comp
Start three docker containers using docker-compose. Start three docker containers using docker-compose.
The container includes: one simple bootstrap node, three nodes The container includes: one simple bootstrap node, three nodes
```shell
docker-compose up -d docker-compose up -d
```
or, you could specific the container register and version, such as: or, you could specific the container register and version, such as:
```shell
REGISTRY=registry.cn-beijing.aliyuncs.com TAG=1.4.0 docker-compose up -d REGISTRY=registry.cn-beijing.aliyuncs.com TAG=1.4.0 docker-compose up -d
``` ```
......
...@@ -9,11 +9,14 @@ filegroup( ...@@ -9,11 +9,14 @@ filegroup(
cmake( cmake(
name = "APSI", name = "APSI",
env={ #env={
"HTTPS_PROXY": "http://127.0.0.1:7890", # "HTTPS_PROXY": "http://127.0.0.1:7890",
"HTTP_PROXY": "http://127.0.0.1:7890", # "HTTP_PROXY": "http://127.0.0.1:7890",
"https_proxy": "http://127.0.0.1:7890", # "https_proxy": "http://127.0.0.1:7890",
"http_proxy": "http://127.0.0.1:7890", # "http_proxy": "http://127.0.0.1:7890",
#},
cache_entries = {
"SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT": "OFF",
}, },
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
# cache_entries = { # cache_entries = {
......
...@@ -3,19 +3,22 @@ ...@@ -3,19 +3,22 @@
load("@rules_foreign_cc//foreign_cc:defs.bzl", "cmake") load("@rules_foreign_cc//foreign_cc:defs.bzl", "cmake")
filegroup( filegroup(
name = "src", name = "src",
srcs = glob(["**"]), srcs = glob(["**"]),
visibility = ["//visibility:public"] visibility = ["//visibility:public"]
) )
cmake( cmake(
name = "seal", name = "seal",
env={ #env={
"HTTPS_PROXY": "http://127.0.0.1:1080", # "HTTPS_PROXY": "http://127.0.0.1:1080",
"HTTP_PROXY": "http://127.0.0.1:1080", # "HTTP_PROXY": "http://127.0.0.1:1080",
"https_proxy": "http://127.0.0.1:1080", # "https_proxy": "http://127.0.0.1:1080",
"http_proxy": "http://127.0.0.1:1080", # "http_proxy": "http://127.0.0.1:1080",
}, #},
cache_entries = {
"SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT": "OFF",
},
defines = [ defines = [
"SEAL_USE_CXX17=ON", "SEAL_USE_CXX17=ON",
"SEAL_USE_INTRIN=ON", "SEAL_USE_INTRIN=ON",
......
...@@ -9,7 +9,7 @@ fi ...@@ -9,7 +9,7 @@ fi
bash pre_build.sh bash pre_build.sh
bazel build --config=linux :node :cli :opt_paillier_c2py bazel build --cxxopt=-D_AMD64_ --config=linux --define microsoft-apsi=true :node :cli :opt_paillier_c2py
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Build failed!!!" echo "Build failed!!!"
......
...@@ -42,6 +42,11 @@ datasets: ...@@ -42,6 +42,11 @@ datasets:
model: "csv" model: "csv"
source: "data/pir_server.csv" source: "data/pir_server.csv"
# keyword PIR test case dataset
- description: "keyword_pir_server_data"
model: "csv"
source: "data/keyword_pir_server.csv"
# PSI test case dataset # PSI test case dataset
- description: "psi_server_data" - description: "psi_server_data"
model: "csv" model: "csv"
...@@ -59,6 +64,12 @@ datasets: ...@@ -59,6 +64,12 @@ datasets:
- description: "test_dataset" - description: "test_dataset"
model: "csv" model: "csv"
source: "data/FL/wisconsin_test.data" source: "data/FL/wisconsin_test.data"
- description: "breast_1"
model: "csv"
source: "data/FL/wisconsin.data"
- description: "homo_lr_test"
model: "csv"
source: "data/FL/homo_lr_test.data"
localkv: localkv:
......
...@@ -8,33 +8,38 @@ datasets: ...@@ -8,33 +8,38 @@ datasets:
# ABY3 LR test case datasets # ABY3 LR test case datasets
- description: "train_party_1" - description: "train_party_1"
model: "csv" model: "csv"
source: "/tmp/train_party_1.csv" source: "/tmp/train_party_1.csv"
- description: "test_party_1" - description: "test_party_1"
model: "csv" model: "csv"
source: "/tmp/test_party_1.csv" source: "/tmp/test_party_1.csv"
# MNIST test case datasets # MNIST test case datasets
- description: "test_party_1_self" - description: "test_party_1_self"
model: "csv" model: "csv"
source: "/tmp/falcon/dataset/MNIST/input_1" source: "/tmp/falcon/dataset/MNIST/input_1"
- description: "test_party_1_next" - description: "test_party_1_next"
model: "csv" model: "csv"
source: "/tmp/falcon/dataset/MNIST/input_2" source: "/tmp/falcon/dataset/MNIST/input_2"
# PIR test case dataset # PIR test case dataset
- description: "pir_server_data" - description: "pir_server_data"
model: "csv" model: "csv"
source: "/tmp/pir_server.csv" source: "/tmp/pir_server.csv"
# keyword PIR test case dataset
- description: "keyword_pir_server_data"
model: "csv"
source: "/tmp/keyword_pir_server.csv"
# PSI test case dataset # PSI test case dataset
- description: "psi_server_data" - description: "psi_server_data"
model: "csv" model: "csv"
source: "/tmp/server_e.csv" source: "/tmp/server_e.csv"
# FL xgb test case datasets # FL xgb test case datasets
- description: "label_dataset" - description: "label_dataset"
model: "csv" model: "csv"
source: "/tmp/FL/wisconsin_host.data" source: "/tmp/FL/wisconsin_host.data"
- description: "test_dataset" - description: "test_dataset"
model: "csv" model: "csv"
source: "/tmp/FL/wisconsin_test.data" source: "/tmp/FL/wisconsin_test.data"
...@@ -45,7 +50,7 @@ datasets: ...@@ -45,7 +50,7 @@ datasets:
model: "csv" model: "csv"
source: "/tmp/FL/homo_lr_test.data" source: "/tmp/FL/homo_lr_test.data"
localkv: localkv:
model: "leveldb" model: "leveldb"
path: "/data/localdb1" path: "/data/localdb1"
...@@ -56,4 +61,5 @@ p2p: ...@@ -56,4 +61,5 @@ p2p:
multi_addr: "/ip4/172.28.1.11/tcp/8887" multi_addr: "/ip4/172.28.1.11/tcp/8887"
dht_get_value_timeout: 120 dht_get_value_timeout: 120
notify_server: 0.0.0.0:6667 notify_server: 0.0.0.0:6667
\ No newline at end of file
label1,label2
lLAnqwEihmGXxVPZZESncfgaaIZIhoPpMEmHPSFUoqUgUHBnMUddmTVwHfxEsqGg,SifXKvuBLDWNptdGIwpLEpcMqYlGQkALiJXABVpYsNmVMwSkUIWSbjDjChInhzwF
tsjVctWbHnsvMcnOVRKcHbTeVFifXdNAcCJYPCAtLRAMOmzqMGVEEWgcTOgoEwoc,jajADuiaxdQSRvfqmuKyyXOggleDMgGPTWjsYzIzpXjwNlcoXvGRigVLYaPZVvWY
GmbwOtczJMlNaFTlGhQVFEcYgfMcMOXLjbvcqUWJCkvFdEXRlCXylFcFToedcrdc,JmibcbboYpjyuShOyqYvjYbZunltYcKWOuVzxCRicEfbZZRTdeoCkvzqKxhRmdXM
IoqdwHwrBsNRrWngkYZiBzUMqWWjfkpiEXZcvDKhYGJvIibBxWgkufFymbzTzDJe,cVDnRiDfYjFrqRiLILJugXjYxCYTCjIqGTEMBLcKOsNyHZNSinBJiHZyISKwnCoB
VjBWFAmqrPKraFuJwuiFaXJXvLskePqSqVKVwumyfulYWJPNkwfgHVyISSxsBKBt,vlBtRulmBgompFuufjJMHxRLYDSWSxyFOlYgYoqMBWhyszwLaVRnwsoKuNKVyHOF
HXfUhjJCfMssfPIjhDBXeMyZFmfbIAYvijkSCsyqvoGsJwcFhZiYIYSpFDdTUxvG,oaNQNbHOYCfJpPxqNvNBRbpAbPibqXtlMUvGLIQhEaLktOjdRXlpjilcTdzbOIDD
yubIODoRnLsFPrQSpwAkZKxFoUIzBrbUJWSMlvBqtjPbLKRVuTxgUclEefUUcNDN,UEuVGElnzIvwigsKdRJVkmYxNCGqjCKEusvksuIFShpdEcemGWujQOwiFvsgBGGO
RXryiPrKLaQJNrLqwkRBhfKoalfzWWFQUjzITCiBJczKPcucKBpzqcXRBpyfbNSw,RpFTWbBAHXNQihMXezkGATVukNTeAsxFvBYuEwFLQzmnladOWrtpBAlslIAbYFen
eDmAByAWSXcbazPARJcxQUNPpQeZOwZIqrChHbeEVGXoJypEDwZmvGJZeMiqMipC,crtTjWmHxeFyRuMOrHHHoBCceGUJKbnhQMXOjqrIQXEPNEYmwPSiBvAOLoeItIsE
NTfMsymXWcKePZGHhrWssKblnkbOYziixjpYeofoEaTiclmcNqGdEGAtsXVcyjMt,cQIeVdQUfovclnKVhgdFxdNFfsLCisYgIKlXycEIXjmslvlOtGROMFkVcZfylSBK
kBhvqaanVZdoaPgrPLFJraHGKlkonygVWNOGPsLzfLIbOOhajevgWyibKTvrmkNj,epwxZEYjJtSnQEaprMRkOyZZTxekMKEXWvFZmWHJSQMiHobJWGUwJUNGmDbwjdyy
SWvjYXDRJCmyZdrvQdGCGQjjSTWtnvSFBHeHWewgyXVMnWtLXwuLngTRXipovvlo,krYfUkQoURITLRngnLWTRsLooDuNFCLbLswdwkagACbFTrmKamNIfYRSUJFyGMXe
eXSZIVckiSpbQekmbAhvTuuFQuQMqyQNuMFyWllbDGVMXhTampdxcAWARJgZPAmw,KFENkfPFWMDqOVuBzKrZktsxatqWBRylQEZSjtPCbbWghZwPpNnyZQoWrLjYbhqa
xutpaCDOGrpYOiZNEeZNUUuJZCuGRhXRpGhYoefWnjUoNEnUjaxedZktbIpcVzZg,oHJqpqDZHzrFzHpAohUZszYWllJvXijtQCYehmFrOjUwLPAEisgXKauAwYzviaVa
SXUiwPmLNohCROXPMZIqbnLrfhCtREPzCiDYZaDghlkfGBCTqyPdqnjoaWzyNzBT,qTtpUGIhnpCnsnrgTktWTgEGlariwydiHZZiYInWjxqHrYXPpxirZKQRaHTiQDTy
iWmeZawoNiYkvQXAUPEvTNPFVuqqfiKgJfunkQKvIemNCrjZmXmhvfLIJRiEuUkd,ChWkJFntjBPAOBSzyBGQiRcpGUTTmfmkRxmfrNRnvHlYjIASbZlLBOcrZubbgaYz
RdtLFTxlxrqGlohnSMfYculjIOgWagTnxGXpIYdGeFwxieQCEinNYapTNhicCcsA,BwvjWbBNBBkpmDDPQrtuLPrDWbfleviuwIFKLlmAqTAOMwMAEaCIwUBgHNhESSHO
RdVbZxUnQmBOaISAHewIxGiMserWZIwsUrNzoQBGqevyFDCoFKoyiFCrEpMxUtuQ,SXeTBUElBaAROKAzzdMfodmWQLxxwzIfnkRCtqRXpwbfKFcRGHgpXOyadCkOcodH
IfJPsaDSeYRWmwKBnOSRYXbrDZFoAOJeOmBbsWOwoiqELKummuoSJGZchZDECoLB,mbwtlADLEFJHjpIjWklrqCKbXXgSTRaOlCMsXLTvhIYnHKnKgNXaCLoXNfYlpGYq
sJAicHrwiiQKwAyhUpdNFTofiPmsOtsHnEuatdnlbNLdobvUaOsKxuKOBuZTSJjo,NVcUinUjKhVpyVsYeVogtnQYRJeFaBAoUgyWorNHqnYbRWysFpEUueggumLBONES
YpNhFNfrreTFHgbkbRSyyJZPJsenNRgrGEmSINyuzzdpCPJfsXqiGMtYLzebidzr,THuPvAlfFtWMbQkaZertVJPCWIJQNaYAORTLFONxnXVkVRSOdmxqFFrnrNwtnQUQ
uIzkWOiwASWsvNZTxoESJbfHCgPqDIJHEjaSRDIKZSCcwbmaMjMtywXXQEiFJCDB,jqrZyVUYISeuAIqlSZrSulAFlljcAmWvaxwlFUdOHywMyVioAicxwQbTZzsLvGex
iZdWVDpgPxRTVNGGLysBmxSNAKPHOvcvYntzXNAAHjtpAZAtlXoAHdEkULwdHTJm,sihCzBdpPTdNWrhljNwIguVAetQLBWTSIjuWYWbFQPWtjBUhKQkpbDZAHDqRVdhf
KLbZtoFzpWdewNEcEScBoQrRGXFGlFnzcSZBCBrgPNeEtHsZEKcikgAODVwDUYmJ,tyDLJzuFImYiNYrcqaCFuShtiXYpgCaOPnEZADCRYpPtebHMOpnrXABxmoEpyLSz
IGvEVefbEuKPEIEoRxBUhGaiJlbaQqoGtXhzLFdrKMdBpqAkJJOqNvomqPBvBKmL,KDhfQGgBrXDtbJOJFKvwuqMOgkIaMSMziPwAilPVLqoJTsbBbSjblTVEPlQBmbly
SJUAOMgVISbvIVYnGKcfBfxeVdAbGGktoVdOHYIcYaTsfRXZbEBWDObVtNzLOrWZ,CAzvsxZrtwSlLwVMXPPKoFGNIvHcPWhnkJxHESLnGIQEqlrunPDQmvSXkIrcQLoa
IUEZVRnmTKueoTKSrnQMoehJQWDBLmLqFiIaIqcVqxilwmOYllbZXhmFdRsCUHOD,BgIakKWAojTnTxRrvwzvHzDnUfJarWmIiMCDOznAPiOkQZMNLpxdSvbtRkqWIKkL
wiJRcQKXGuwqeqLrnjLYupYVsURcbSXYexHcxNUnRyHtVcIHBOikCYwpFqAsSRlL,jHCttxmbPHFNlhfGewKqWfoybSoIhoFpJphMQqdLOVvDkgEMDnyCZARBinTIyASh
JVuoSYkCRtBujTBwwkHEADtlTSfJoMPGTnGgerhmOOqckSNyVpXIRJqGAlpEFUvR,uvRvEbmLydLhlAxGLSKYuSWTqJOGfSMJhQywCecvBWcRUmhPjfKyyBxfrqiJQuod
rTpDlOHZmzDtidZgEzclzfwEakEjvRYMoiWCMnckLtXWTDwukHSjLBDVSLAvvssR,ZcqqLJUSYVRXizdrshlwkfTVpzmUrXTzquilLwAXsSFfxTNgEJRpodQLahXAzdYo
kwyOBZMqorvmRikqlzkHdlWQWZOlZDdvfhFeuUWkQMqWEqVYpsYxdDTxwvNwnbuy,VHnDORLXKhlrruTzTdfElvuLKRkKhMRnSGXvezGeCKwhKjUiyeEEzSsPmuKkCAjp
sVBccRvJypBLliCWqsonCQpymWzjQdPVAlAOOxfHMIHLyMaqGrAccBIHhyCMyucr,EpwnJGDamJUmUAIvWRKekIbEDbTGkuasOXRFQuDSXECCtkFPILWBHUeKOsbTCmdD
GDdfeQLBJTkRvSENhVLyuajxWoinZTaXVvvpuXtsLVaezncWRYmoERXTXhZrzdSy,tqfEYftKTxzOVROdpyvmOdlgPDZEuewfohklNQNHKlhbdDDzvhHwUKpzVWRmQoEP
gVUDEuBVxLaPtmmpQPvmIeOTDzvppblsGQvpirnJEAjVIEPuzulRWuSjikjBwOrp,CQdqBChnRQfFFtkUKRNXwMaTfwKUydarcCqfpvYahWAHdGQJEgEUpqvZjZhqoKxU
nWOEJocoUzMZZsqnfOVOwYYnOWovSglYfAZbHYwMrxPKOBTOiCikBmuwwvAMoMli,eyjEABNYoZiOScfXswwVmiNSGrzIcnebXQEYiETTMpDfTFQUdWbmvCmyBJkuSuGW
rWtPQVzObqycWSbOZtwQejEpfDhHdGXTNGdkvwDJkprFBGgMRyIxIniLZBhNpPbd,yfhEKbQcryFtTTJjLYhoFZylworfLqNewmxvDIhyVGtOysqaJtQxatsZfGxerEJB
TogpcosRbcgXzjhJVvwvtSWkHXAsXbolUTbuBzYucKEGLpDAZXakSghsAVBZuSkv,kjvMohEOPSWWpxnoOvVZHOpJJOtGUkYkBGwDoMEsUENlBedkCyZWboDBmwvBksNs
xZBYSaMxPHZgQNoxkpKGcoUyPDuPNrDZDeMFCAvCrzFiyNdKXRHOfjYBXfVzOcbg,lQuGFfIfahLUwmhdFfkdZxqoRnUMaiRHQMnFtINqgAVVwjhUquMybHZRmJVIlwUI
zosVlPrzbKTnUafihtOABliTsHQDGbWRJhdcEgRoZaGTJvUSfTunMsWNKJxHFUGw,RHrdqTxgsduUWyvYgXiQJgqNbnRMdonDgiBFZbbpbGCKNmpsjIkOIxkMrcvcEbIZ
IKwUYfLFtxOEzEoKrCyKIXroEhYLuQRzZLFMGhbbuPRBpiaugnVVihcIBdfTseUR,pyiuqnXvyzGkslCNVreqWnZUYzkXtdyBFDmnPhFAVPzzusIsdXzFvClbDbriVcdL
iyxSOOdhkZPhiYeSsUXZbNcMDvoVyJcOeEPgKNZiWTJQGipysRVBhCKUHiJZuKIv,pbUDdoYPYoTzkmiGBRqYEqGeYNZZBwlzgDbCoPSzZZYvxiChSvdSgvThZqYTawIN
AiRZePzswhhcNIiGPYSqxFTIbUdUepbIsMlCBwAsEGhRaVfAwnCwBBxVufAbdkaj,DLBGQOzjBDwuGtkOQEJSCLlXYNJKMjYsamIuRcqLZqEPYSknxLqbxjTOqmHtGpix
nLmrRIMTaYRaeMRoGHcTGumAJsgsGjNKvJkEYgEEspLHDUIaxUtqEdowjDWOjsJn,wCDBBveDmjXTRySpTMdQfYsoOgouVLBZkByuHQnLfQvnBpLpsEfLSyvbFfyCCSoY
MYhJLzFDzmDMwonNcHuxQLwLURzYyTqkmdAHSeywjJHWOkTNmMKiOMSMBFEXVksa,kRiDIDQpsowkqzefUqXlrblSFBeYiOWbdXFoOXHSSZxaOQNjshnffLPXzJCYMRMR
XpgRCQeYmCRVJGHVmcXbJHtqdDJUFSjysHGKukoFREEYifcAfpBbpgAuzAmibSpP,cYlRUVJhUtvNybVETieMjbIWwjatEJSTiSpeZSXIEUVSeAkdWpjAKpOvaHUSkiuU
MceaUZrnLzudOunKIEJRydcrjEdsBlgrnQabBybkhmDUzsracIjcPncDHyxzSton,gywlrOLjlaeosQUuVZcBCzqcggTGMGwLfhiGlxSEWAWRVxTVWYyBXvkqKqIpCFJl
vyYSnQwtdiYfYseKvBOdxIqKBbucJSxrNyfPUxQKoBjjWGktTwuTCqluNXDbkyMG,XnBvnTkToYYodCITlVVUFzCVxAbTUuUZAKXCRCbYUKElOJcKdGGIeNGXFkggvoJt
FhvmerDVIJWnsVsoIhFWVQbKSraHQiQUAvEPMzAWzHrbebIVVFnUJiLwHJXnPQcA,CLCrBRSVGQGeOUaQbBJCEqEVjQoexSgEtHlqeuQVtjKKXOeNBZPdbeunZmEmNfdm
UCLJwzuOSWbnIyfDWBrpxgIFFWQMYZnTTBFChcQSFnjclzhVhTistEdolFxdYEIM,zHWdLfKIEhgBlMWjDACTvwooQjPZAfcxlaYdVluvRONbcqQHkvXIpsABIxSLBYla
x0,x1,x2,x3,x4,x5,x6,x7,x8,y
5.0,1.0,1.0,1.0,2.0,1.0,3.0,1.0,1.0,0 5.0,1.0,1.0,1.0,2.0,1.0,3.0,1.0,1.0,0
5.0,4.0,4.0,5.0,7.0,10.0,3.0,2.0,1.0,0 5.0,4.0,4.0,5.0,7.0,10.0,3.0,2.0,1.0,0
3.0,1.0,1.0,1.0,2.0,2.0,3.0,1.0,1.0,0 3.0,1.0,1.0,1.0,2.0,2.0,3.0,1.0,1.0,0
......
x0,x1,x2,x3,x4,x5,x6,x7,x8,y
9.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,1.0,1 9.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,1.0,1
5.0,3.0,6.0,1.0,2.0,1.0,1.0,1.0,1.0,0 5.0,3.0,6.0,1.0,2.0,1.0,1.0,1.0,1.0,0
8.0,7.0,8.0,2.0,4.0,2.0,5.0,10.0,1.0,1 8.0,7.0,8.0,2.0,4.0,2.0,5.0,10.0,1.0,1
......
x0,x1,x2,x3,x4,x5,x6,x7,x8,y
7.0,8.0,8.0,7.0,3.0,10.0,7.0,2.0,3.0,1 7.0,8.0,8.0,7.0,3.0,10.0,7.0,2.0,3.0,1
1.0,1.0,1.0,1.0,2.0,1.0,1.0,1.0,1.0,0 1.0,1.0,1.0,1.0,2.0,1.0,1.0,1.0,1.0,0
1.0,1.0,1.0,1.0,2.0,1.0,2.0,1.0,1.0,0 1.0,1.0,1.0,1.0,2.0,1.0,2.0,1.0,1.0,0
......
5.0,1.0,1.0,1.0,2.0,1.0,3.0,1.0,1.0,0 x0,x1,x2,x3,x4,x5,x6,x7,x8,y
5.0,NA,1.0,1.0,2.0,1.0,3.0,1.0,1.0,0
5.0,4.0,4.0,5.0,7.0,10.0,3.0,2.0,1.0,0 5.0,4.0,4.0,5.0,7.0,10.0,3.0,2.0,1.0,0
3.0,1.0,1.0,1.0,2.0,2.0,3.0,1.0,1.0,0 3.0,1.0,1.0,1.0,2.0,2.0,3.0,1.0,1.0,0
6.0,8.0,8.0,1.0,3.0,4.0,3.0,7.0,1.0,0 6.0,8.0,8.0,1.0,3.0,4.0,3.0,7.0,1.0,0
......
x0,x1,x2,x3,x4,x5,x6,x7,x8,y
10.0,4.0,5.0,5.0,5.0,10.0,4.0,1.0,1.0,1 10.0,4.0,5.0,5.0,5.0,10.0,4.0,1.0,1.0,1
3.0,3.0,2.0,1.0,3.0,1.0,3.0,6.0,1.0,0 3.0,3.0,2.0,1.0,3.0,1.0,3.0,6.0,1.0,0
3.0,1.0,4.0,1.0,2.0,1.0,3.0,1.0,1.0,0 3.0,1.0,4.0,1.0,2.0,1.0,3.0,1.0,1.0,0
......
x0,x1,x2,x3,x4,x5,x6,x7,x8,y
10.0,6.0,6.0,2.0,4.0,10.0,9.0,7.0,1.0,1 10.0,6.0,6.0,2.0,4.0,10.0,9.0,7.0,1.0,1
6.0,6.0,6.0,5.0,4.0,10.0,7.0,6.0,2.0,1 6.0,6.0,6.0,5.0,4.0,10.0,7.0,6.0,2.0,1
4.0,1.0,1.0,1.0,2.0,1.0,1.0,1.0,1.0,0 4.0,1.0,1.0,1.0,2.0,1.0,1.0,1.0,1.0,0
......
PRIMIHUB_FUSION=primihub/primihub-fusion:latest
PRIMIHUB_PLATFORM=primihub/primihub-platform:latest
PRIMIHUB_WEB_MANAGE=primihub/primihub-web:latest
PRIMIHUB_NODE=primihub/primihub-node:latest
\ No newline at end of file
# README
### docker-compose部署
#### 部署要求
* 机器配置最低4核8G,磁盘40G
* 系统支持`CentOS 7``Ubuntu 18.04+` (推荐使用`Ubuntu`)
#### 执行deploy.sh脚本,完成部署
```bash
bash deploy.sh
```
#### 查看部署结果
```
# docker-compose ps -a
NAME COMMAND SERVICE STATUS PORTS
application1 "/bin/sh -c 'java -j…" application1 running
application2 "/bin/sh -c 'java -j…" application2 running
application3 "/bin/sh -c 'java -j…" application3 running
bootstrap-node "/app/simple-bootstr…" simple-bootstrap-node running 4001/tcp
fusion "/bin/sh -c 'java -j…" fusion running
gateway1 "/bin/sh -c 'java -j…" gateway1 running
gateway2 "/bin/sh -c 'java -j…" gateway2 running
gateway3 "/bin/sh -c 'java -j…" gateway3 running
mysql "docker-entrypoint.s…" mysql running 0.0.0.0:3306->3306/tcp, :::3306->3306/tcp
nacos-server "bin/docker-startup.…" nacos running 0.0.0.0:8848->8848/tcp, 0.0.0.0:9555->9555/tcp, 0.0.0.0:9848->9848/tcp, :::8848->8848/tcp, :::9555->9555/tcp, :::9848->9848/tcp
primihub-node0 "/bin/bash -c './pri…" node0 running 50050/tcp
primihub-node1 "/bin/bash -c './pri…" node1 running 50050/tcp
primihub-node2 "/bin/bash -c './pri…" node2 running 50050/tcp
primihub-web1 "/docker-entrypoint.…" nginx1 running 0.0.0.0:30811->80/tcp, :::30811->80/tcp
primihub-web2 "/docker-entrypoint.…" nginx2 running 0.0.0.0:30812->80/tcp, :::30812->80/tcp
primihub-web3 "/docker-entrypoint.…" nginx3 running 0.0.0.0:30813->80/tcp, :::30813->80/tcp
rabbitmq1 "docker-entrypoint.s…" rabbitmq1 running 25672/tcp
rabbitmq2 "docker-entrypoint.s…" rabbitmq2 running 25672/tcp
rabbitmq3 "docker-entrypoint.s…" rabbitmq3 running 25672/tcp
redis "docker-entrypoint.s…" redis running 6379/tcp
```
#### 说明
docker-compose.yaml 文件中的nginx1、nginx2、nginx3 模拟 3 个机构的管理后台,启动完成后在浏览器分别访问
http://机器IP:30811
http://机器IP:30812
http://机器IP:30813
默认用户密码都是 admin / 123456
具体的联邦建模、隐私求交、匿踪查询等功能的操作步骤请参考 [快速试用管理平台](https://docs.primihub.com/docs/quick-start-platform)
\ No newline at end of file
server {
listen 80;
server_name localhost;
location / {
root /usr/local/nginx/html;
index index.html index.htm;
try_files $uri $uri/ /index.html;
}
location ^~ /prod-api/ {
proxy_pass http://gateway1:8080/;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}
location /fileimages {
root /data;
}
error_page 500 502 503 504 /50x.html;
location = /50x.html {
root html;
}
}
server {
listen 80;
server_name localhost;
location / {
root /usr/local/nginx/html;
index index.html index.htm;
try_files $uri $uri/ /index.html;
}
location ^~ /prod-api/ {
proxy_pass http://gateway2:8080/;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}
location /fileimages {
root /data;
}
error_page 500 502 503 504 /50x.html;
location = /50x.html {
root html;
}
}
server {
listen 80;
server_name localhost;
location / {
root /usr/local/nginx/html;
index index.html index.htm;
try_files $uri $uri/ /index.html;
}
location ^~ /prod-api/ {
proxy_pass http://gateway3:8080/;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}
location /fileimages {
root /data;
}
error_page 500 502 503 504 /50x.html;
location = /50x.html {
root html;
}
}
daemonize no
pidfile /data/redis.pid
port 6379
bind 0.0.0.0
requirepass primihub
rdbcompression yes
dbfilename dump.rdb
dir /data
MYSQL_ROOT_PASSWORD=root
MYSQL_DATABASE=nacos_config
MYSQL_USER=primihub
MYSQL_PASSWORD=primihub@123
\ No newline at end of file
PREFER_HOST_MODE=hostname
MODE=standalone
SPRING_DATASOURCE_PLATFORM=mysql
MYSQL_SERVICE_HOST=mysql
MYSQL_SERVICE_DB_NAME=nacos_config
MYSQL_SERVICE_PORT=3306
MYSQL_SERVICE_USER=primihub
MYSQL_SERVICE_PASSWORD=primihub@123
MYSQL_SERVICE_DB_PARAM=characterEncoding=utf8&connectTimeout=10000&socketTimeout=30000&autoReconnect=true&useSSL=false
\ No newline at end of file
此差异已折叠。
#!/bin/bash
# primihub deploy script
# First, install docker and docker-compose
centos(){
iptables -F
systemctl stop firewalld
systemctl disable firewalld
setenforce 0
sed -i s#SELINUX=enforcing#SELINUX=disabled# /etc/selinux/config
docker version
if [ $? -eq 0 ];
then
echo "docker installed"
else
wget -O /etc/yum.repos.d/docker-ce.repo https://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo
yum -y install docker-ce device-mapper-persistent-data lvm2
systemctl daemon-reload
systemctl start docker && systemctl enable docker
echo "docker install succeed !"
fi
}
ubuntu(){
docker version
if [ $? -eq 0 ];
then
echo "docker installed"
else
apt-get update
apt-get -y install apt-transport-https ca-certificates curl software-properties-common
curl -fsSL https://mirrors.aliyun.com/docker-ce/linux/ubuntu/gpg | sudo apt-key add -
add-apt-repository "deb [arch=amd64] https://mirrors.aliyun.com/docker-ce/linux/ubuntu $(lsb_release -cs) stable"
apt-get -y update
apt-get -y install docker-ce
echo "docker install succeed !"
fi
}
if [ $(uname -s) == "Linux" ];
then
grep "Ubuntu" /etc/issue >> /dev/null
if [ $? -eq 0 ];
then
ubuntu
else
centos
fi
elif [ $(uname -s) == "Darwin" ]; then
which docker-compose > /dev/null
if [ $? != 0 ];
then
echo "Cannot find docker compose, please install it first."
echo "Read the official document from https://docs.docker.com/compose/install/"
exit 1
fi
else
echo "not support yet"
exit 1
fi
docker-compose version
if [ $? -eq 0 ];
then
echo "docker-compose installed"
else
curl -L "https://github.com/docker/compose/releases/download/v2.2.2/docker-compose-$(uname -s)-$(uname -m)" -o /usr/bin/docker-compose
if [ $? -eq 0 ];
then
chmod +x /usr/bin/docker-compose
echo "docker-compose install succeed !"
else
echo "Download docker-compose failed!"
exit
fi
fi
# Pull all the necessary images to avoid pulling multiple times
for i in `cat .env | cut -d '=' -f 2`
do
docker pull $i
done
# Finally, start the application
docker-compose up -d
version: '3'
services:
mysql:
image: nacos/nacos-mysql:5.7
container_name: mysql
restart: always
env_file:
- ./data/env/mysql.env
volumes:
- ./data/mysql:/var/lib/mysql
- ./data/initsql/:/docker-entrypoint-initdb.d/
ports:
- "3306:3306"
networks:
primihub_net:
ipv4_address: 172.28.1.14
redis:
image: redis:5
container_name: redis
restart: always
volumes:
- "./data:/data"
- "./config/redis.conf:/usr/local/etc/redis/redis.conf"
command:
"redis-server /usr/local/etc/redis/redis.conf"
depends_on:
- nacos
networks:
primihub_net:
ipv4_address: 172.28.1.15
rabbitmq1:
image: rabbitmq:3.6.15-management
container_name: rabbitmq1
restart: always
depends_on:
- redis
networks:
primihub_net:
ipv4_address: 172.28.1.16
rabbitmq2:
image: rabbitmq:3.6.15-management
container_name: rabbitmq2
restart: always
depends_on:
- redis
networks:
primihub_net:
ipv4_address: 172.28.1.17
rabbitmq3:
image: rabbitmq:3.6.15-management
container_name: rabbitmq3
restart: always
depends_on:
- redis
networks:
primihub_net:
ipv4_address: 172.28.1.18
nacos:
image: nacos/nacos-server:v2.0.4
container_name: nacos-server
restart: always
env_file:
- ./data/env/nacos-mysql.env
volumes:
- ./data/log/:/home/nacos/logs
#- ./nacos/init.d/custom.properties:/home/nacos/init.d/custom.properties
ports:
- "8848:8848"
- "9848:9848"
- "9555:9555"
depends_on:
- node0
networks:
primihub_net:
ipv4_address: 172.28.1.19
fusion:
image: $PRIMIHUB_FUSION
container_name: fusion
restart: always
entrypoint:
- "/bin/sh"
- "-c"
- "java -jar -Dfile.encoding=UTF-8 /applications/fusion.jar --spring.profiles.active=dc --server.port=8080"
# ports:
# - "8080:8080"
depends_on:
- gateway1
networks:
primihub_net:
ipv4_address: 172.28.1.20
application1:
image: $PRIMIHUB_PLATFORM
container_name: application1
restart: always
volumes:
- "./data:/data"
entrypoint:
- "/bin/sh"
- "-c"
- "java -jar -Dfile.encoding=UTF-8 /applications/application.jar --spring.profiles.active=dc1"
depends_on:
- rabbitmq1
networks:
primihub_net:
ipv4_address: 172.28.1.21
application2:
image: $PRIMIHUB_PLATFORM
container_name: application2
restart: always
volumes:
- "./data:/data"
entrypoint:
- "/bin/sh"
- "-c"
- "java -jar -Dfile.encoding=UTF-8 /applications/application.jar --spring.profiles.active=dc2"
depends_on:
- rabbitmq2
networks:
primihub_net:
ipv4_address: 172.28.1.22
application3:
image: $PRIMIHUB_PLATFORM
container_name: application3
restart: always
volumes:
- "./data:/data"
entrypoint:
- "/bin/sh"
- "-c"
- "java -jar -Dfile.encoding=UTF-8 /applications/application.jar --spring.profiles.active=dc3"
depends_on:
- rabbitmq3
networks:
primihub_net:
ipv4_address: 172.28.1.23
gateway1:
image: $PRIMIHUB_PLATFORM
container_name: gateway1
restart: always
entrypoint:
- "/bin/sh"
- "-c"
- "java -jar -Dfile.encoding=UTF-8 /applications/gateway.jar --spring.profiles.active=dc1 --server.port=8080"
depends_on:
- application1
networks:
primihub_net:
ipv4_address: 172.28.1.24
gateway2:
image: $PRIMIHUB_PLATFORM
container_name: gateway2
restart: always
entrypoint:
- "/bin/sh"
- "-c"
- "java -jar -Dfile.encoding=UTF-8 /applications/gateway.jar --spring.profiles.active=dc2 --server.port=8080"
depends_on:
- application2
networks:
primihub_net:
ipv4_address: 172.28.1.25
gateway3:
image: $PRIMIHUB_PLATFORM
container_name: gateway3
restart: always
entrypoint:
- "/bin/sh"
- "-c"
- "java -jar -Dfile.encoding=UTF-8 /applications/gateway.jar --spring.profiles.active=dc3 --server.port=8080"
depends_on:
- application3
networks:
primihub_net:
ipv4_address: 172.28.1.26
nginx1:
image: $PRIMIHUB_WEB_MANAGE
container_name: primihub-web1
restart: always
volumes:
- "./config/default1.conf:/etc/nginx/conf.d/default.conf"
- "./data:/data"
ports:
- "30811:80"
depends_on:
- gateway1
networks:
primihub_net:
ipv4_address: 172.28.1.27
nginx2:
image: $PRIMIHUB_WEB_MANAGE
container_name: primihub-web2
restart: always
volumes:
- "./config/default2.conf:/etc/nginx/conf.d/default.conf"
- "./data:/data"
ports:
- "30812:80"
depends_on:
- gateway2
networks:
primihub_net:
ipv4_address: 172.28.1.28
nginx3:
image: $PRIMIHUB_WEB_MANAGE
container_name: primihub-web3
restart: always
volumes:
- "./config/default3.conf:/etc/nginx/conf.d/default.conf"
- "./data:/data"
ports:
- "30813:80"
depends_on:
- gateway3
networks:
primihub_net:
ipv4_address: 172.28.1.29
node0:
image: $PRIMIHUB_NODE
container_name: primihub-node0
restart: "always"
ports:
- "50050:50050"
- "6666:6666"
# - "10120:12120"
# - "10121:12121"
volumes:
- ../../config:/app/config
- ./data:/data
entrypoint:
- "/bin/bash"
- "-c"
- "./primihub-node --service_port=50050 --node_id=node0 --config=/app/config/primihub_node0.yaml"
depends_on:
- simple-bootstrap-node
networks:
primihub_net:
ipv4_address: 172.28.1.10
node1:
image: $PRIMIHUB_NODE
container_name: primihub-node1
restart: "always"
ports:
- "50051:50051"
- "6667:6667"
# - "11120:12120"
# - "11121:12121"
volumes:
- ../../config:/app/config
- ./data:/data
entrypoint:
- "/bin/bash"
- "-c"
- "./primihub-node --service_port=50051 --node_id=node1 --config=/app/config/primihub_node1.yaml"
depends_on:
- simple-bootstrap-node
networks:
primihub_net:
ipv4_address: 172.28.1.11
node2:
image: $PRIMIHUB_NODE
container_name: primihub-node2
restart: "always"
ports:
- "50052:50052"
- "6668:6668"
# - "12120:12120"
# - "12121:12121"
volumes:
- ../../config:/app/config
- ./data:/data
entrypoint:
- "/bin/bash"
- "-c"
- "./primihub-node --service_port=50052 --node_id=node2 --config=/app/config/primihub_node2.yaml"
depends_on:
- simple-bootstrap-node
networks:
primihub_net:
ipv4_address: 172.28.1.12
simple-bootstrap-node:
image: primihub/simple-bootstrap-node:1.0
container_name: bootstrap-node
restart: "always"
# ports:
# - "4001:4001"
entrypoint:
- "/app/simple-bootstrap-node"
depends_on:
- mysql
networks:
primihub_net:
ipv4_address: 172.28.1.13
networks:
primihub_net:
driver: bridge
ipam:
config:
- subnet: 172.28.0.0/16
gateway: 172.28.0.1
\ No newline at end of file
"""
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https: //www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import primihub as ph
from primihub.client import primihub_cli as cli
from primihub import dataset, context
from primihub.channel.zmq_channel import IOService, Session
from primihub.FL.model.xgboost.xgb_guest import XGB_GUEST
from primihub.FL.model.xgboost.xgb_host import XGB_HOST
from primihub.FL.model.evaluation.evaluation import Regression_eva
from primihub.FL.model.evaluation.evaluation import Classification_eva
import pandas as pd
import numpy as np
import pickle
import logging
def get_logger(name):
LOG_FORMAT = "[%(asctime)s][%(filename)s:%(lineno)d][%(levelname)s] %(message)s"
DATE_FORMAT = "%m/%d/%Y %H:%M:%S %p"
logging.basicConfig(level=logging.DEBUG,
format=LOG_FORMAT, datefmt=DATE_FORMAT)
logger = logging.getLogger(name)
return logger
logger = get_logger("hetero_xgb")
# client init
# cli.init(config={"node": "127.0.0.1:50050", "cert": ""})
# cli.init(config={"node": "192.168.99.23:8050", "cert": ""})
cli.init(config={"node": "192.168.99.26:50050", "cert": ""})
# Number of tree to fit.
num_tree = 1
# Max depth of each tree.
max_depth = 1
dataset.define("guest_dataset")
dataset.define("label_dataset")
ph.context.Context.func_params_map = {
"xgb_host_logic": ("paillier",),
"xgb_guest_logic": ("paillier",)
}
@ph.context.function(role='host', protocol='xgboost', datasets=['label_dataset'], port='8001', task_type="regression")
def xgb_host_logic(cry_pri="paillier"):
logger.info("start xgb host logic...")
logger.info(ph.context.Context.dataset_map)
logger.info(ph.context.Context.node_addr_map)
logger.info(ph.context.Context.role_nodeid_map)
logger.info(ph.context.Context.params_map)
eva_type = ph.context.Context.params_map.get("taskType", None)
if eva_type is None:
logger.info(
"taskType is not specified, set to default value 'regression'.")
eva_type = "regression"
eva_type = eva_type.lower()
if eva_type != "classification" and eva_type != "regression":
logger.error("Invalid value of taskType, possible value is 'regression', 'classification'.")
return
logger.info("Current task type is {}.".format(eva_type))
data = ph.dataset.read(dataset_key="label_dataset").df_data
columns_label_data = data.columns.tolist()
for index, row in data.iterrows():
for name in columns_label_data:
temp = row[name]
try:
float(temp)
except ValueError:
logger.error(
"Find illegal string '{}', it's not a digit string.".format(temp))
return
# Get host's ip address.
role_node_map = ph.context.Context.get_role_node_map()
node_addr_map = ph.context.Context.get_node_addr_map()
if len(role_node_map["host"]) != 1:
logger.error("Current node of host party: {}".format(
role_node_map["host"]))
logger.error("In hetero XGB, only dataset of host party has label, "
"so host party must have one, make sure it.")
return
host_node = role_node_map["host"][0]
next_peer = node_addr_map[host_node]
ip, port = next_peer.split(":")
ios = IOService()
server = Session(ios, ip, port, "server")
channel = server.addChannel()
dim = data.shape[0]
dim_train = dim / 10 * 8
data_train = data.loc[:dim_train, :].reset_index(drop=True)
data_test = data.loc[dim_train:dim, :].reset_index(drop=True)
label_true = ['Class']
y_true = data_test['Class'].values
data_test = data_test[
[x for x in data_test.columns if x not in label_true]
]
logger.info(data_test.head())
labels = ['Class'] # noqa
X_host = data_train[
[x for x in data.columns if x not in labels]
]
Y = data_train['Class'].values
if cry_pri == "paillier":
from primihub.primitive.opt_paillier_c2py_warpper import opt_paillier_encrypt_crt, opt_paillier_decrypt_crt
from primihub.FL.model.xgboost.xgb_guest_en import XGB_GUEST_EN
from primihub.FL.model.xgboost.xgb_host_en import XGB_HOST_EN
xgb_host = XGB_HOST_EN(n_estimators=num_tree, max_depth=max_depth, reg_lambda=1,
sid=0, min_child_weight=1, objective='linear', channel=channel)
channel.recv()
xgb_host.channel.send(xgb_host.pub)
print(xgb_host.channel.recv())
y_hat = np.array([0.5] * Y.shape[0])
for t in range(xgb_host.n_estimators):
logger.info("Begin to trian tree {}.".format(t + 1))
xgb_host.record = 0
xgb_host.lookup_table = pd.DataFrame(
columns=['record_id', 'feature_id', 'threshold_value'])
f_t = pd.Series([0] * Y.shape[0])
gh = xgb_host.get_gh(y_hat, Y)
gh_en = pd.DataFrame(columns=['g', 'h'])
for item in gh.columns:
for index in gh.index:
gh_en.loc[index, item] = opt_paillier_encrypt_crt(xgb_host.pub, xgb_host.prv,
int(gh.loc[index, item]))
logger.info("Encrypt finish.")
xgb_host.channel.send(gh_en)
GH_guest_en = xgb_host.channel.recv()
GH_guest = pd.DataFrame(
columns=['G_left', 'G_right', 'H_left', 'H_right', 'var', 'cut'])
for item in [x for x in GH_guest_en.columns if x not in ['cut', 'var']]:
for index in GH_guest_en.index:
if GH_guest_en.loc[index, item] == 0:
GH_guest.loc[index, item] = 0
else:
GH_guest.loc[index, item] = opt_paillier_decrypt_crt(xgb_host.pub, xgb_host.prv,
GH_guest_en.loc[index, item])
logger.info("Decrypt finish.")
for item in [x for x in GH_guest_en.columns if x not in ['G_left', 'G_right', 'H_left', 'H_right']]:
for index in GH_guest_en.index:
GH_guest.loc[index, item] = GH_guest_en.loc[index, item]
xgb_host.tree_structure[t + 1], f_t = xgb_host.xgb_tree(X_host, GH_guest, gh, f_t, 0) # noqa
xgb_host.lookup_table_sum[t + 1] = xgb_host.lookup_table
y_hat = y_hat + xgb_host.learning_rate * f_t
logger.info("Finish to train tree {}.".format(t + 1))
predict_file_path = ph.context.Context.get_predict_file_path()
indicator_file_path = ph.context.Context.get_indicator_file_path()
model_file_path = ph.context.Context.get_model_file_path()
lookup_file_path = ph.context.Context.get_host_lookup_file_path()
with open(model_file_path, 'wb') as fm:
pickle.dump(xgb_host.tree_structure, fm)
with open(lookup_file_path, 'wb') as fl:
pickle.dump(xgb_host.lookup_table_sum, fl)
y_pre = xgb_host.predict_prob(data_test)
y_train_pre = xgb_host.predict_prob(X_host)
y_train_pre.to_csv(predict_file_path)
y_train_true = Y
Y_true = {"train": y_train_true, "test": y_true}
Y_pre = {"train": y_train_pre, "test": y_pre}
if eva_type == 'regression':
Regression_eva.get_result(Y_true, Y_pre, indicator_file_path)
elif eva_type == 'classification':
Classification_eva.get_result(Y_true, Y_pre, indicator_file_path)
elif cry_pri == "plaintext":
xgb_host = XGB_HOST(n_estimators=num_tree, max_depth=max_depth, reg_lambda=1,
sid=0, min_child_weight=1, objective='linear', channel=channel)
channel.recv()
y_hat = np.array([0.5] * Y.shape[0])
for t in range(xgb_host.n_estimators):
logger.info("Begin to trian tree {}.".format(t))
xgb_host.record = 0
xgb_host.lookup_table = pd.DataFrame(
columns=['record_id', 'feature_id', 'threshold_value'])
f_t = pd.Series([0] * Y.shape[0])
gh = xgb_host.get_gh(y_hat, Y)
xgb_host.channel.send(gh)
GH_guest = xgb_host.channel.recv()
xgb_host.tree_structure[t + 1], f_t = xgb_host.xgb_tree(X_host, GH_guest, gh, f_t, 0) # noqa
xgb_host.lookup_table_sum[t + 1] = xgb_host.lookup_table
y_hat = y_hat + xgb_host.learning_rate * f_t
logger.info("Finish to trian tree {}.".format(t))
predict_file_path = ph.context.Context.get_predict_file_path()
indicator_file_path = ph.context.Context.get_indicator_file_path()
model_file_path = ph.context.Context.get_model_file_path()
lookup_file_path = ph.context.Context.get_host_lookup_file_path()
with open(model_file_path, 'wb') as fm:
pickle.dump(xgb_host.tree_structure, fm)
with open(lookup_file_path, 'wb') as fl:
pickle.dump(xgb_host.lookup_table_sum, fl)
y_pre = xgb_host.predict_prob(data_test)
if eva_type == 'regression':
Regression_eva.get_result(y_true, y_pre, indicator_file_path)
elif eva_type == 'classification':
Classification_eva.get_result(y_true, y_pre, indicator_file_path)
xgb_host.predict_prob(data_test).to_csv(predict_file_path)
@ph.context.function(role='guest', protocol='xgboost', datasets=['guest_dataset'], port='9002', task_type="regression")
def xgb_guest_logic(cry_pri="paillier"):
logger.info("start xgb guest logic...")
ios = IOService()
logger.info(ph.context.Context.dataset_map)
logger.info(ph.context.Context.node_addr_map)
logger.info(ph.context.Context.role_nodeid_map)
eva_type = ph.context.Context.params_map.get("taskType", None)
if eva_type is None:
logger.info(
"taskType is not specified, set to default value 'regression'.")
eva_type = "regression"
eva_type = eva_type.lower()
if eva_type != "classification" and eva_type != "regression":
logger.error("Invalid value of taskType, possible value is 'regression', 'classification'.")
return
logger.info("Current task type is {}.".format(eva_type))
# Check dataset.
data = ph.dataset.read(dataset_key="guest_dataset").df_data
columns_label_data = data.columns.tolist()
for index, row in data.iterrows():
for name in columns_label_data:
temp = row[name]
try:
float(temp)
except ValueError:
logger.error(
"Find illegal string '{}', it's not a digit string.".format(temp))
return
# Get host's ip address.
role_node_map = ph.context.Context.get_role_node_map()
node_addr_map = ph.context.Context.get_node_addr_map()
if len(role_node_map["host"]) != 1:
logger.error("Current node of host party: {}".format(
role_node_map["host"]))
logger.error("In hetero XGB, only dataset of host party has label,"
"so host party must have one, make sure it.")
return
host_node = role_node_map["host"][0]
next_peer = node_addr_map[host_node]
ip, port = next_peer.split(":")
client = Session(ios, ip, port, "client")
channel = client.addChannel()
dim = data.shape[0]
dim_train = dim / 10 * 8
X_guest = data.loc[:dim_train, :].reset_index(drop=True)
data_test = data.loc[dim_train:dim, :].reset_index(drop=True)
if cry_pri == "paillier":
from primihub.primitive.opt_paillier_c2py_warpper import opt_paillier_encrypt_crt, opt_paillier_decrypt_crt
from primihub.FL.model.xgboost.xgb_guest_en import XGB_GUEST_EN
from primihub.FL.model.xgboost.xgb_host_en import XGB_HOST_EN
xgb_guest = XGB_GUEST_EN(n_estimators=num_tree, max_depth=max_depth, reg_lambda=1, min_child_weight=1,
objective='linear',
sid=1, channel=channel) # noqa
channel.send(b'guest ready')
pub = xgb_guest.channel.recv()
xgb_guest.channel.send(b'recved pub')
for t in range(xgb_guest.n_estimators):
xgb_guest.record = 0
xgb_guest.lookup_table = pd.DataFrame(
columns=['record_id', 'feature_id', 'threshold_value'])
gh_host = xgb_guest.channel.recv()
X_guest_gh = pd.concat([X_guest, gh_host], axis=1)
print(X_guest_gh)
gh_sum = xgb_guest.get_GH(X_guest_gh, pub)
xgb_guest.channel.send(gh_sum)
xgb_guest.cart_tree(X_guest_gh, 0, pub)
xgb_guest.lookup_table_sum[t + 1] = xgb_guest.lookup_table
lookup_file_path = ph.context.Context.get_guest_lookup_file_path()
with open(lookup_file_path, 'wb') as fl:
pickle.dump(xgb_guest.lookup_table_sum, fl)
xgb_guest.predict(data_test)
xgb_guest.predict(X_guest)
elif cry_pri == "plaintext":
xgb_guest = XGB_GUEST(n_estimators=num_tree, max_depth=max_depth, reg_lambda=1, min_child_weight=1,
objective='linear',
sid=1, channel=channel) # noqa
channel.send(b'guest ready')
for t in range(xgb_guest.n_estimators):
xgb_guest.record = 0
xgb_guest.lookup_table = pd.DataFrame(
columns=['record_id', 'feature_id', 'threshold_value'])
gh_host = xgb_guest.channel.recv()
X_guest_gh = pd.concat([X_guest, gh_host], axis=1)
print(X_guest_gh)
gh_sum = xgb_guest.get_GH(X_guest_gh)
xgb_guest.channel.send(gh_sum)
xgb_guest.cart_tree(X_guest_gh, 0)
xgb_guest.lookup_table_sum[t + 1] = xgb_guest.lookup_table
lookup_file_path = ph.context.Context.get_guest_lookup_file_path()
with open(lookup_file_path, 'wb') as fl:
pickle.dump(xgb_guest.lookup_table_sum, fl)
xgb_guest.predict(data_test)
xgb_guest.predict(X_guest)
# context
cry_pri = "plaintext"
# cry_pri = "paillier"
cli.async_remote_execute((xgb_host_logic, cry_pri), (xgb_guest_logic, cry_pri))
# cli.start()
"""
Copyright 2022 Primihub
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https: //www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import primihub as ph
from primihub import context
from primihub.FL.model.logistic_regression.homo_lr import run_homo_lr_host, run_homo_lr_guest, run_homo_lr_arbiter
from primihub.client import primihub_cli as cli
# client init
# cli.init(config={"node": "127.0.0.1:50050", "cert": ""})
# cli.init(config={"node": "192.168.99.23:8050", "cert": ""})
cli.init(config={"node": "192.168.99.26:50050", "cert": ""})
def get_logger(name):
LOG_FORMAT = "[%(asctime)s][%(filename)s:%(lineno)d][%(levelname)s] %(message)s"
DATE_FORMAT = "%m/%d/%Y %H:%M:%S %p"
logging.basicConfig(level=logging.DEBUG,
format=LOG_FORMAT,
datefmt=DATE_FORMAT)
logger = logging.getLogger(name)
return logger
# arbiter_info, guest_info, host_info, task_type, task_params = load_info()
task_params = {}
logger = get_logger("Homo-LR")
@ph.context.function(role='arbiter', protocol='lr', datasets=['breast_0'], port='9010', task_type="lr-train")
def run_arbiter_party():
role_node_map = ph.context.Context.get_role_node_map()
node_addr_map = ph.context.Context.get_node_addr_map()
dataset_map = ph.context.Context.dataset_map
data_key = list(dataset_map.keys())[0]
logger.debug(
"role_nodeid_map {}".format(role_node_map))
logger.debug(
"dataset_map {}".format(dataset_map))
logger.debug(
"node_addr_map {}".format(node_addr_map))
run_homo_lr_arbiter(role_node_map, node_addr_map, data_key)
logger.info("Finish homo-LR arbiter logic.")
@ph.context.function(role='host', protocol='lr', datasets=['breast_1'], port='9020', task_type="lr-train")
def run_host_party():
role_node_map = ph.context.Context.get_role_node_map()
node_addr_map = ph.context.Context.get_node_addr_map()
dataset_map = ph.context.Context.dataset_map
logger.debug(
"dataset_map {}".format(dataset_map))
data_key = list(dataset_map.keys())[0]
logger.debug(
"role_nodeid_map {}".format(role_node_map))
logger.debug(
"node_addr_map {}".format(node_addr_map))
logger.info("Start homo-LR host logic.")
run_homo_lr_host(role_node_map, node_addr_map, data_key)
logger.info("Finish homo-LR host logic.")
@ph.context.function(role='guest', protocol='lr', datasets=['breast_2'], port='9030', task_type="lr-train")
def run_guest_party():
role_node_map = ph.context.Context.get_role_node_map()
node_addr_map = ph.context.Context.get_node_addr_map()
dataset_map = ph.context.Context.dataset_map
logger.debug(
"dataset_map {}".format(dataset_map))
data_key = list(dataset_map.keys())[0]
logger.debug(
"role_nodeid_map {}".format(role_node_map))
logger.debug(
"node_addr_map {}".format(node_addr_map))
logger.info("Start homo-LR guest logic.")
run_homo_lr_guest(role_node_map, node_addr_map, datakey=data_key)
logger.info("Finish homo-LR guest logic.")
cli.async_remote_execute((run_host_party, ), (run_guest_party, ))
# cli.start()
#include "src/primihub/algorithm/arithmetic.h" #include "src/primihub/algorithm/arithmetic.h"
#include "src/primihub/data_store/csv/csv_driver.h"
#include "src/primihub/data_store/factory.h"
#include <arrow/api.h> #include <arrow/api.h>
#include <arrow/array.h> #include <arrow/array.h>
#include <arrow/result.h> #include <arrow/result.h>
#include "src/primihub/data_store/csv/csv_driver.h"
#include "src/primihub/data_store/factory.h"
using arrow::Array; using arrow::Array;
using arrow::DoubleArray; using arrow::DoubleArray;
using arrow::Int64Array; using arrow::Int64Array;
...@@ -25,13 +27,14 @@ void spiltStr(string str, const string &split, std::vector<string> &strlist) { ...@@ -25,13 +27,14 @@ void spiltStr(string str, const string &split, std::vector<string> &strlist) {
} }
} }
ArithmeticExecutor::ArithmeticExecutor( template <Decimal Dbit>
ArithmeticExecutor<Dbit>::ArithmeticExecutor(
PartyConfig &config, std::shared_ptr<DatasetService> dataset_service) PartyConfig &config, std::shared_ptr<DatasetService> dataset_service)
: AlgorithmBase(dataset_service) { : AlgorithmBase(dataset_service) {
this->algorithm_name_ = "arithmetic"; this->algorithm_name_ = "arithmetic";
std::map<std::string, Node> &node_map = config.node_map; std::map<std::string, Node> &node_map = config.node_map;
LOG(INFO) << node_map.size(); // LOG(INFO) << node_map.size();
std::map<uint16_t, rpc::Node> party_id_node_map; std::map<uint16_t, rpc::Node> party_id_node_map;
for (auto iter = node_map.begin(); iter != node_map.end(); iter++) { for (auto iter = node_map.begin(); iter != node_map.end(); iter++) {
rpc::Node &node = iter->second; rpc::Node &node = iter->second;
...@@ -91,7 +94,8 @@ ArithmeticExecutor::ArithmeticExecutor( ...@@ -91,7 +94,8 @@ ArithmeticExecutor::ArithmeticExecutor(
} }
} }
int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) { template <Decimal Dbit>
int ArithmeticExecutor<Dbit>::loadParams(primihub::rpc::Task &task) {
auto param_map = task.params().param_map(); auto param_map = task.params().param_map();
try { try {
data_file_path_ = param_map["Data_File"].value_string(); data_file_path_ = param_map["Data_File"].value_string();
...@@ -104,7 +108,7 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) { ...@@ -104,7 +108,7 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) {
std::string col = itr->substr(0, pos); std::string col = itr->substr(0, pos);
int owner = std::atoi((itr->substr(pos + 1, itr->size())).c_str()); int owner = std::atoi((itr->substr(pos + 1, itr->size())).c_str());
col_and_owner_.insert(make_pair(col, owner)); col_and_owner_.insert(make_pair(col, owner));
LOG(INFO) << col << ":" << owner; // LOG(INFO) << col << ":" << owner;
} }
// LOG(INFO) << col_and_owner; // LOG(INFO) << col_and_owner;
...@@ -115,7 +119,7 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) { ...@@ -115,7 +119,7 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) {
std::string col = itr->substr(0, pos); std::string col = itr->substr(0, pos);
int dtype = std::atoi((itr->substr(pos + 1, itr->size())).c_str()); int dtype = std::atoi((itr->substr(pos + 1, itr->size())).c_str());
col_and_dtype_.insert(make_pair(col, dtype)); col_and_dtype_.insert(make_pair(col, dtype));
LOG(INFO) << col << ":" << dtype; // LOG(INFO) << col << ":" << dtype;
} }
// LOG(INFO) << col_and_dtype; // LOG(INFO) << col_and_dtype;
...@@ -138,17 +142,17 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) { ...@@ -138,17 +142,17 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) {
} }
mpc_op_exec_ = new MPCOperator(party_id_, next_name, prev_name); mpc_op_exec_ = new MPCOperator(party_id_, next_name, prev_name);
} else { } else {
mpc_exec_ = new MPCExpressExecutor(); mpc_exec_ = new MPCExpressExecutor<Dbit>();
} }
LOG(INFO) << expr_; // LOG(INFO) << expr_;
std::string parties = param_map["Parties"].value_string(); std::string parties = param_map["Parties"].value_string();
spiltStr(parties, ";", tmp3); spiltStr(parties, ";", tmp3);
for (auto itr = tmp3.begin(); itr != tmp3.end(); itr++) { for (auto itr = tmp3.begin(); itr != tmp3.end(); itr++) {
uint32_t party = std::atoi((*itr).c_str()); uint32_t party = std::atoi((*itr).c_str());
parties_.push_back(party); parties_.push_back(party);
LOG(INFO) << party; // LOG(INFO) << party;
} }
LOG(INFO) << parties; // LOG(INFO) << parties;
res_name_ = param_map["ResFileName"].value_string(); res_name_ = param_map["ResFileName"].value_string();
} catch (std::exception &e) { } catch (std::exception &e) {
...@@ -159,7 +163,7 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) { ...@@ -159,7 +163,7 @@ int ArithmeticExecutor::loadParams(primihub::rpc::Task &task) {
return 0; return 0;
} }
int ArithmeticExecutor::loadDataset() { template <Decimal Dbit> int ArithmeticExecutor<Dbit>::loadDataset() {
int ret = _LoadDatasetFromCSV(data_file_path_); int ret = _LoadDatasetFromCSV(data_file_path_);
// file reading error or file empty // file reading error or file empty
if (ret <= 0) { if (ret <= 0) {
...@@ -188,7 +192,7 @@ int ArithmeticExecutor::loadDataset() { ...@@ -188,7 +192,7 @@ int ArithmeticExecutor::loadDataset() {
return 0; return 0;
} }
int ArithmeticExecutor::initPartyComm(void) { template <Decimal Dbit> int ArithmeticExecutor<Dbit>::initPartyComm(void) {
if (is_cmp) { if (is_cmp) {
mpc_op_exec_->setup(next_ip_, prev_ip_, next_port_, prev_port_); mpc_op_exec_->setup(next_ip_, prev_ip_, next_port_, prev_port_);
return 0; return 0;
...@@ -199,14 +203,11 @@ int ArithmeticExecutor::initPartyComm(void) { ...@@ -199,14 +203,11 @@ int ArithmeticExecutor::initPartyComm(void) {
return 0; return 0;
} }
int ArithmeticExecutor::execute() { template <Decimal Dbit> int ArithmeticExecutor<Dbit>::execute() {
if (is_cmp) { if (is_cmp) {
try { try {
sbMatrix sh_res; sbMatrix sh_res;
f64Matrix<D16> m; f64Matrix<Dbit> m;
LOG(INFO) << expr_;
LOG(INFO) << expr_.substr(6, 1);
LOG(INFO) << expr_.substr(4, 1);
if (col_and_owner_[expr_.substr(4, 1)] == party_id_) { if (col_and_owner_[expr_.substr(4, 1)] == party_id_) {
m.resize(1, col_and_val_double[expr_.substr(4, 1)].size()); m.resize(1, col_and_val_double[expr_.substr(4, 1)].size());
for (size_t i = 0; i < col_and_val_double[expr_.substr(4, 1)].size(); for (size_t i = 0; i < col_and_val_double[expr_.substr(4, 1)].size();
...@@ -222,13 +223,14 @@ int ArithmeticExecutor::execute() { ...@@ -222,13 +223,14 @@ int ArithmeticExecutor::execute() {
} else } else
mpc_op_exec_->MPC_Compare(sh_res); mpc_op_exec_->MPC_Compare(sh_res);
// reveal // reveal
if (party_id_ == 0) { for (const auto& party : parties_) {
i64Matrix tmp; if (party_id_ == party) {
tmp = mpc_op_exec_->reveal(sh_res); i64Matrix tmp = mpc_op_exec_->reveal(sh_res);
for (size_t i = 0; i < tmp.rows(); i++) for (size_t i = 0; i < tmp.rows(); i++)
cmp_res_.emplace_back(static_cast<bool>(tmp(i, 0))); cmp_res_.emplace_back(static_cast<bool>(tmp(i, 0)));
} else { } else {
mpc_op_exec_->reveal(sh_res, 0); mpc_op_exec_->reveal(sh_res, party);
}
} }
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "In party " << party_id_ << ":\n" << e.what() << "."; LOG(ERROR) << "In party " << party_id_ << ":\n" << e.what() << ".";
...@@ -239,14 +241,16 @@ int ArithmeticExecutor::execute() { ...@@ -239,14 +241,16 @@ int ArithmeticExecutor::execute() {
mpc_exec_->runMPCEvaluate(); mpc_exec_->runMPCEvaluate();
if (mpc_exec_->isFP64RunMode()) { if (mpc_exec_->isFP64RunMode()) {
mpc_exec_->revealMPCResult(parties_, final_val_double_); mpc_exec_->revealMPCResult(parties_, final_val_double_);
for (auto itr = final_val_double_.begin(); itr != final_val_double_.end(); // for (auto itr = final_val_double_.begin(); itr !=
itr++) // final_val_double_.end();
LOG(INFO) << *itr; // itr++)
// LOG(INFO) << *itr;
} else { } else {
mpc_exec_->revealMPCResult(parties_, final_val_int64_); mpc_exec_->revealMPCResult(parties_, final_val_int64_);
for (auto itr = final_val_int64_.begin(); itr != final_val_int64_.end(); // for (auto itr = final_val_int64_.begin(); itr !=
itr++) // final_val_int64_.end();
LOG(INFO) << *itr; // itr++)
// LOG(INFO) << *itr;
} }
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::string msg = "In party 0, "; std::string msg = "In party 0, ";
...@@ -256,7 +260,7 @@ int ArithmeticExecutor::execute() { ...@@ -256,7 +260,7 @@ int ArithmeticExecutor::execute() {
return 0; return 0;
} }
int ArithmeticExecutor::finishPartyComm(void) { template <Decimal Dbit> int ArithmeticExecutor<Dbit>::finishPartyComm(void) {
if (is_cmp) { if (is_cmp) {
mpc_op_exec_->fini(); mpc_op_exec_->fini();
delete mpc_op_exec_; delete mpc_op_exec_;
...@@ -266,7 +270,17 @@ int ArithmeticExecutor::finishPartyComm(void) { ...@@ -266,7 +270,17 @@ int ArithmeticExecutor::finishPartyComm(void) {
return 0; return 0;
} }
int ArithmeticExecutor::saveModel(void) { template <Decimal Dbit> int ArithmeticExecutor<Dbit>::saveModel(void) {
bool is_reveal = false;
for (auto party : parties_) {
if (party == party_id_) {
is_reveal = true;
break;
}
}
if (!is_reveal) {
return 0;
}
arrow::MemoryPool *pool = arrow::default_memory_pool(); arrow::MemoryPool *pool = arrow::default_memory_pool();
arrow::DoubleBuilder builder(pool); arrow::DoubleBuilder builder(pool);
if (final_val_double_.size() != 0) if (final_val_double_.size() != 0)
...@@ -318,7 +332,8 @@ int ArithmeticExecutor::saveModel(void) { ...@@ -318,7 +332,8 @@ int ArithmeticExecutor::saveModel(void) {
return 0; return 0;
} }
int ArithmeticExecutor::_LoadDatasetFromCSV(std::string &filename) { template <Decimal Dbit>
int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &filename) {
std::string nodeaddr("test address"); // TODO std::string nodeaddr("test address"); // TODO
std::shared_ptr<DataDriver> driver = std::shared_ptr<DataDriver> driver =
DataDirverFactory::getDriver("CSV", nodeaddr); DataDirverFactory::getDriver("CSV", nodeaddr);
...@@ -333,68 +348,109 @@ int ArithmeticExecutor::_LoadDatasetFromCSV(std::string &filename) { ...@@ -333,68 +348,109 @@ int ArithmeticExecutor::_LoadDatasetFromCSV(std::string &filename) {
// } // }
bool errors = false; bool errors = false;
int num_col = table->num_columns(); int num_col = table->num_columns();
// 'array' include values in a column of csv file. // 'array' include values in a column of csv file.
auto array = std::static_pointer_cast<DoubleArray>( int chunk_num = table->column(num_col - 1)->chunks().size();
table->column(num_col - 1)->chunk(0)); int64_t array_len = 0;
int64_t array_len = array->length(); for (int k = 0; k < chunk_num; k++) {
auto array = std::static_pointer_cast<DoubleArray>(
table->column(num_col - 1)->chunk(k));
array_len += array->length();
}
LOG(INFO) << "Label column '" << col_names[num_col - 1] << "' has " LOG(INFO) << "Label column '" << col_names[num_col - 1] << "' has "
<< array_len << " values."; << array_len << " values.";
// Force the same value count in every column. // Force the same value count in every column.
for (int i = 0; i < num_col; i++) { for (int i = 0; i < num_col; i++) {
int chunk_num = table->column(i)->chunks().size();
if (col_and_dtype_[col_names[i]] == 0) { if (col_and_dtype_[col_names[i]] == 0) {
auto array = if (table->schema()->GetFieldByName(col_names[i])->type()->id() != 9) {
std::static_pointer_cast<Int64Array>(table->column(i)->chunk(0)); LOG(ERROR) << "Local data type is inconsistent with the demand data "
"type!Demand data type is int,but local data type is "
"double!Please input consistent data type!";
return -1;
}
std::vector<int64_t> tmp_data; std::vector<int64_t> tmp_data;
for (int64_t j = 0; j < array->length(); j++) { int64_t tmp_len = 0;
tmp_data.push_back(array->Value(j)); for (int k = 0; k < chunk_num; k++) {
LOG(INFO) << array->Value(j); auto array =
std::static_pointer_cast<Int64Array>(table->column(i)->chunk(k));
tmp_len += array->length();
for (int64_t j = 0; j < array->length(); j++) {
tmp_data.push_back(array->Value(j));
// LOG(INFO) << array->Value(j);
}
} }
if (array->length() != array_len) { if (tmp_len != array_len) {
LOG(ERROR) << "Column " << col_names[i] << " has " << array->length() LOG(ERROR) << "Column " << col_names[i] << " has " << tmp_len
<< " value, but other column has " << array_len << " value."; << " value, but other column has " << array_len << " value.";
errors = true; errors = true;
break; break;
} }
col_and_val_int.insert( col_and_val_int.insert(
pair<string, std::vector<int64_t>>(col_names[i], tmp_data)); pair<string, std::vector<int64_t>>(col_names[i], tmp_data));
for (auto itr = col_and_val_int.begin(); itr != col_and_val_int.end(); // for (auto itr = col_and_val_int.begin(); itr != col_and_val_int.end();
itr++) { // itr++) {
LOG(INFO) << itr->first; // LOG(INFO) << itr->first;
auto tmp_vec = itr->second; // auto tmp_vec = itr->second;
for (auto iter = tmp_vec.begin(); iter != tmp_vec.end(); iter++) // for (auto iter = tmp_vec.begin(); iter != tmp_vec.end(); iter++)
LOG(INFO) << *iter; // LOG(INFO) << *iter;
} // }
} else { } else {
auto array =
std::static_pointer_cast<DoubleArray>(table->column(i)->chunk(0));
std::vector<double> tmp_data; std::vector<double> tmp_data;
for (int64_t j = 0; j < array->length(); j++) { int64_t tmp_len = 0;
tmp_data.push_back(array->Value(j)); if (table->schema()->GetFieldByName(col_names[i])->type()->id() == 9) {
LOG(INFO) << array->Value(j); for (int k = 0; k < chunk_num; k++) {
} auto array =
if (array->length() != array_len) { std::static_pointer_cast<Int64Array>(table->column(i)->chunk(k));
LOG(ERROR) << "Column " << col_names[i] << " has " << array->length() tmp_len += array->length();
<< " value, but other column has " << array_len << " value."; for (int64_t j = 0; j < array->length(); j++) {
errors = true; tmp_data.push_back(array->Value(j));
break; // LOG(INFO) << array->Value(j);
}
}
if (tmp_len != array_len) {
LOG(ERROR) << "Column " << col_names[i] << " has " << tmp_len
<< " value, but other column has " << array_len
<< " value.";
errors = true;
break;
}
} else {
for (int k = 0; k < chunk_num; k++) {
auto array =
std::static_pointer_cast<DoubleArray>(table->column(i)->chunk(k));
tmp_len += array->length();
for (int64_t j = 0; j < array->length(); j++) {
tmp_data.push_back(array->Value(j));
// LOG(INFO) << array->Value(j);
}
}
if (tmp_len != array_len) {
LOG(ERROR) << "Column " << col_names[i] << " has " << tmp_len
<< " value, but other column has " << array_len
<< " value.";
errors = true;
break;
}
} }
col_and_val_double.insert( col_and_val_double.insert(
pair<string, std::vector<double>>(col_names[i], tmp_data)); pair<string, std::vector<double>>(col_names[i], tmp_data));
for (auto itr = col_and_val_double.begin(); // for (auto itr = col_and_val_double.begin();
itr != col_and_val_double.end(); itr++) { // itr != col_and_val_double.end(); itr++) {
LOG(INFO) << itr->first; // LOG(INFO) << itr->first;
auto tmp_vec = itr->second; // auto tmp_vec = itr->second;
for (auto iter = tmp_vec.begin(); iter != tmp_vec.end(); iter++) // for (auto iter = tmp_vec.begin(); iter != tmp_vec.end(); iter++)
LOG(INFO) << *iter; // LOG(INFO) << *iter;
} // }
} }
} }
if (errors) if (errors)
return -1; return -1;
return array->length(); return array_len;
} }
template class ArithmeticExecutor<D32>;
template class ArithmeticExecutor<D16>;
} // namespace primihub } // namespace primihub
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include "src/primihub/data_store/driver.h" #include "src/primihub/data_store/driver.h"
namespace primihub { namespace primihub {
template <Decimal Dbit>
class ArithmeticExecutor : public AlgorithmBase { class ArithmeticExecutor : public AlgorithmBase {
public: public:
explicit ArithmeticExecutor(PartyConfig &config, explicit ArithmeticExecutor(PartyConfig &config,
...@@ -35,7 +37,7 @@ private: ...@@ -35,7 +37,7 @@ private:
int _LoadDatasetFromCSV(std::string &filename); int _LoadDatasetFromCSV(std::string &filename);
bool is_cmp; bool is_cmp;
MPCExpressExecutor *mpc_exec_; MPCExpressExecutor<Dbit> *mpc_exec_;
MPCOperator *mpc_op_exec_; MPCOperator *mpc_op_exec_;
std::string res_name_; std::string res_name_;
uint16_t local_id_; uint16_t local_id_;
......
...@@ -20,15 +20,15 @@ ...@@ -20,15 +20,15 @@
#include <glog/logging.h> #include <glog/logging.h>
#include "src/primihub/algorithm/logistic.h" #include "src/primihub/algorithm/logistic.h"
#include "src/primihub/data_store/dataset.h"
#include "src/primihub/data_store/factory.h" #include "src/primihub/data_store/factory.h"
#include "src/primihub/service/dataset/model.h" #include "src/primihub/service/dataset/model.h"
#include "src/primihub/data_store/dataset.h"
using namespace std; using namespace std;
using namespace Eigen; using namespace Eigen;
using arrow::Array; using arrow::Array;
using arrow::DoubleArray; using arrow::DoubleArray;
using arrow::Int64Array;
using arrow::Table; using arrow::Table;
namespace primihub { namespace primihub {
...@@ -160,32 +160,27 @@ LogisticRegressionExecutor::LogisticRegressionExecutor( ...@@ -160,32 +160,27 @@ LogisticRegressionExecutor::LogisticRegressionExecutor(
model_name_ = ss.str(); model_name_ = ss.str();
} }
int LogisticRegressionExecutor::loadParams(primihub::rpc::Task &task){ int LogisticRegressionExecutor::loadParams(primihub::rpc::Task &task) {
auto param_map = task.params().param_map(); auto param_map = task.params().param_map();
try{ try {
train_input_filepath_ = param_map["TrainData"].value_string(); train_input_filepath_ = param_map["Data_File"].value_string();
test_input_filepath_ = param_map["TestData"].value_string(); // test_input_filepath_ = param_map["TestData"].value_string();
batch_size_ = param_map["BatchSize"].value_int32(); batch_size_ = param_map["BatchSize"].value_int32();
num_iter_ = param_map["NumIters"].value_int32(); num_iter_ = param_map["NumIters"].value_int32();
model_file_name_ = param_map["modelName"].value_string(); model_file_name_ = param_map["modelName"].value_string();
if(model_file_name_ == "") if (model_file_name_ == "")
model_file_name_="./" + model_name_ + ".csv"; model_file_name_ = "./" + model_name_ + ".csv";
} } catch (std::exception &e) {
catch (std::exception &e)
{
LOG(ERROR) << "Failed to load params: " << e.what(); LOG(ERROR) << "Failed to load params: " << e.what();
return -1; return -1;
} }
LOG(INFO) << "Train data " << train_input_filepath_ << ", test data " LOG(INFO) << "Train data " << train_input_filepath_ << ", test data "
<< test_input_filepath_ << "."; << test_input_filepath_ << ".";
return 0; return 0;
} }
int LogisticRegressionExecutor::_LoadDatasetFromCSV(std::string &filename) {
int LogisticRegressionExecutor::_LoadDatasetFromCSV(std::string &filename,
eMatrix<double> &m) {
std::string nodeaddr("test address"); // TODO std::string nodeaddr("test address"); // TODO
std::shared_ptr<DataDriver> driver = std::shared_ptr<DataDriver> driver =
DataDirverFactory::getDriver("CSV", nodeaddr); DataDirverFactory::getDriver("CSV", nodeaddr);
...@@ -219,36 +214,54 @@ int LogisticRegressionExecutor::_LoadDatasetFromCSV(std::string &filename, ...@@ -219,36 +214,54 @@ int LogisticRegressionExecutor::_LoadDatasetFromCSV(std::string &filename,
if (errors) if (errors)
return -1; return -1;
int64_t train_length = floor(array_len * 0.8);
m.resize(array_len, num_col); int64_t test_length = array_len - train_length;
for (int i = 0; i < num_col - 1; i++) { // LOG(INFO)<<"array_len: "<<array_len;
auto array = // LOG(INFO)<<"train_length: "<<train_length;
std::static_pointer_cast<DoubleArray>(table->column(i)->chunk(0)); // LOG(INFO)<<"test_length: "<<test_length;
for (int64_t j = 0; j < array->length(); j++) train_input_.resize(train_length, num_col);
m(j, i) = array->Value(j); test_input_.resize(test_length, num_col);
} // m.resize(array_len, num_col);
auto array_lastCol = std::static_pointer_cast<arrow::Int64Array>( for (int i = 0; i < num_col; i++) {
table->column(num_col - 1)->chunk(0)); if (table->schema()->GetFieldByName(col_names[i])->type()->id() == 9) {
for (int64_t j = 0; j < array_lastCol->length(); j++) { auto array =
m(j, num_col - 1) = array_lastCol->Value(j); std::static_pointer_cast<Int64Array>(table->column(i)->chunk(0));
for (int64_t j = 0; j < array->length(); j++) {
if (j < train_length)
train_input_(j, i) = array->Value(j);
else
test_input_(j - train_length, i) = array->Value(j);
// m(j, i) = array->Value(j);
}
} else {
auto array =
std::static_pointer_cast<DoubleArray>(table->column(i)->chunk(0));
for (int64_t j = 0; j < array->length(); j++) {
if (j < train_length)
train_input_(j, i) = array->Value(j);
else
test_input_(j - train_length, i) = array->Value(j);
// m(j, i) = array->Value(j);
}
}
} }
return array->length(); return array->length();
} }
int LogisticRegressionExecutor::loadDataset() { int LogisticRegressionExecutor::loadDataset() {
int ret = _LoadDatasetFromCSV(train_input_filepath_, train_input_); int ret = _LoadDatasetFromCSV(train_input_filepath_);
// file reading error or file empty // file reading error or file empty
if (ret <= 0) { if (ret <= 0) {
LOG(ERROR) << "Load dataset for train failed."; LOG(ERROR) << "Load dataset failed.";
return -1; return -1;
} }
ret = _LoadDatasetFromCSV(test_input_filepath_, test_input_); // ret = _LoadDatasetFromCSV(test_input_filepath_, test_input_);
// file reading error or file empty // // file reading error or file empty
if (ret <= 0) { // if (ret <= 0) {
LOG(ERROR) << "Load dataset for test failed."; // LOG(ERROR) << "Load dataset for test failed.";
return -2; // return -2;
} // }
if (train_input_.cols() != test_input_.cols()) { if (train_input_.cols() != test_input_.cols()) {
LOG(ERROR) LOG(ERROR)
...@@ -536,7 +549,7 @@ int LogisticRegressionExecutor::execute() { ...@@ -536,7 +549,7 @@ int LogisticRegressionExecutor::execute() {
return 0; return 0;
} }
int LogisticRegressionExecutor::saveModel(void){ int LogisticRegressionExecutor::saveModel(void) {
arrow::MemoryPool *pool = arrow::default_memory_pool(); arrow::MemoryPool *pool = arrow::default_memory_pool();
arrow::DoubleBuilder builder(pool); arrow::DoubleBuilder builder(pool);
...@@ -550,15 +563,14 @@ int LogisticRegressionExecutor::saveModel(void){ ...@@ -550,15 +563,14 @@ int LogisticRegressionExecutor::saveModel(void){
arrow::field("w", arrow::float64())}; arrow::field("w", arrow::float64())};
auto schema = std::make_shared<arrow::Schema>(schema_vector); auto schema = std::make_shared<arrow::Schema>(schema_vector);
std::shared_ptr<arrow::Table> table = arrow::Table::Make(schema, {array}); std::shared_ptr<arrow::Table> table = arrow::Table::Make(schema, {array});
std::shared_ptr<DataDriver> driver = std::shared_ptr<DataDriver> driver =
DataDirverFactory::getDriver("CSV", dataset_service_->getNodeletAddr()); DataDirverFactory::getDriver("CSV", dataset_service_->getNodeletAddr());
auto cursor = driver->initCursor(model_file_name_); auto cursor = driver->initCursor(model_file_name_);
auto dataset = std::make_shared<primihub::Dataset>(table, driver); auto dataset = std::make_shared<primihub::Dataset>(table, driver);
int ret = cursor->write(dataset); int ret = cursor->write(dataset);
if (ret != 0) if (ret != 0) {
{
LOG(ERROR) << "Save LR model to file " << model_file_name_ << " failed."; LOG(ERROR) << "Save LR model to file " << model_file_name_ << " failed.";
return -1; return -1;
} }
......
...@@ -68,7 +68,7 @@ private: ...@@ -68,7 +68,7 @@ private:
sf64Matrix<D> &train_label, sf64Matrix<D> &test_data, sf64Matrix<D> &train_label, sf64Matrix<D> &test_data,
sf64Matrix<D> &test_label); sf64Matrix<D> &test_label);
int _LoadDatasetFromCSV(std::string &filename, eMatrix<double> &m); int _LoadDatasetFromCSV(std::string &filename);
std::string model_file_name_; std::string model_file_name_;
std::string model_name_; std::string model_name_;
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
#include <arrow/api.h> #include <arrow/api.h>
#include <arrow/array.h> #include <arrow/array.h>
#include <arrow/csv/api.h>
#include <arrow/csv/writer.h>
#include <arrow/filesystem/localfs.h>
#include <arrow/io/api.h> #include <arrow/io/api.h>
#include <arrow/io/file.h> #include <arrow/io/file.h>
#include <arrow/result.h> #include <arrow/result.h>
...@@ -19,7 +22,7 @@ ...@@ -19,7 +22,7 @@
#include "src/primihub/data_store/dataset.h" #include "src/primihub/data_store/dataset.h"
#include "src/primihub/data_store/driver.h" #include "src/primihub/data_store/driver.h"
#include "src/primihub/data_store/factory.h" #include "src/primihub/data_store/factory.h"
#include <arrow/pretty_print.h>
using arrow::Array; using arrow::Array;
using arrow::DoubleArray; using arrow::DoubleArray;
using arrow::Int64Array; using arrow::Int64Array;
...@@ -210,39 +213,43 @@ int MissingProcess::execute() { ...@@ -210,39 +213,43 @@ int MissingProcess::execute() {
std::find(local_col_names.begin(), local_col_names.end(), itr->first); std::find(local_col_names.begin(), local_col_names.end(), itr->first);
double double_sum = 0; double double_sum = 0;
i64 int_sum = 0; i64 int_sum = 0;
int null_num = 0;
if (t != local_col_names.end()) { if (t != local_col_names.end()) {
int tmp_index = std::distance(local_col_names.begin(), t); int tmp_index = std::distance(local_col_names.begin(), t);
if (itr->second == 1) { if (itr->second == 1) {
auto array = std::static_pointer_cast<Int64Array>( int chunk_num = table->column(tmp_index)->chunks().size();
table->column(tmp_index)->chunk(0)); for (int k = 0; k < chunk_num; k++) {
for (int64_t j = 0; j < array->length(); j++) {
int_sum += array->Value(j);
}
auto tmp_array = table->column(tmp_index)->chunk(0);
int_sum =
int_sum / (array->length() - tmp_array->data()->GetNullCount());
} else if (itr->second == 2) {
// check schema
if (table->schema()->GetFieldByName(itr->first)->type()->id() == 9) {
auto array = std::static_pointer_cast<Int64Array>( auto array = std::static_pointer_cast<Int64Array>(
table->column(tmp_index)->chunk(0)); table->column(tmp_index)->chunk(k));
null_num +=
table->column(tmp_index)->chunk(k)->data()->GetNullCount();
for (int64_t j = 0; j < array->length(); j++) { for (int64_t j = 0; j < array->length(); j++) {
double_sum += array->Value(j); int_sum += array->Value(j);
} }
auto tmp_array = table->column(tmp_index)->chunk(0); }
double_sum = double_sum / int_sum = int_sum / (table->num_rows() - null_num);
(array->length() - tmp_array->data()->GetNullCount()); } else if (itr->second == 2) {
} else { // check schema
auto array = std::static_pointer_cast<DoubleArray>( int chunk_num = table->column(tmp_index)->chunks().size();
table->column(tmp_index)->chunk(0)); for (int k = 0; k < chunk_num; k++) {
for (int64_t j = 0; j < array->length(); j++) { if (table->schema()->GetFieldByName(itr->first)->type()->id() ==
double_sum += array->Value(j); 9) {
auto array = std::static_pointer_cast<Int64Array>(
table->column(tmp_index)->chunk(k));
for (int64_t j = 0; j < array->length(); j++) {
double_sum += array->Value(j);
}
} else {
auto array = std::static_pointer_cast<DoubleArray>(
table->column(tmp_index)->chunk(k));
for (int64_t j = 0; j < array->length(); j++) {
double_sum += array->Value(j);
}
} }
auto tmp_array = table->column(tmp_index)->chunk(0); null_num +=
double_sum = double_sum / table->column(tmp_index)->chunk(k)->data()->GetNullCount();
(array->length() - tmp_array->data()->GetNullCount());
} }
double_sum = double_sum / (table->num_rows() - null_num);
} }
} }
if (itr->second == 1) { if (itr->second == 1) {
...@@ -255,20 +262,28 @@ int MissingProcess::execute() { ...@@ -255,20 +262,28 @@ int MissingProcess::execute() {
new_sum = new_sum / 3; new_sum = new_sum / 3;
if (t != local_col_names.end()) { if (t != local_col_names.end()) {
int tmp_index = std::distance(local_col_names.begin(), t); int tmp_index = std::distance(local_col_names.begin(), t);
auto csv_array = std::static_pointer_cast<Int64Array>( int chunk_num = table->column(tmp_index)->chunks().size();
table->column(tmp_index)->chunk(0));
std::vector<int> null_index;
auto tmp_array = table->column(tmp_index)->chunk(0);
for (int i = 0; i < tmp_array->length(); i++) {
if (tmp_array->IsNull(i))
null_index.push_back(i);
}
std::vector<i64> new_col; std::vector<i64> new_col;
for (int64_t i = 0; i < csv_array->length(); i++) { for (int k = 0; k < chunk_num; k++) {
new_col.push_back(csv_array->Value(i));
} auto csv_array = std::static_pointer_cast<Int64Array>(
for (auto itr = null_index.begin(); itr != null_index.end(); itr++) { table->column(tmp_index)->chunk(k));
new_col[*itr] = new_sum; std::vector<int> null_index;
auto tmp_array = table->column(tmp_index)->chunk(k);
for (int i = 0; i < tmp_array->length(); i++) {
if (tmp_array->IsNull(i))
null_index.push_back(i);
}
std::vector<i64> tmp_new_col;
for (int64_t i = 0; i < csv_array->length(); i++) {
tmp_new_col.push_back(csv_array->Value(i));
}
for (auto itr = null_index.begin(); itr != null_index.end();
itr++) {
tmp_new_col[*itr] = new_sum;
}
new_col.insert(new_col.end(), tmp_new_col.begin(),
tmp_new_col.end());
} }
arrow::Int64Builder int64_builder; arrow::Int64Builder int64_builder;
int64_builder.AppendValues(new_col); int64_builder.AppendValues(new_col);
...@@ -292,32 +307,41 @@ int MissingProcess::execute() { ...@@ -292,32 +307,41 @@ int MissingProcess::execute() {
new_sum = new_sum / 3; new_sum = new_sum / 3;
if (t != local_col_names.end()) { if (t != local_col_names.end()) {
std::vector<double> new_col;
int tmp_index = std::distance(local_col_names.begin(), t); int tmp_index = std::distance(local_col_names.begin(), t);
std::vector<double> new_col;
if (table->schema()->GetFieldByName(itr->first)->type()->id() == 9) { int chunk_num = table->column(tmp_index)->chunks().size();
auto csv_array = std::static_pointer_cast<Int64Array>( for (int k = 0; k < chunk_num; k++) {
table->column(tmp_index)->chunk(0)); std::vector<double> tmp_new_col;
for (int64_t j = 0; j < csv_array->length(); j++) { if (table->schema()->GetFieldByName(itr->first)->type()->id() ==
new_col.push_back(csv_array->Value(j)); 9) {
auto csv_array = std::static_pointer_cast<Int64Array>(
table->column(tmp_index)->chunk(k));
for (int64_t j = 0; j < csv_array->length(); j++) {
tmp_new_col.push_back(csv_array->Value(j));
}
} else {
auto csv_array = std::static_pointer_cast<DoubleArray>(
table->column(tmp_index)->chunk(k));
for (int64_t i = 0; i < csv_array->length(); i++) {
tmp_new_col.push_back(csv_array->Value(i));
}
} }
} else { std::vector<int> null_index;
auto csv_array = std::static_pointer_cast<DoubleArray>( auto tmp_array = table->column(tmp_index)->chunk(k);
table->column(tmp_index)->chunk(0)); for (int i = 0; i < tmp_array->length(); i++) {
if (tmp_array->IsNull(i))
for (int64_t i = 0; i < csv_array->length(); i++) { null_index.push_back(i);
new_col.push_back(csv_array->Value(i));
} }
} for (auto itr = null_index.begin(); itr != null_index.end();
std::vector<int> null_index; itr++) {
auto tmp_array = table->column(tmp_index)->chunk(0); tmp_new_col[*itr] = new_sum;
for (int i = 0; i < tmp_array->length(); i++) { }
if (tmp_array->IsNull(i)) new_col.insert(new_col.end(), tmp_new_col.begin(),
null_index.push_back(i); tmp_new_col.end());
}
for (auto itr = null_index.begin(); itr != null_index.end(); itr++) {
new_col[*itr] = new_sum;
} }
arrow::DoubleBuilder double_builder; arrow::DoubleBuilder double_builder;
double_builder.AppendValues(new_col); double_builder.AppendValues(new_col);
...@@ -336,7 +360,7 @@ int MissingProcess::execute() { ...@@ -336,7 +360,7 @@ int MissingProcess::execute() {
LOG(ERROR) << "In party " << party_id_ << ":\n" << e.what() << "."; LOG(ERROR) << "In party " << party_id_ << ":\n" << e.what() << ".";
} }
return 0; return 0;
} } // namespace primihub
int MissingProcess::finishPartyComm(void) { int MissingProcess::finishPartyComm(void) {
si64 tmp_share0, tmp_share1, tmp_share2; si64 tmp_share0, tmp_share1, tmp_share2;
...@@ -380,9 +404,11 @@ int MissingProcess::_LoadDatasetFromCSV(std::string &filename) { ...@@ -380,9 +404,11 @@ int MissingProcess::_LoadDatasetFromCSV(std::string &filename) {
DataDirverFactory::getDriver("CSV", nodeaddr); DataDirverFactory::getDriver("CSV", nodeaddr);
std::shared_ptr<Cursor> &cursor = driver->read(filename); std::shared_ptr<Cursor> &cursor = driver->read(filename);
std::shared_ptr<Dataset> ds = cursor->read(); std::shared_ptr<Dataset> ds = cursor->read();
table = std::get<std::shared_ptr<Table>>(ds->data); table = std::get<std::shared_ptr<Table>>(ds->data);
bool errors = false; bool errors = false;
std::vector<std::string> col_names = table->ColumnNames(); std::vector<std::string> col_names = table->ColumnNames();
int num_col = table->num_columns(); int num_col = table->num_columns();
LOG(INFO) << "Loaded " << table->num_rows() << " rows in " LOG(INFO) << "Loaded " << table->num_rows() << " rows in "
...@@ -390,26 +416,33 @@ int MissingProcess::_LoadDatasetFromCSV(std::string &filename) { ...@@ -390,26 +416,33 @@ int MissingProcess::_LoadDatasetFromCSV(std::string &filename) {
local_col_names = table->ColumnNames(); local_col_names = table->ColumnNames();
// 'array' include values in a column of csv file. // 'array' include values in a column of csv file.
auto array = std::static_pointer_cast<DoubleArray>( int chunk_num = table->column(num_col - 1)->chunks().size();
table->column(num_col - 1)->chunk(0)); int array_len = 0;
int64_t array_len = array->length(); for (int k = 0; k < chunk_num; k++) {
auto array = std::static_pointer_cast<DoubleArray>(
table->column(num_col - 1)->chunk(k));
array_len += array->length();
}
LOG(INFO) << "Label column '" << local_col_names[num_col - 1] << "' has " LOG(INFO) << "Label column '" << local_col_names[num_col - 1] << "' has "
<< array_len << " values."; << array_len << " values.";
// Force the same value count in every column. // Force the same value count in every column.
for (int i = 0; i < num_col; i++) { for (int i = 0; i < num_col; i++) {
auto array = int chunk_num = table->column(i)->chunks().size();
std::static_pointer_cast<DoubleArray>(table->column(i)->chunk(0));
std::vector<double> tmp_data; std::vector<double> tmp_data;
for (int64_t j = 0; j < array->length(); j++) { int tmp_len = 0;
tmp_data.push_back(array->Value(j));
for (int k = 0; k < chunk_num; k++) {
auto array =
std::static_pointer_cast<DoubleArray>(table->column(i)->chunk(k));
tmp_len += array->length();
for (int64_t j = 0; j < array->length(); j++) {
tmp_data.push_back(array->Value(j));
}
} }
if (tmp_len != array_len) {
if (array->length() != array_len) { LOG(ERROR) << "Column " << local_col_names[i] << " has " << tmp_len
LOG(ERROR) << "Column " << local_col_names[i] << " has " << " value, but other column has " << array_len << " value.";
<< array->length() << " value, but other column has "
<< array_len << " value.";
errors = true; errors = true;
break; break;
} }
...@@ -417,7 +450,7 @@ int MissingProcess::_LoadDatasetFromCSV(std::string &filename) { ...@@ -417,7 +450,7 @@ int MissingProcess::_LoadDatasetFromCSV(std::string &filename) {
if (errors) if (errors)
return -1; return -1;
return array->length(); return array_len;
} }
} }
} // namespace primihub } // namespace primihub
\ No newline at end of file
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "src/primihub/cli/cli.h" #include "src/primihub/cli/cli.h"
#include <fstream> // std::ifstream #include <fstream> // std::ifstream
#include <string> #include <string>
#include <chrono>
using primihub::rpc::ParamValue; using primihub::rpc::ParamValue;
using primihub::rpc::string_array; using primihub::rpc::string_array;
...@@ -135,7 +136,7 @@ int SDKClient::SubmitTask() { ...@@ -135,7 +136,7 @@ int SDKClient::SubmitTask() {
input_datasets[i]); input_datasets[i]);
} }
} }
// TEE task // TEE task
if ( absl::GetFlag(FLAGS_task_type) == 6 ) { if ( absl::GetFlag(FLAGS_task_type) == 6 ) {
...@@ -194,9 +195,15 @@ int main(int argc, char** argv) { ...@@ -194,9 +195,15 @@ int main(int argc, char** argv) {
LOG(INFO) << "SDK SubmitTask to: " << peer; LOG(INFO) << "SDK SubmitTask to: " << peer;
primihub::SDKClient client( primihub::SDKClient client(
grpc::CreateChannel(peer, grpc::InsecureChannelCredentials())); grpc::CreateChannel(peer, grpc::InsecureChannelCredentials()));
if (!client.SubmitTask()) { auto _start = std::chrono::high_resolution_clock::now();
auto ret = client.SubmitTask();
auto _end = std::chrono::high_resolution_clock::now();
auto time_cost = std::chrono::duration_cast<std::chrono::milliseconds>(_end - _start).count();
LOG(INFO) << "SubmitTask time cost(ms): " << time_cost;
if (!ret) {
break; break;
} }
} }
return 0; return 0;
......
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
*/ */
#include <variant> #include <variant>
#include <sys/stat.h>
#include "src/primihub/data_store/csv/csv_driver.h" #include "src/primihub/data_store/csv/csv_driver.h"
#include "src/primihub/data_store/driver.h" #include "src/primihub/data_store/driver.h"
#include <arrow/api.h> #include <arrow/api.h>
#include <arrow/csv/api.h> #include <arrow/csv/api.h>
#include <arrow/csv/writer.h> #include <arrow/csv/writer.h>
...@@ -45,7 +45,9 @@ void CSVCursor::close() { ...@@ -45,7 +45,9 @@ void CSVCursor::close() {
// read all data from csv file // read all data from csv file
std::shared_ptr<primihub::Dataset> CSVCursor::read() { std::shared_ptr<primihub::Dataset> CSVCursor::read() {
struct stat file_info;
::stat(filePath.c_str(), &file_info);
size_t file_size = file_info.st_size;
arrow::io::IOContext io_context = arrow::io::default_io_context(); arrow::io::IOContext io_context = arrow::io::default_io_context();
arrow::fs::LocalFileSystem local_fs( arrow::fs::LocalFileSystem local_fs(
arrow::fs::LocalFileSystemOptions::Defaults()); arrow::fs::LocalFileSystemOptions::Defaults());
...@@ -57,6 +59,7 @@ std::shared_ptr<primihub::Dataset> CSVCursor::read() { ...@@ -57,6 +59,7 @@ std::shared_ptr<primihub::Dataset> CSVCursor::read() {
std::shared_ptr<arrow::io::InputStream> input = result_ifstream.ValueOrDie(); std::shared_ptr<arrow::io::InputStream> input = result_ifstream.ValueOrDie();
auto read_options = arrow::csv::ReadOptions::Defaults(); auto read_options = arrow::csv::ReadOptions::Defaults();
read_options.block_size = file_size;
auto parse_options = arrow::csv::ParseOptions::Defaults(); auto parse_options = arrow::csv::ParseOptions::Defaults();
auto convert_options = arrow::csv::ConvertOptions::Defaults(); auto convert_options = arrow::csv::ConvertOptions::Defaults();
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -54,7 +54,7 @@ message TaskResponse { ...@@ -54,7 +54,7 @@ message TaskResponse {
service VMNode { service VMNode {
rpc SubmitTask(PushTaskRequest) returns (PushTaskReply); rpc SubmitTask(PushTaskRequest) returns (PushTaskReply);
rpc ExecuteTask(ExecuteTaskRequest) returns (ExecuteTaskResponse); rpc ExecuteTask(stream ExecuteTaskRequest) returns (stream ExecuteTaskResponse);
rpc Send(stream TaskRequest) returns (TaskResponse); rpc Send(stream TaskRequest) returns (TaskResponse);
} }
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册