未验证 提交 64cb8600 编写于 作者: PhoenixTree2013's avatar PhoenixTree2013 提交者: GitHub

Merge pull request #304 from primihub/develop

1.6.2
# Linux
build:linux --cxxopt=-std=c++17
build:linux --host_cxxopt=-std=c++17
build:linux --copt=-w
build:linux --copt=-DENABLE_SSE
build:linux --linkopt=-lstdc++fs
build:linux --define microsoft-apsi=false
build:linux_x86_64 --cxxopt=-std=c++17
build:linux_x86_64 --host_cxxopt=-std=c++17
build:linux_x86_64 --copt=-w
build:linux_x86_64 --linkopt=-lstdc++fs
build:linux_x86_64 --copt=-DENABLE_SSE
build:linux_x86_64 --define cpu=amd64
build:linux_x86_64 --define cpu_arch=x86_64
build:linux_x86_64 --define microsoft-apsi=true
build:linux_x86_64 --define enable_mysql_driver=true
build:linux_aarch64 --cxxopt=-std=c++17
build:linux_aarch64 --host_cxxopt=-std=c++17
build:linux_aarch64 --copt=-w
build:linux_aarch64 --linkopt=-lstdc++fs
build:linux_aarch64 --define cpu=arm64
build:linux_aarch64 --define cpu_arch=aarch64
build:linux_aarch64 --define microsoft-apsi=true
build:linux_aarch64 --define enable_mysql_driver=true
#build:linux --strip=never
#build:linux --copt -fno-sanitize-recover=all
......@@ -47,12 +59,16 @@ build:darwin_x86_64 --macos_minimum_os=10.16
build:darwin_x86_64 --cpu=darwin_x86_64
build:darwin_x86_64 --copt=-DENABLE_SSE
build:darwin_x86_64 --define macos-build=true
build:darwin_x86_64 --define cpu_arch=darwin_x86_64
build:darwin_x86_64 --define enable_mysql_driver=false
# MacOS Big Sur with Apple Silicon M1
build:darwin_arm64 --apple_platform_type=macos
build:darwin_arm64 --macos_minimum_os=10.16
build:darwin_arm64 --cpu=darwin_arm64
build:darwin_arm64 --define macos-build=true
build:darwin_arm64 --define cpu_arch=darwin_arm64
build:darwin_arm64 --define enable_mysql_driver=false
# MacOS Monterey with Apple M1
build:darwin --apple_platform_type=macos
......
......@@ -63,7 +63,7 @@ jobs:
- name: bazel test
run: |
./pre_build.sh
bazel build --config=linux logistic_test maxpool_test falcon_lenet_test common_test network_test
bazel build --config=linux_x86_64 logistic_test maxpool_test falcon_lenet_test common_test network_test
./bazel-bin/logistic_test
./bazel-bin/maxpool_test
./bazel-bin/falcon_lenet_test
......@@ -143,7 +143,7 @@ jobs:
- name: bazel build
run: |
# cc_binary
bazel build --config=linux :node :cli :opt_paillier_c2py :linkcontext
bazel build --config=linux_x86_64 :node :cli :opt_paillier_c2py :linkcontext
build-on-mac_x86_64:
......
......@@ -36,6 +36,25 @@ config_setting(
values = {"define": "microsoft-apsi=true"},
)
config_setting(
name = "aarch64",
values = {"define": "cpu_arch=aarch64"},
)
config_setting(
name = "x86_64",
values = {"define": "cpu_arch=x86_64"},
)
config_setting(
name = "darwin_x86_64",
values = {"define": "cpu_arch=darwin_x86_64"},
)
config_setting(
name = "enable_mysql_driver",
values = {"define": "enable_mysql_driver=true"},
)
DEFAULT_LINK_OPTS = [
"-pthread",
"-ldl",
......@@ -66,10 +85,12 @@ C_OPT = [
"-Wall",
"-ggdb",
"-rdynamic",
"-maes",
"-mpclmul",
"-Wno-reserved-user-defined-literal",
]
] + select({
":x86_64": ["-maes", "-mpclmul"],
":aarch64": [],
"//conditions:default": [],
})
## start of primihub
load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library")
......@@ -91,7 +112,6 @@ alias(
actual = "@soralog//:soralog",
)
# proto
grpc_proto_library(
name = "route_guide",
......@@ -130,27 +150,21 @@ grpc_proto_library(
)
cc_library(
name = "common_lib",
srcs = glob([
"src/primihub/common/defines.cc",
"src/primihub/common/clp.cc",
"src/primihub/common/config/config.cc",
"src/primihub/common/type/type.cc",
"src/primihub/common/type/fixed_point.cc",
]),
hdrs = glob([
"src/primihub/common/defines.h",
"src/primihub/common/finally.h",
"src/primihub/common/clp.h",
"src/primihub/common/config/config.h",
"src/primihub/common/type/type.h",
"src/primihub/common/type/fixed_point.h",
"src/primihub/common/type/matrix.h",
"src/primihub/common/type/matrix_view.h",
"src/primihub/common/eventbus/eventbus.hpp",
"src/primihub/common/eventbus/function_traits.hpp",
name = "server_config",
hdrs = [
"src/primihub/node/server_config.h",
],
srcs = [
"src/primihub/node/server_config.cc",
],
deps = [
":config_lib",
"@com_github_jbeder_yaml_cpp//:yaml-cpp",
],
)
]),
cc_library(
name = "microsoft_gsl",
textual_hdrs = [
"src/primihub/common/gsl/span",
"src/primihub/common/gsl/gsl_assert",
......@@ -166,11 +180,114 @@ cc_library(
"--std=c++17",
],
linkopts = LINK_OPTS,
)
cc_library(
name = "common_defination",
hdrs = [
"src/primihub/common/common.h",
],
)
cc_library(
name = "config_lib",
hdrs = [
"src/primihub/common/config/config.h",
],
srcs = [
"src/primihub/common/config/config.cc",
],
deps = [
":common_defination",
"@com_github_jbeder_yaml_cpp//:yaml-cpp",
"@com_github_glog_glog//:glog",
]
)
cc_library(
name = "clp_lib",
hdrs = [
"src/primihub/common/clp.h",
],
srcs = [
"src/primihub/common/clp.cc",
],
deps = [
":common_defination",
":microsoft_gsl",
],
)
cc_library(
name = "eventbus_lib",
hdrs = [
"src/primihub/common/eventbus/eventbus.hpp",
"src/primihub/common/eventbus/function_traits.hpp",
],
)
cc_library(
name = "data_type_defination",
hdrs = [
"src/primihub/common/type/type.h",
"src/primihub/common/type/fixed_point.h",
"src/primihub/common/type/matrix.h",
"src/primihub/common/type/matrix_view.h",
],
srcs = [
"src/primihub/common/type/type.cc",
"src/primihub/common/type/fixed_point.cc",
],
copts = C_OPT + [
"--std=c++17",
],
linkstatic = False,
deps = [
":eigen",
"@boost//:multiprecision",
":microsoft_gsl",
":common_util",
],
)
cc_library(
name = "common_util",
hdrs = [
"src/primihub/common/defines.h",
],
srcs = [
"src/primihub/common/defines.cc",
],
linkopts = LINK_OPTS,
linkstatic = False,
deps = [
":microsoft_gsl",
":common_defination",
],
)
cc_library (
name = "finally_tool",
hdrs = [
"src/primihub/common/finally.h",
],
)
cc_library(
name = "common_lib",
copts = C_OPT + [
"--std=c++17",
],
linkopts = LINK_OPTS,
linkstatic = False,
deps = [
":common_defination",
":config_lib",
":common_util",
":clp_lib",
":eventbus_lib",
":finally_tool",
":data_type_defination",
"@com_github_jbeder_yaml_cpp//:yaml-cpp",
"@com_github_glog_glog//:glog",
],
)
......@@ -229,7 +346,7 @@ cc_library(
deps = [
"@com_github_glog_glog//:glog",
"@com_github_grpc_grpc//:grpc++",
":common_lib",
":config_lib",
":worker_proto",
],
)
......@@ -389,12 +506,12 @@ cc_library(
"-I src/primihub/protocol/cryptflow2/OT/",
"-I src/primihub/protocol/cryptflow2/utils/",
"-D SCI_OT",
"-maes",
"-msse4.1",
"-mavx2",
"-pthread",
"-mrdseed",
],
] + select({
":x86_64": ["-maes", "-msse4.1", "-mavx2", "-mrdseed",],
":aarch64": [],
"//conditions:default": [],
}),
linkopts = ["-fopenmp"],
deps = [
"@com_microsoft_seal//:seal",
......@@ -421,11 +538,14 @@ cc_library(
# TODO: Consider to remove -I flag.
"-I src/primihub/protocol/falcon-public/",
"-I src/primihub/protocol/falcon-public/util/",
"-mpclmul",
"-maes",
"-fpic",
"-Wno-narrowing",
],
] + select({
":x86_64": ["-maes", "-mpclmul"],
":darwin_x86_64": ["-maes", "-mpclmul"],
":aarch64": [],
"//conditions:default": [],
}),
linkstatic = False,
deps = [
":eigen",
......@@ -476,11 +596,11 @@ cc_library(
"src/primihub/service/dataset/storage_backend.h",
"src/primihub/util/cpu_check.h"
]),
copts = C_OPT + [
"-maes",
"-mavx2",
"-mrdseed",
],
copts = C_OPT + select({
":x86_64": ["-maes", "-mrdseed", "-mavx2"],
":aarch64": [],
"//conditions:default": [],
}),
linkopts = LINK_OPTS,
linkstatic = False,
deps = [
......@@ -522,11 +642,11 @@ cc_library(
"src/primihub/operator/aby3_operator.h",
"src/primihub/algorithm/missing_val_processing.h",
]),
copts = C_OPT + [
"-maes",
"-msse4.1",
"-mrdseed",
],
copts = C_OPT + select({
":x86_64": ["-maes", "-mrdseed", "-msse4.1"],
":aarch64": [],
"//conditions:default": [],
}),
linkopts = LINK_OPTS,
linkstatic = False,
deps = [
......@@ -608,8 +728,10 @@ PIR_LIB_DEPS = select({
LINUX_X86_BUILD_LIBS = select({
":macos-build": [] ,
"//conditions:default": [":cryptflow2_algorithm_lib", "@osu_libpsi//:libpsi"],
":aarch64": [],
":x86_64": [":cryptflow2_algorithm_lib", "@osu_libpsi//:libpsi"],
":darwin_x86_64": [],
"//conditions:default": [],
})
TASK_LIB_DEPS = DEFAULT_TASK_LIB_DEPS + PIR_LIB_DEPS + LINUX_X86_BUILD_LIBS
......@@ -741,6 +863,7 @@ cc_library(
deps = TASK_LIB_DEPS + [
":communication_lib",
":endian_util",
":server_config",
],
)
......@@ -748,12 +871,10 @@ cc_library(
name = "node_lib",
srcs = glob([
"src/primihub/node/worker/worker.cc",
"src/primihub/algorithm/dataload.cpp",
]),
hdrs = glob([
"src/primihub/node/worker/worker.h",
]),
copts = C_OPT,
linkopts = LINK_OPTS,
......@@ -909,7 +1030,12 @@ cc_library(
"src/primihub/data_store/csv/csv_driver.cc",
# "src/primihub/data_store/hdfs/hdfs_driver.cc",
"src/primihub/data_store/sqlite/sqlite_driver.cc",
] + select({
"enable_mysql_driver": [
"src/primihub/data_store/mysql/mysql_driver.cc",
],
"//conditions:default": []
}),
hdrs = [
"src/primihub/data_store/factory.h",
"src/primihub/data_store/dataset.h",
......@@ -918,8 +1044,21 @@ cc_library(
"src/primihub/data_store/csv/csv_driver.h",
#"src/primihub/data_store/hdfs/hdfs_driver.h",
"src/primihub/data_store/sqlite/sqlite_driver.h",
] + select({
"enable_mysql_driver": [
"src/primihub/data_store/mysql/mysql_driver.h",
],
linkopts = LINK_OPTS,
"//conditions:default": []
}),
defines = select({
"enable_mysql_driver": ["ENABLE_MYSQL_DRIVER"],
"//conditions:default": []
}),
linkopts = LINK_OPTS + select({
"enable_mysql_driver": ["-lmysqlclient"],
"//conditions:default": []
}),
deps = [
"@arrow",
"@com_github_glog_glog//:glog",
......@@ -930,26 +1069,23 @@ cc_library(
":util_lib",
"@nlohmann_json",
"@com_github_sqlite_wrapper//:sqlite_wrapper",
"@com_github_jbeder_yaml_cpp//:yaml-cpp",
],
linkstatic = True,
visibility = ["//visibility:public"],
)
cc_library(
name = "new_data_store_lib",
srcs = [
"src/primihub/data_store/driver.cc",
"src/primihub/data_store/csv/csv_driver.cc",
# "src/primihub/data_store/hdfs/hdfs_driver.cc",
],
hdrs = [
"src/primihub/data_store/factory.h",
"src/primihub/data_store/dataset.h",
"src/primihub/data_store/driver.h",
"src/primihub/data_store/csv/csv_driver.h",
"src/primihub/data_store/hdfs/hdfs_driver.h",
......@@ -963,8 +1099,6 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_library(
name = "p2p_lib",
srcs = [
......@@ -1042,7 +1176,8 @@ cc_library(
":p2p_lib",
":data_store_lib",
":util_lib",
"service_base",
":service_base",
":server_config",
],
linkstatic = True,
visibility = ["//visibility:public"],
......@@ -1077,7 +1212,6 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_library(
name = "nodelet",
srcs = [
......@@ -1213,10 +1347,10 @@ cc_binary(
"src/primihub/node/node.h",
"src/primihub/node/ds.cc",
"src/primihub/node/ds.h",
"src/primihub/node/main.cc",
]),
copts = C_OPT,
includes = [
],
includes = [],
linkstatic = True,
linkopts = LINK_OPTS ,
deps = [
......@@ -1235,7 +1369,7 @@ cc_binary(
":nodelet",
":algorithm_lib",
":util_lib",
":server_config",
],
)
......@@ -1264,9 +1398,34 @@ cc_binary(
":util_lib",
"@openssl",
"@com_github_stduuid//:stduuid_lib",
":communication_lib",
],
)
cc_binary(
name = "reg_cli_test",
srcs = [
"src/primihub/cli/reg_cli.cc",
"src/primihub/cli/reg_cli.h",
],
copts = C_OPT,
includes = [
],
linkstatic = True,
linkopts = LINK_OPTS,
deps = [
"@com_github_grpc_grpc//:grpc++",
"@com_google_absl//absl/base",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/memory",
"@com_github_glog_glog//:glog",
":worker_proto",
":service_proto",
"@openssl",
"@com_github_stduuid//:stduuid_lib",
],
)
## end of primihub
## unit tests
......
......@@ -12,15 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
FROM ubuntu:20.04 as builder
FROM primihub/primihub-node:build as builder
ENV LANG C.UTF-8
ENV DEBIAN_FRONTEND=noninteractive
# Install dependencies
RUN apt update \
&& apt install -y python3 python3-dev gcc-8 g++-8 python-dev libgmp-dev cmake libmysqlclient-dev\
&& apt install -y automake ca-certificates git libtool m4 patch pkg-config unzip make wget curl zip ninja-build npm \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 800 --slave /usr/bin/g++ g++ /usr/bin/g++-8 \
&& rm -rf /var/lib/apt/lists/*
# install bazelisk
RUN npm install -g @bazel/bazelisk
WORKDIR /src
ADD . /src
# Bazel build primihub-node & primihub-cli & paillier shared library
RUN bash pre_build.sh \
&& bazel build --config=linux --define cpu=amd64 --define microsoft-apsi=true :node :cli :opt_paillier_c2py :linkcontext
&& ARCH=`arch` \
&& bazel build --config=linux_$ARCH :node :cli :opt_paillier_c2py :linkcontext
FROM ubuntu:20.04 as runner
......@@ -53,10 +66,9 @@ RUN mkdir -p src/primihub/protos data log
COPY --from=builder /src/python ./python
COPY --from=builder /src/src/primihub/protos/ ./src/primihub/protos/
# Copy opt_paillier_c2py.so to /app/python, this enable setup.py find it.
RUN cp $TARGET_PATH/opt_paillier_c2py.so /app/python/
# Copy linkcontext.so to /app/python, this enable setup.py find it.
RUN cp $TARGET_PATH/linkcontext.so /app/python/
# Copy opt_paillier_c2py.so linkcontext.so to /app/python, this enable setup.py find it.
RUN cp $TARGET_PATH/opt_paillier_c2py.so /app/python/ \
&& cp $TARGET_PATH/linkcontext.so /app/python/
# The setup.py will copy opt_paillier_c2py.so to python library path.
WORKDIR /app/python
......@@ -64,8 +76,8 @@ RUN python3 -m pip install --upgrade pip \
&& python3 -m pip install -r requirements.txt \
&& python3 setup.py install
RUN rm -rf /app/python/opt_paillier_c2py.so
RUN rm -rf /app/python/linkcontext.so
RUN rm -rf /app/python/opt_paillier_c2py.so \
&& rm -rf /app/python/linkcontext.so
WORKDIR /app
# gRPC server port
......
# Copyright 2022 Primihub
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FROM ubuntu:20.04
ENV LANG C.UTF-8
ENV DEBIAN_FRONTEND=noninteractive
# Install dependencies
RUN apt update \
&& apt install -y python3 python3-dev gcc-8 g++-8 python-dev libgmp-dev cmake \
&& apt install -y automake ca-certificates git libtool m4 patch pkg-config unzip make wget curl zip ninja-build npm \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 800 --slave /usr/bin/g++ g++ /usr/bin/g++-8 \
&& rm -rf /var/lib/apt/lists/*
# install bazelisk
RUN npm install -g @bazel/bazelisk
# Install keyword PIR dependencies
WORKDIR /opt
RUN wget https://github.com/zeromq/libzmq/archive/refs/tags/v4.3.4.tar.gz \
&& tar -zxf v4.3.4.tar.gz && mkdir libzmq-4.3.4/build && cd libzmq-4.3.4/build \
&& cmake .. && make -j 8 && make install
RUN wget https://github.com/zeromq/cppzmq/archive/refs/tags/v4.9.0.tar.gz \
&& tar -zxf v4.9.0.tar.gz && mkdir cppzmq-4.9.0/build && cd cppzmq-4.9.0/build \
&& cmake .. && make -j 8 && make install
RUN wget https://github.com/google/flatbuffers/archive/refs/tags/v2.0.0.tar.gz \
&& tar -zxf v2.0.0.tar.gz && mkdir flatbuffers-2.0.0/build && cd flatbuffers-2.0.0/build \
&& cmake .. && make -j 8 && make install
RUN wget https://sourceforge.net/projects/tclap/files/tclap-1.2.5.tar.gz \
&& tar -zxvf tclap-1.2.5.tar.gz && cd tclap-1.2.5 && ./configure \
&& make -j 8 && make install
\ No newline at end of file
......@@ -6,7 +6,7 @@ ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
RUN apt-get update \
&& apt-get install -y python3 python3-dev libgmp-dev python3-pip libzmq5 tzdata \
&& apt-get install -y python3 python3-dev libgmp-dev python3-pip libzmq5 tzdata libmysqlclient-dev\
&& rm -rf /var/lib/apt/lists/*
ARG TARGET_PATH=/root/.cache/bazel/_bazel_root/17a1cd4fb136f9bc7469e0db6305b35a/execroot/__main__/bazel-out/k8-fastbuild/bin
......
linux_x86_64:
#bazel build --config=linux --define cpu=amd64 --define microsoft-apsi=true //:node //:cli
bazel build --config=linux --define cpu=amd64 --define microsoft-apsi=true //:node //:cli //:linkcontext //:opt_paillier_c2py
bazel build --config=linux_x86_64 //:node //:cli //:linkcontext //:opt_paillier_c2py //:linkcontext
linux_aarch64:
bazel build --config=linux_aarch64 //:node //:cli //:linkcontext //:opt_paillier_c2py //:linkcontext
macos_arm64:
bazel build --config=macos --define cpu=arm64 --define microsoft-apsi=true //:node //:cli //:opt_paillier_c2py //:linkcontext
......@@ -538,7 +538,8 @@ http_archive(
# APSI
git_repository(
name = "mircrosoft_apsi",
branch = "bazel_version",
#branch = "bazel_version",
commit = "44243c1a85435c04ca858279757ca5524dd3c9aa",
remote = "https://gitee.com/primihub/APSI.git",
)
......
......@@ -562,7 +562,8 @@ http_archive(
# APSI
git_repository(
name = "mircrosoft_apsi",
branch = "bazel_version",
#branch = "bazel_version",
commit = "44243c1a85435c04ca858279757ca5524dd3c9aa",
remote = "https://github.com/primihub/APSI.git",
)
......
......@@ -9,7 +9,9 @@ fi
bash pre_build.sh
bazel build --config=linux --define cpu=amd64 --define microsoft-apsi=true :node :cli :opt_paillier_c2py :linkcontext
ARCH=`arch`
bazel build --config=linux_$ARCH :node :cli :opt_paillier_c2py :linkcontext
if [ $? -ne 0 ]; then
echo "Build failed!!!"
......
......@@ -16,19 +16,24 @@
version: 1.0
node: "node0"
#location: "www.primihub.server.com"
#use_tls: true
location: "127.0.0.1"
use_tls: false
grpc_port: 50050
#certificate:
# root_ca: "data/cert/ca.crt"
# key: "data/cert/node0.key"
# cert: "data/cert/node0.crt"
# Use redis by default, set `use_redis: False` to use libp2p
redis_meta_service:
redis_addr: "127.0.0.1:6379"
use_redis: True
redis_password: "primihub"
# node_keypair:
# public_key:
# private_key:
# load datasets
datasets:
# ABY3 LR test case datasets
......@@ -50,10 +55,28 @@ datasets:
model: "csv"
source: "data/falcon/dataset/MNIST/input_1"
# FL homo lr test case datasets
- description: "homo_lr_data"
model: "csv"
source: "data/FL/homo_lr/breast_cancer.csv"
- description: "train_homo_lr"
model: "csv"
source: "data/FL/homo_lr/train_breast_cancer.csv"
# PSI test case datasets for sqlite database
- description: "psi_client_data_db"
model: "sqlite"
table_name: "psi_client_data"
source: "data/client_e.db3"
# Dataset authorization
# authorization:
# - node:
# task:
# PSI test caset datasets
- description: "psi_client_data"
model: "csv"
source: "data/client_e.csv"
localkv:
model: "leveldb"
......
......@@ -16,7 +16,12 @@
version: 1.0
node: "node1"
#location: "www.primihub.server.com"
#use_tls: true
location: "127.0.0.1"
use_tls: false
grpc_port: 50051
# Use redis by default, set `use_redis: False` to use libp2p
......@@ -25,6 +30,11 @@ redis_meta_service:
use_redis: True
redis_password: "primihub"
#certificate:
# root_ca: "data/cert/ca.crt"
# key: "data/cert/node0.key"
# cert: "data/cert/node0.crt"
# load datasets
datasets:
# ABY3 LR test case datasets
......@@ -77,6 +87,21 @@ datasets:
model: "csv"
source: "data/FL/homo_lr_test.data"
# FL homo lr test case datasets
- description: "train_homo_lr_host"
model: "csv"
source: "data/FL/homo_lr/train/train_breast_cancer_host.csv"
- description: "test_homo_lr"
model: "csv"
source: "data/FL/homo_lr/test_breast_cancer.csv"
- description: "train_hetero_xgb_host"
model: "csv"
source: "data/FL/hetero_xgb/train/train_breast_cancer_host.csv"
- description: "test_hetero_xgb_host"
model: "csv"
source: "data/FL/hetero_xgb/test/test_breast_cancer_host.csv"
localkv:
model: "leveldb"
......
......@@ -16,9 +16,19 @@
version: 1.0
node: "node2"
#location: "www.primihub.server.com"
#use_tls: true
location: "127.0.0.1"
use_tls: false
grpc_port: 50052
#certificate:
# root_ca: "data/cert/ca.crt"
# key: "data/cert/node0.key"
# cert: "data/cert/node0.crt"
# Use redis by default, set `use_redis: False` to use libp2p
redis_meta_service:
redis_addr: "127.0.0.1:6379"
......@@ -54,16 +64,30 @@ datasets:
model: "csv"
source: "data/FL/wisconsin.data"
# PSI test caset datasets
- description: "psi_client_data"
# FL homo lr test case datasets
- description: "train_homo_lr_guest"
model: "csv"
source: "data/client_e.csv"
source: "data/FL/homo_lr/train/train_breast_cancer_guest.csv"
## PSI test caset datasets
#- description: "psi_client_data"
# model: "csv"
# source: "data/client_e.csv"
# PSI test case datasets for sqlite database
- description: "psi_client_data_db"
model: "sqlite"
table_name: "psi_client_data"
source: "data/client_e.db3"
#- description: "psi_client_data_db"
# model: "sqlite"
# table_name: "psi_client_data"
# source: "data/client_e.db3"
- description: "train_hetero_xgb_guest"
model: "csv"
source: "data/FL/hetero_xgb/train/train_breast_cancer_guest.csv"
- description: "test_hetero_xgb_guest"
model: "csv"
source: "data/FL/hetero_xgb/test/test_breast_cancer_guest.csv"
localkv:
model: "leveldb"
......
......@@ -2,8 +2,15 @@ version: 1.0
node: "node0"
location: "172.28.1.10"
use_tls: false
grpc_port: 50050
#certificate:
# root_ca: "data/cert/ca.crt"
# key: "data/cert/node0.key"
# cert: "data/cert/node0.crt"
# Use redis by default, set `use_redis: False` to use libp2p
redis_meta_service:
redis_addr: "redis:6379"
......@@ -30,6 +37,14 @@ datasets:
model: "csv"
source: "/tmp/falcon/dataset/MNIST/input_1"
# FL homo lr test case datasets
- description: "homo_lr_data"
model: "csv"
source: "/tmp/FL/homo_lr/breast_cancer.csv"
- description: "train_homo_lr"
model: "csv"
source: "/tmp/FL/homo_lr/train_breast_cancer.csv"
localkv:
model: "leveldb"
path: "/data/localdb0"
......
......@@ -2,8 +2,15 @@ version: 1.0
node: "node1"
location: "172.28.1.11"
use_tls: false
grpc_port: 50051
#certificate:
# root_ca: "data/cert/ca.crt"
# key: "data/cert/node0.key"
# cert: "data/cert/node0.crt"
# Use redis by default, set `use_redis: False` to use libp2p
redis_meta_service:
redis_addr: "redis:6379"
......@@ -64,6 +71,14 @@ datasets:
model: "csv"
source: "/tmp/FL/hetero_xgb/test/test_breast_cancer_host.csv"
# FL homo lr test case datasets
- description: "train_homo_lr_host"
model: "csv"
source: "/tmp/FL/homo_lr/train/train_breast_cancer_host.csv"
- description: "test_homo_lr"
model: "csv"
source: "/tmp/FL/homo_lr/test_breast_cancer.csv"
localkv:
model: "leveldb"
path: "/data/localdb1"
......
......@@ -3,6 +3,12 @@ version: 1.0
node: "node2"
location: "172.28.1.12"
grpc_port: 50052
use_tls: false
#certificate:
# root_ca: "data/cert/ca.crt"
# key: "data/cert/node0.key"
# cert: "data/cert/node0.crt"
# Use redis by default, set `use_redis: False` to use libp2p
redis_meta_service:
......@@ -35,6 +41,11 @@ datasets:
model: "csv"
source: "/tmp/FL/wisconsin.data"
# FL homo lr test case datasets
- description: "train_homo_lr_guest"
model: "csv"
source: "/tmp/FL/homo_lr/train/train_breast_cancer_guest.csv"
# PSI test caset datasets
- description: "psi_client_data"
model: "csv"
......
此差异已折叠。
id,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,radius error,texture error,perimeter error,area error,smoothness error,compactness error,concavity error,concave points error,symmetry error,fractal dimension error,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,y
301,12.46,19.89,80.43,471.3,0.08451,0.1014,0.0683,0.03099,0.1781,0.06249,0.3642,1.04,2.579,28.32,0.00653,0.03369,0.04712,0.01403,0.0274,0.004651,13.46,23.07,88.13,551.3,0.105,0.2158,0.1904,0.07625,0.2685,0.07764,1
233,20.51,27.81,134.4,1319.0,0.09159,0.1074,0.1554,0.0834,0.1448,0.05592,0.524,1.189,3.767,70.01,0.00502,0.02062,0.03457,0.01091,0.01298,0.002887,24.47,37.38,162.7,1872.0,0.1223,0.2761,0.4146,0.1563,0.2437,0.08328,0
506,12.22,20.04,79.47,453.1,0.1096,0.1152,0.08175,0.02166,0.2124,0.06894,0.1811,0.7959,0.9857,12.58,0.006272,0.02198,0.03966,0.009894,0.0132,0.003813,13.16,24.17,85.13,515.3,0.1402,0.2315,0.3535,0.08088,0.2709,0.08839,1
478,11.49,14.59,73.99,404.9,0.1046,0.08228,0.05308,0.01969,0.1779,0.06574,0.2034,1.166,1.567,14.34,0.004957,0.02114,0.04156,0.008038,0.01843,0.003614,12.4,21.9,82.04,467.6,0.1352,0.201,0.2596,0.07431,0.2941,0.0918,1
512,13.4,20.52,88.64,556.7,0.1106,0.1469,0.1445,0.08172,0.2116,0.07325,0.3906,0.9306,3.093,33.67,0.005414,0.02265,0.03452,0.01334,0.01705,0.004005,16.41,29.66,113.3,844.4,0.1574,0.3856,0.5106,0.2051,0.3585,0.1109,0
466,13.14,20.74,85.98,536.9,0.08675,0.1089,0.1085,0.0351,0.1562,0.0602,0.3152,0.7884,2.312,27.4,0.007295,0.03179,0.04615,0.01254,0.01561,0.00323,14.8,25.46,100.9,689.1,0.1351,0.3549,0.4504,0.1181,0.2563,0.08174,1
462,14.4,26.99,92.25,646.1,0.06995,0.05223,0.03476,0.01737,0.1707,0.05433,0.2315,0.9112,1.727,20.52,0.005356,0.01679,0.01971,0.00637,0.01414,0.001892,15.4,31.98,100.4,734.6,0.1017,0.146,0.1472,0.05563,0.2345,0.06464,1
557,9.423,27.88,59.26,271.3,0.08123,0.04971,0.0,0.0,0.1742,0.06059,0.5375,2.927,3.618,29.11,0.01159,0.01124,0.0,0.0,0.03004,0.003324,10.49,34.24,66.5,330.6,0.1073,0.07158,0.0,0.0,0.2475,0.06969,1
192,9.72,18.22,60.73,288.1,0.0695,0.02344,0.0,0.0,0.1653,0.06447,0.3539,4.885,2.23,21.69,0.001713,0.006736,0.0,0.0,0.03799,0.001688,9.968,20.83,62.25,303.8,0.07117,0.02729,0.0,0.0,0.1909,0.06559,1
489,16.69,20.2,107.1,857.6,0.07497,0.07112,0.03649,0.02307,0.1846,0.05325,0.2473,0.5679,1.775,22.95,0.002667,0.01446,0.01423,0.005297,0.01961,0.0017,19.18,26.56,127.3,1084.0,0.1009,0.292,0.2477,0.08737,0.4677,0.07623,0
555,10.29,27.61,65.67,321.4,0.0903,0.07658,0.05999,0.02738,0.1593,0.06127,0.2199,2.239,1.437,14.46,0.01205,0.02736,0.04804,0.01721,0.01843,0.004938,10.84,34.91,69.57,357.6,0.1384,0.171,0.2,0.09127,0.2226,0.08283,1
249,11.52,14.93,73.87,406.3,0.1013,0.07808,0.04328,0.02929,0.1883,0.06168,0.2562,1.038,1.686,18.62,0.006662,0.01228,0.02105,0.01006,0.01677,0.002784,12.65,21.19,80.88,491.8,0.1389,0.1582,0.1804,0.09608,0.2664,0.07809,1
493,12.46,12.83,78.83,477.3,0.07372,0.04043,0.007173,0.01149,0.1613,0.06013,0.3276,1.486,2.108,24.6,0.01039,0.01003,0.006416,0.007895,0.02869,0.004821,13.19,16.36,83.24,534.0,0.09439,0.06477,0.01674,0.0268,0.228,0.07028,1
425,10.03,21.28,63.19,307.3,0.08117,0.03912,0.00247,0.005159,0.163,0.06439,0.1851,1.341,1.184,11.6,0.005724,0.005697,0.002074,0.003527,0.01445,0.002411,11.11,28.94,69.92,376.3,0.1126,0.07094,0.01235,0.02579,0.2349,0.08061,1
385,14.6,23.29,93.97,664.7,0.08682,0.06636,0.0839,0.05271,0.1627,0.05416,0.4157,1.627,2.914,33.01,0.008312,0.01742,0.03389,0.01576,0.0174,0.002871,15.79,31.71,102.2,758.2,0.1312,0.1581,0.2675,0.1359,0.2477,0.06836,0
482,13.47,14.06,87.32,546.3,0.1071,0.1155,0.05786,0.05266,0.1779,0.06639,0.1588,0.5733,1.102,12.84,0.00445,0.01452,0.01334,0.008791,0.01698,0.002787,14.83,18.32,94.94,660.2,0.1393,0.2499,0.1848,0.1335,0.3227,0.09326,1
532,13.68,16.33,87.76,575.5,0.09277,0.07255,0.01752,0.0188,0.1631,0.06155,0.2047,0.4801,1.373,17.25,0.003828,0.007228,0.007078,0.005077,0.01054,0.001697,15.85,20.2,101.6,773.4,0.1264,0.1564,0.1206,0.08704,0.2806,0.07782,1
1,20.57,17.77,132.9,1326.0,0.08474,0.07864,0.0869,0.07017,0.1812,0.05667,0.5435,0.7339,3.398,74.08,0.005225,0.01308,0.0186,0.0134,0.01389,0.003532,24.99,23.41,158.8,1956.0,0.1238,0.1866,0.2416,0.186,0.275,0.08902,0
286,11.94,20.76,77.87,441.0,0.08605,0.1011,0.06574,0.03791,0.1588,0.06766,0.2742,1.39,3.198,21.91,0.006719,0.05156,0.04387,0.01633,0.01872,0.008015,13.24,27.29,92.2,546.1,0.1116,0.2813,0.2365,0.1155,0.2465,0.09981,1
329,16.26,21.88,107.5,826.8,0.1165,0.1283,0.1799,0.07981,0.1869,0.06532,0.5706,1.457,2.961,57.72,0.01056,0.03756,0.05839,0.01186,0.04022,0.006187,17.73,25.21,113.7,975.2,0.1426,0.2116,0.3344,0.1047,0.2736,0.07953,0
70,18.94,21.31,123.6,1130.0,0.09009,0.1029,0.108,0.07951,0.1582,0.05461,0.7888,0.7975,5.486,96.05,0.004444,0.01652,0.02269,0.0137,0.01386,0.001698,24.86,26.58,165.9,1866.0,0.1193,0.2336,0.2687,0.1789,0.2551,0.06589,0
6,18.25,19.98,119.6,1040.0,0.09463,0.109,0.1127,0.074,0.1794,0.05742,0.4467,0.7732,3.18,53.91,0.004314,0.01382,0.02254,0.01039,0.01369,0.002179,22.88,27.66,153.2,1606.0,0.1442,0.2576,0.3784,0.1932,0.3063,0.08368,0
102,12.18,20.52,77.22,458.7,0.08013,0.04038,0.02383,0.0177,0.1739,0.05677,0.1924,1.571,1.183,14.68,0.00508,0.006098,0.01069,0.006797,0.01447,0.001532,13.34,32.84,84.58,547.8,0.1123,0.08862,0.1145,0.07431,0.2694,0.06878,1
547,10.26,16.58,65.85,320.8,0.08877,0.08066,0.04358,0.02438,0.1669,0.06714,0.1144,1.023,0.9887,7.326,0.01027,0.03084,0.02613,0.01097,0.02277,0.00589,10.83,22.04,71.08,357.4,0.1461,0.2246,0.1783,0.08333,0.2691,0.09479,1
362,12.76,18.84,81.87,496.6,0.09676,0.07952,0.02688,0.01781,0.1759,0.06183,0.2213,1.285,1.535,17.26,0.005608,0.01646,0.01529,0.009997,0.01909,0.002133,13.75,25.99,87.82,579.7,0.1298,0.1839,0.1255,0.08312,0.2744,0.07238,1
278,13.59,17.84,86.24,572.3,0.07948,0.04052,0.01997,0.01238,0.1573,0.0552,0.258,1.166,1.683,22.22,0.003741,0.005274,0.01065,0.005044,0.01344,0.001126,15.5,26.1,98.91,739.1,0.105,0.07622,0.106,0.05185,0.2335,0.06263,1
195,12.91,16.33,82.53,516.4,0.07941,0.05366,0.03873,0.02377,0.1829,0.05667,0.1942,0.9086,1.493,15.75,0.005298,0.01587,0.02321,0.00842,0.01853,0.002152,13.88,22.0,90.81,600.6,0.1097,0.1506,0.1764,0.08235,0.3024,0.06949,1
47,13.17,18.66,85.98,534.6,0.1158,0.1231,0.1226,0.0734,0.2128,0.06777,0.2871,0.8937,1.897,24.25,0.006532,0.02336,0.02905,0.01215,0.01743,0.003643,15.67,27.95,102.8,759.4,0.1786,0.4166,0.5006,0.2088,0.39,0.1179,0
29,17.57,15.05,115.0,955.1,0.09847,0.1157,0.09875,0.07953,0.1739,0.06149,0.6003,0.8225,4.655,61.1,0.005627,0.03033,0.03407,0.01354,0.01925,0.003742,20.01,19.52,134.9,1227.0,0.1255,0.2812,0.2489,0.1456,0.2756,0.07919,0
65,14.78,23.94,97.4,668.3,0.1172,0.1479,0.1267,0.09029,0.1953,0.06654,0.3577,1.281,2.45,35.24,0.006703,0.0231,0.02315,0.01184,0.019,0.003224,17.31,33.39,114.6,925.1,0.1648,0.3416,0.3024,0.1614,0.3321,0.08911,0
508,16.3,15.7,104.7,819.8,0.09427,0.06712,0.05526,0.04563,0.1711,0.05657,0.2067,0.4706,1.146,20.67,0.007394,0.01203,0.0247,0.01431,0.01344,0.002569,17.32,17.76,109.8,928.2,0.1354,0.1361,0.1947,0.1357,0.23,0.0723,1
69,12.78,16.49,81.37,502.5,0.09831,0.05234,0.03653,0.02864,0.159,0.05653,0.2368,0.8732,1.471,18.33,0.007962,0.005612,0.01585,0.008662,0.02254,0.001906,13.46,19.76,85.67,554.9,0.1296,0.07061,0.1039,0.05882,0.2383,0.0641,1
498,18.49,17.52,121.3,1068.0,0.1012,0.1317,0.1491,0.09183,0.1832,0.06697,0.7923,1.045,4.851,95.77,0.007974,0.03214,0.04435,0.01573,0.01617,0.005255,22.75,22.88,146.4,1600.0,0.1412,0.3089,0.3533,0.1663,0.251,0.09445,0
556,10.16,19.59,64.73,311.7,0.1003,0.07504,0.005025,0.01116,0.1791,0.06331,0.2441,2.09,1.648,16.8,0.01291,0.02222,0.004174,0.007082,0.02572,0.002278,10.65,22.88,67.88,347.3,0.1265,0.12,0.01005,0.02232,0.2262,0.06742,1
426,10.48,14.98,67.49,333.6,0.09816,0.1013,0.06335,0.02218,0.1925,0.06915,0.3276,1.127,2.564,20.77,0.007364,0.03867,0.05263,0.01264,0.02161,0.00483,12.13,21.57,81.41,440.4,0.1327,0.2996,0.2939,0.0931,0.302,0.09646,1
412,9.397,21.68,59.75,268.8,0.07969,0.06053,0.03735,0.005128,0.1274,0.06724,0.1186,1.182,1.174,6.802,0.005515,0.02674,0.03735,0.005128,0.01951,0.004583,9.965,27.99,66.61,301.0,0.1086,0.1887,0.1868,0.02564,0.2376,0.09206,1
402,12.96,18.29,84.18,525.2,0.07351,0.07899,0.04057,0.01883,0.1874,0.05899,0.2357,1.299,2.397,20.21,0.003629,0.03713,0.03452,0.01065,0.02632,0.003705,14.13,24.61,96.31,621.9,0.09329,0.2318,0.1604,0.06608,0.3207,0.07247,1
507,11.06,17.12,71.25,366.5,0.1194,0.1071,0.04063,0.04268,0.1954,0.07976,0.1779,1.03,1.318,12.3,0.01262,0.02348,0.018,0.01285,0.0222,0.008313,11.69,20.74,76.08,411.1,0.1662,0.2031,0.1256,0.09514,0.278,0.1168,1
279,13.85,15.18,88.99,587.4,0.09516,0.07688,0.04479,0.03711,0.211,0.05853,0.2479,0.9195,1.83,19.41,0.004235,0.01541,0.01457,0.01043,0.01528,0.001593,14.98,21.74,98.37,670.0,0.1185,0.1724,0.1456,0.09993,0.2955,0.06912,1
330,16.03,15.51,105.8,793.2,0.09491,0.1371,0.1204,0.07041,0.1782,0.05976,0.3371,0.7476,2.629,33.27,0.005839,0.03245,0.03715,0.01459,0.01467,0.003121,18.76,21.98,124.3,1070.0,0.1435,0.4478,0.4956,0.1981,0.3019,0.09124,0
545,13.62,23.23,87.19,573.2,0.09246,0.06747,0.02974,0.02443,0.1664,0.05801,0.346,1.336,2.066,31.24,0.005868,0.02099,0.02021,0.009064,0.02087,0.002583,15.35,29.09,97.58,729.8,0.1216,0.1517,0.1049,0.07174,0.2642,0.06953,1
232,11.22,33.81,70.79,386.8,0.0778,0.03574,0.004967,0.006434,0.1845,0.05828,0.2239,1.647,1.489,15.46,0.004359,0.006813,0.003223,0.003419,0.01916,0.002534,12.36,41.78,78.44,470.9,0.09994,0.06885,0.02318,0.03002,0.2911,0.07307,1
333,11.25,14.78,71.38,390.0,0.08306,0.04458,0.0009737,0.002941,0.1773,0.06081,0.2144,0.9961,1.529,15.07,0.005617,0.007124,0.0009737,0.002941,0.017,0.00203,12.76,22.06,82.08,492.7,0.1166,0.09794,0.005518,0.01667,0.2815,0.07418,1
290,14.41,19.73,96.03,651.0,0.08757,0.1676,0.1362,0.06602,0.1714,0.07192,0.8811,1.77,4.36,77.11,0.007762,0.1064,0.0996,0.02771,0.04077,0.02286,15.77,22.13,101.7,767.3,0.09983,0.2472,0.222,0.1021,0.2272,0.08799,1
299,10.51,23.09,66.85,334.2,0.1015,0.06797,0.02495,0.01875,0.1695,0.06556,0.2868,1.143,2.289,20.56,0.01017,0.01443,0.01861,0.0125,0.03464,0.001971,10.93,24.22,70.1,362.7,0.1143,0.08614,0.04158,0.03125,0.2227,0.06777,1
87,19.02,24.59,122.0,1076.0,0.09029,0.1206,0.1468,0.08271,0.1953,0.05629,0.5495,0.6636,3.055,57.65,0.003872,0.01842,0.0371,0.012,0.01964,0.003337,24.56,30.41,152.9,1623.0,0.1249,0.3206,0.5755,0.1956,0.3956,0.09288,0
294,12.72,13.78,81.78,492.1,0.09667,0.08393,0.01288,0.01924,0.1638,0.061,0.1807,0.6931,1.34,13.38,0.006064,0.0118,0.006564,0.007978,0.01374,0.001392,13.5,17.48,88.54,553.7,0.1298,0.1472,0.05233,0.06343,0.2369,0.06922,1
477,13.9,16.62,88.97,599.4,0.06828,0.05319,0.02224,0.01339,0.1813,0.05536,0.1555,0.5762,1.392,14.03,0.003308,0.01315,0.009904,0.004832,0.01316,0.002095,15.14,21.8,101.2,718.9,0.09384,0.2006,0.1384,0.06222,0.2679,0.07698,1
27,18.61,20.25,122.1,1094.0,0.0944,0.1066,0.149,0.07731,0.1697,0.05699,0.8529,1.849,5.632,93.54,0.01075,0.02722,0.05081,0.01911,0.02293,0.004217,21.31,27.26,139.9,1403.0,0.1338,0.2117,0.3446,0.149,0.2341,0.07421,0
84,12.0,15.65,76.95,443.3,0.09723,0.07165,0.04151,0.01863,0.2079,0.05968,0.2271,1.255,1.441,16.16,0.005969,0.01812,0.02007,0.007027,0.01972,0.002607,13.67,24.9,87.78,567.9,0.1377,0.2003,0.2267,0.07632,0.3379,0.07924,1
234,9.567,15.91,60.21,279.6,0.08464,0.04087,0.01652,0.01667,0.1551,0.06403,0.2152,0.8301,1.215,12.64,0.01164,0.0104,0.01186,0.009623,0.02383,0.00354,10.51,19.16,65.74,335.9,0.1504,0.09515,0.07161,0.07222,0.2757,0.08178,1
368,21.71,17.25,140.9,1546.0,0.09384,0.08562,0.1168,0.08465,0.1717,0.05054,1.207,1.051,7.733,224.1,0.005568,0.01112,0.02096,0.01197,0.01263,0.001803,30.75,26.44,199.5,3143.0,0.1363,0.1628,0.2861,0.182,0.251,0.06494,0
305,11.6,24.49,74.23,417.2,0.07474,0.05688,0.01974,0.01313,0.1935,0.05878,0.2512,1.786,1.961,18.21,0.006122,0.02337,0.01596,0.006998,0.03194,0.002211,12.44,31.62,81.39,476.5,0.09545,0.1361,0.07239,0.04815,0.3244,0.06745,1
5,12.45,15.7,82.57,477.1,0.1278,0.17,0.1578,0.08089,0.2087,0.07613,0.3345,0.8902,2.217,27.19,0.00751,0.03345,0.03672,0.01137,0.02165,0.005082,15.47,23.75,103.4,741.6,0.1791,0.5249,0.5355,0.1741,0.3985,0.1244,0
408,17.99,20.66,117.8,991.7,0.1036,0.1304,0.1201,0.08824,0.1992,0.06069,0.4537,0.8733,3.061,49.81,0.007231,0.02772,0.02509,0.0148,0.01414,0.003336,21.08,25.41,138.1,1349.0,0.1482,0.3735,0.3301,0.1974,0.306,0.08503,0
238,14.22,27.85,92.55,623.9,0.08223,0.1039,0.1103,0.04408,0.1342,0.06129,0.3354,2.324,2.105,29.96,0.006307,0.02845,0.0385,0.01011,0.01185,0.003589,15.75,40.54,102.5,764.0,0.1081,0.2426,0.3064,0.08219,0.189,0.07796,1
242,11.3,18.19,73.93,389.4,0.09592,0.1325,0.1548,0.02854,0.2054,0.07669,0.2428,1.642,2.369,16.39,0.006663,0.05914,0.0888,0.01314,0.01995,0.008675,12.58,27.96,87.16,472.9,0.1347,0.4848,0.7436,0.1218,0.3308,0.1297,1
id,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,radius error,texture error,perimeter error,area error,smoothness error,compactness error,concavity error,concave points error,symmetry error,fractal dimension error,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,y
231,11.32,27.08,71.76,395.7,0.06883,0.03813,0.01633,0.003125,0.1869,0.05628,0.121,0.8927,1.059,8.605,0.003653,0.01647,0.01633,0.003125,0.01537,0.002052,12.08,33.75,79.82,452.3,0.09203,0.1432,0.1089,0.02083,0.2849,0.07087,1
110,9.777,16.99,62.5,290.2,0.1037,0.08404,0.04334,0.01778,0.1584,0.07065,0.403,1.424,2.747,22.87,0.01385,0.02932,0.02722,0.01023,0.03281,0.004638,11.05,21.47,71.68,367.0,0.1467,0.1765,0.13,0.05334,0.2533,0.08468,1
327,12.03,17.93,76.09,446.0,0.07683,0.03892,0.001546,0.005592,0.1382,0.0607,0.2335,0.9097,1.466,16.97,0.004729,0.006887,0.001184,0.003951,0.01466,0.001755,13.07,22.25,82.74,523.4,0.1013,0.0739,0.007732,0.02796,0.2171,0.07037,1
374,13.69,16.07,87.84,579.1,0.08302,0.06374,0.02556,0.02031,0.1872,0.05669,0.1705,0.5066,1.372,14.0,0.00423,0.01587,0.01169,0.006335,0.01943,0.002177,14.84,20.21,99.16,670.6,0.1105,0.2096,0.1346,0.06987,0.3323,0.07701,1
511,14.81,14.7,94.66,680.7,0.08472,0.05016,0.03416,0.02541,0.1659,0.05348,0.2182,0.6232,1.677,20.72,0.006708,0.01197,0.01482,0.01056,0.0158,0.001779,15.61,17.58,101.7,760.2,0.1139,0.1011,0.1101,0.07955,0.2334,0.06142,1
259,15.53,33.56,103.7,744.9,0.1063,0.1639,0.1751,0.08399,0.2091,0.0665,0.2419,1.278,1.903,23.02,0.005345,0.02556,0.02889,0.01022,0.009947,0.003359,18.49,49.54,126.3,1035.0,0.1883,0.5564,0.5703,0.2014,0.3512,0.1204,0
514,15.05,19.07,97.26,701.9,0.09215,0.08597,0.07486,0.04335,0.1561,0.05915,0.386,1.198,2.63,38.49,0.004952,0.0163,0.02967,0.009423,0.01152,0.001718,17.58,28.06,113.8,967.0,0.1246,0.2101,0.2866,0.112,0.2282,0.06954,0
201,17.54,19.32,115.1,951.6,0.08968,0.1198,0.1036,0.07488,0.1506,0.05491,0.3971,0.8282,3.088,40.73,0.00609,0.02569,0.02713,0.01345,0.01594,0.002658,20.42,25.84,139.5,1239.0,0.1381,0.342,0.3508,0.1939,0.2928,0.07867,0
528,13.94,13.17,90.31,594.2,0.1248,0.09755,0.101,0.06615,0.1976,0.06457,0.5461,2.635,4.091,44.74,0.01004,0.03247,0.04763,0.02853,0.01715,0.005528,14.62,15.38,94.52,653.3,0.1394,0.1364,0.1559,0.1015,0.216,0.07253,1
390,10.26,12.22,65.75,321.6,0.09996,0.07542,0.01923,0.01968,0.18,0.06569,0.1911,0.5477,1.348,11.88,0.005682,0.01365,0.008496,0.006929,0.01938,0.002371,11.38,15.65,73.23,394.5,0.1343,0.165,0.08615,0.06696,0.2937,0.07722,1
28,15.3,25.27,102.4,732.4,0.1082,0.1697,0.1683,0.08751,0.1926,0.0654,0.439,1.012,3.498,43.5,0.005233,0.03057,0.03576,0.01083,0.01768,0.002967,20.27,36.71,149.3,1269.0,0.1641,0.611,0.6335,0.2024,0.4027,0.09876,0
346,12.06,18.9,76.66,445.3,0.08386,0.05794,0.00751,0.008488,0.1555,0.06048,0.243,1.152,1.559,18.02,0.00718,0.01096,0.005832,0.005495,0.01982,0.002754,13.64,27.06,86.54,562.6,0.1289,0.1352,0.04506,0.05093,0.288,0.08083,1
206,9.876,17.27,62.92,295.4,0.1089,0.07232,0.01756,0.01952,0.1934,0.06285,0.2137,1.342,1.517,12.33,0.009719,0.01249,0.007975,0.007527,0.0221,0.002472,10.42,23.22,67.08,331.6,0.1415,0.1247,0.06213,0.05588,0.2989,0.0738,1
428,11.13,16.62,70.47,381.1,0.08151,0.03834,0.01369,0.0137,0.1511,0.06148,0.1415,0.9671,0.968,9.704,0.005883,0.006263,0.009398,0.006189,0.02009,0.002377,11.68,20.29,74.35,421.1,0.103,0.06219,0.0458,0.04044,0.2383,0.07083,1
277,18.81,19.98,120.9,1102.0,0.08923,0.05884,0.0802,0.05843,0.155,0.04996,0.3283,0.828,2.363,36.74,0.007571,0.01114,0.02623,0.01463,0.0193,0.001676,19.96,24.3,129.0,1236.0,0.1243,0.116,0.221,0.1294,0.2567,0.05737,0
224,13.27,17.02,84.55,546.4,0.08445,0.04994,0.03554,0.02456,0.1496,0.05674,0.2927,0.8907,2.044,24.68,0.006032,0.01104,0.02259,0.009057,0.01482,0.002496,15.14,23.6,98.84,708.8,0.1276,0.1311,0.1786,0.09678,0.2506,0.07623,1
443,10.57,18.32,66.82,340.9,0.08142,0.04462,0.01993,0.01111,0.2372,0.05768,0.1818,2.542,1.277,13.12,0.01072,0.01331,0.01993,0.01111,0.01717,0.004492,10.94,23.31,69.35,366.3,0.09794,0.06542,0.03986,0.02222,0.2699,0.06736,1
11,15.78,17.89,103.6,781.0,0.0971,0.1292,0.09954,0.06606,0.1842,0.06082,0.5058,0.9849,3.564,54.16,0.005771,0.04061,0.02791,0.01282,0.02008,0.004144,20.42,27.28,136.5,1299.0,0.1396,0.5609,0.3965,0.181,0.3792,0.1048,0
56,19.21,18.57,125.5,1152.0,0.1053,0.1267,0.1323,0.08994,0.1917,0.05961,0.7275,1.193,4.837,102.5,0.006458,0.02306,0.02945,0.01538,0.01852,0.002608,26.14,28.14,170.1,2145.0,0.1624,0.3511,0.3879,0.2091,0.3537,0.08294,0
497,12.47,17.31,80.45,480.1,0.08928,0.0763,0.03609,0.02369,0.1526,0.06046,0.1532,0.781,1.253,11.91,0.003796,0.01371,0.01346,0.007096,0.01536,0.001541,14.06,24.34,92.82,607.3,0.1276,0.2506,0.2028,0.1053,0.3035,0.07661,1
345,10.26,14.71,66.2,321.6,0.09882,0.09159,0.03581,0.02037,0.1633,0.07005,0.338,2.509,2.394,19.33,0.01736,0.04671,0.02611,0.01296,0.03675,0.006758,10.88,19.48,70.89,357.1,0.136,0.1636,0.07162,0.04074,0.2434,0.08488,1
4,20.29,14.34,135.1,1297.0,0.1003,0.1328,0.198,0.1043,0.1809,0.05883,0.7572,0.7813,5.438,94.44,0.01149,0.02461,0.05688,0.01885,0.01756,0.005115,22.54,16.67,152.2,1575.0,0.1374,0.205,0.4,0.1625,0.2364,0.07678,0
99,14.42,19.77,94.48,642.5,0.09752,0.1141,0.09388,0.05839,0.1879,0.0639,0.2895,1.851,2.376,26.85,0.008005,0.02895,0.03321,0.01424,0.01462,0.004452,16.33,30.86,109.5,826.4,0.1431,0.3026,0.3194,0.1565,0.2718,0.09353,0
86,14.48,21.46,94.25,648.2,0.09444,0.09947,0.1204,0.04938,0.2075,0.05636,0.4204,2.22,3.301,38.87,0.009369,0.02983,0.05371,0.01761,0.02418,0.003249,16.21,29.25,108.4,808.9,0.1306,0.1976,0.3349,0.1225,0.302,0.06846,0
122,24.25,20.2,166.2,1761.0,0.1447,0.2867,0.4268,0.2012,0.2655,0.06877,1.509,3.12,9.807,233.0,0.02333,0.09806,0.1278,0.01822,0.04547,0.009875,26.02,23.99,180.9,2073.0,0.1696,0.4244,0.5803,0.2248,0.3222,0.08009,0
145,11.9,14.65,78.11,432.8,0.1152,0.1296,0.0371,0.03003,0.1995,0.07839,0.3962,0.6538,3.021,25.03,0.01017,0.04741,0.02789,0.0111,0.03127,0.009423,13.15,16.51,86.26,509.6,0.1424,0.2517,0.0942,0.06042,0.2727,0.1036,1
401,11.93,10.91,76.14,442.7,0.08872,0.05242,0.02606,0.01796,0.1601,0.05541,0.2522,1.045,1.649,18.95,0.006175,0.01204,0.01376,0.005832,0.01096,0.001857,13.8,20.14,87.64,589.5,0.1374,0.1575,0.1514,0.06876,0.246,0.07262,1
409,12.27,17.92,78.41,466.1,0.08685,0.06526,0.03211,0.02653,0.1966,0.05597,0.3342,1.781,2.079,25.79,0.005888,0.0231,0.02059,0.01075,0.02578,0.002267,14.1,28.88,89.0,610.2,0.124,0.1795,0.1377,0.09532,0.3455,0.06896,1
338,10.05,17.53,64.41,310.8,0.1007,0.07326,0.02511,0.01775,0.189,0.06331,0.2619,2.015,1.778,16.85,0.007803,0.01449,0.0169,0.008043,0.021,0.002778,11.16,26.84,71.98,384.0,0.1402,0.1402,0.1055,0.06499,0.2894,0.07664,1
15,14.54,27.54,96.73,658.8,0.1139,0.1595,0.1639,0.07364,0.2303,0.07077,0.37,1.033,2.879,32.55,0.005607,0.0424,0.04741,0.0109,0.01857,0.005466,17.46,37.13,124.1,943.2,0.1678,0.6577,0.7026,0.1712,0.4218,0.1341,0
71,8.888,14.64,58.79,244.0,0.09783,0.1531,0.08606,0.02872,0.1902,0.0898,0.5262,0.8522,3.168,25.44,0.01721,0.09368,0.05671,0.01766,0.02541,0.02193,9.733,15.67,62.56,284.4,0.1207,0.2436,0.1434,0.04786,0.2254,0.1084,1
119,17.95,20.01,114.2,982.0,0.08402,0.06722,0.07293,0.05596,0.2129,0.05025,0.5506,1.214,3.357,54.04,0.004024,0.008422,0.02291,0.009863,0.05014,0.001902,20.58,27.83,129.2,1261.0,0.1072,0.1202,0.2249,0.1185,0.4882,0.06111,0
458,13.0,25.13,82.61,520.2,0.08369,0.05073,0.01206,0.01762,0.1667,0.05449,0.2621,1.232,1.657,21.19,0.006054,0.008974,0.005681,0.006336,0.01215,0.001514,14.34,31.88,91.06,628.5,0.1218,0.1093,0.04462,0.05921,0.2306,0.06291,1
51,13.64,16.34,87.21,571.8,0.07685,0.06059,0.01857,0.01723,0.1353,0.05953,0.1872,0.9234,1.449,14.55,0.004477,0.01177,0.01079,0.007956,0.01325,0.002551,14.67,23.19,96.08,656.7,0.1089,0.1582,0.105,0.08586,0.2346,0.08025,1
257,15.32,17.27,103.2,713.3,0.1335,0.2284,0.2448,0.1242,0.2398,0.07596,0.6592,1.059,4.061,59.46,0.01015,0.04588,0.04983,0.02127,0.01884,0.00866,17.73,22.66,119.8,928.8,0.1765,0.4503,0.4429,0.2229,0.3258,0.1191,0
378,13.66,15.15,88.27,580.6,0.08268,0.07548,0.04249,0.02471,0.1792,0.05897,0.1402,0.5417,1.101,11.35,0.005212,0.02984,0.02443,0.008356,0.01818,0.004868,14.54,19.64,97.96,657.0,0.1275,0.3104,0.2569,0.1054,0.3387,0.09638,1
63,9.173,13.86,59.2,260.9,0.07721,0.08751,0.05988,0.0218,0.2341,0.06963,0.4098,2.265,2.608,23.52,0.008738,0.03938,0.04312,0.0156,0.04192,0.005822,10.01,19.23,65.59,310.1,0.09836,0.1678,0.1397,0.05087,0.3282,0.0849,1
475,12.83,15.73,82.89,506.9,0.0904,0.08269,0.05835,0.03078,0.1705,0.05913,0.1499,0.4875,1.195,11.64,0.004873,0.01796,0.03318,0.00836,0.01601,0.002289,14.09,19.35,93.22,605.8,0.1326,0.261,0.3476,0.09783,0.3006,0.07802,1
407,12.85,21.37,82.63,514.5,0.07551,0.08316,0.06126,0.01867,0.158,0.06114,0.4993,1.798,2.552,41.24,0.006011,0.0448,0.05175,0.01341,0.02669,0.007731,14.4,27.01,91.63,645.8,0.09402,0.1936,0.1838,0.05601,0.2488,0.08151,1
220,13.65,13.16,87.88,568.9,0.09646,0.08711,0.03888,0.02563,0.136,0.06344,0.2102,0.4336,1.391,17.4,0.004133,0.01695,0.01652,0.006659,0.01371,0.002735,15.34,16.35,99.71,706.2,0.1311,0.2474,0.1759,0.08056,0.238,0.08718,1
413,14.99,22.11,97.53,693.7,0.08515,0.1025,0.06859,0.03876,0.1944,0.05913,0.3186,1.336,2.31,28.51,0.004449,0.02808,0.03312,0.01196,0.01906,0.004015,16.76,31.55,110.2,867.1,0.1077,0.3345,0.3114,0.1308,0.3163,0.09251,1
424,9.742,19.12,61.93,289.7,0.1075,0.08333,0.008934,0.01967,0.2538,0.07029,0.6965,1.747,4.607,43.52,0.01307,0.01885,0.006021,0.01052,0.031,0.004225,11.21,23.17,71.79,380.9,0.1398,0.1352,0.02085,0.04589,0.3196,0.08009,1
441,17.27,25.42,112.4,928.8,0.08331,0.1109,0.1204,0.05736,0.1467,0.05407,0.51,1.679,3.283,58.38,0.008109,0.04308,0.04942,0.01742,0.01594,0.003739,20.38,35.46,132.8,1284.0,0.1436,0.4122,0.5036,0.1739,0.25,0.07944,0
18,19.81,22.15,130.0,1260.0,0.09831,0.1027,0.1479,0.09498,0.1582,0.05395,0.7582,1.017,5.865,112.4,0.006494,0.01893,0.03391,0.01521,0.01356,0.001997,27.32,30.88,186.8,2398.0,0.1512,0.315,0.5372,0.2388,0.2768,0.07615,0
315,12.49,16.85,79.19,481.6,0.08511,0.03834,0.004473,0.006423,0.1215,0.05673,0.1716,0.7151,1.047,12.69,0.004928,0.003012,0.00262,0.00339,0.01393,0.001344,13.34,19.71,84.48,544.2,0.1104,0.04953,0.01938,0.02784,0.1917,0.06174,1
225,14.34,13.47,92.51,641.2,0.09906,0.07624,0.05724,0.04603,0.2075,0.05448,0.522,0.8121,3.763,48.29,0.007089,0.01428,0.0236,0.01286,0.02266,0.001463,16.77,16.9,110.4,873.2,0.1297,0.1525,0.1632,0.1087,0.3062,0.06072,1
470,9.667,18.49,61.49,289.1,0.08946,0.06258,0.02948,0.01514,0.2238,0.06413,0.3776,1.35,2.569,22.73,0.007501,0.01989,0.02714,0.009883,0.0196,0.003913,11.14,25.62,70.88,385.2,0.1234,0.1542,0.1277,0.0656,0.3174,0.08524,1
451,19.59,25.0,127.7,1191.0,0.1032,0.09871,0.1655,0.09063,0.1663,0.05391,0.4674,1.375,2.916,56.18,0.0119,0.01929,0.04907,0.01499,0.01641,0.001807,21.44,30.96,139.8,1421.0,0.1528,0.1845,0.3977,0.1466,0.2293,0.06091,0
152,9.731,15.34,63.78,300.2,0.1072,0.1599,0.4108,0.07857,0.2548,0.09296,0.8245,2.664,4.073,49.85,0.01097,0.09586,0.396,0.05279,0.03546,0.02984,11.02,19.49,71.04,380.5,0.1292,0.2772,0.8216,0.1571,0.3108,0.1259,1
222,10.18,17.53,65.12,313.1,0.1061,0.08502,0.01768,0.01915,0.191,0.06908,0.2467,1.217,1.641,15.05,0.007899,0.014,0.008534,0.007624,0.02637,0.003761,11.17,22.84,71.94,375.6,0.1406,0.144,0.06572,0.05575,0.3055,0.08797,1
487,19.44,18.82,128.1,1167.0,0.1089,0.1448,0.2256,0.1194,0.1823,0.06115,0.5659,1.408,3.631,67.74,0.005288,0.02833,0.04256,0.01176,0.01717,0.003211,23.96,30.39,153.9,1740.0,0.1514,0.3725,0.5936,0.206,0.3266,0.09009,0
364,13.4,16.95,85.48,552.4,0.07937,0.05696,0.02181,0.01473,0.165,0.05701,0.1584,0.6124,1.036,13.22,0.004394,0.0125,0.01451,0.005484,0.01291,0.002074,14.73,21.7,93.76,663.5,0.1213,0.1676,0.1364,0.06987,0.2741,0.07582,1
564,21.56,22.39,142.0,1479.0,0.111,0.1159,0.2439,0.1389,0.1726,0.05623,1.176,1.256,7.673,158.7,0.0103,0.02891,0.05198,0.02454,0.01114,0.004239,25.45,26.4,166.1,2027.0,0.141,0.2113,0.4107,0.2216,0.206,0.07115,0
193,12.34,26.86,81.15,477.4,0.1034,0.1353,0.1085,0.04562,0.1943,0.06937,0.4053,1.809,2.642,34.44,0.009098,0.03845,0.03763,0.01321,0.01878,0.005672,15.65,39.34,101.7,768.9,0.1785,0.4706,0.4425,0.1459,0.3215,0.1205,0
376,10.57,20.22,70.15,338.3,0.09073,0.166,0.228,0.05941,0.2188,0.0845,0.1115,1.231,2.363,7.228,0.008499,0.07643,0.1535,0.02919,0.01617,0.0122,10.85,22.82,76.51,351.9,0.1143,0.3619,0.603,0.1465,0.2597,0.12,1
174,10.66,15.15,67.49,349.6,0.08792,0.04302,0.0,0.0,0.1928,0.05975,0.3309,1.925,2.155,21.98,0.008713,0.01017,0.0,0.0,0.03265,0.001002,11.54,19.2,73.2,408.3,0.1076,0.06791,0.0,0.0,0.271,0.06164,1
566,16.6,28.08,108.3,858.1,0.08455,0.1023,0.09251,0.05302,0.159,0.05648,0.4564,1.075,3.425,48.55,0.005903,0.03731,0.0473,0.01557,0.01318,0.003892,18.98,34.12,126.7,1124.0,0.1139,0.3094,0.3403,0.1418,0.2218,0.0782,0
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
#!/bin/bash
bash pre_build.sh
bazel build --config=linux :libpsi_test
bazel build --config=linux_x86_64 :libpsi_test
cd ./bazel-bin
......
import numpy as np
from sklearn import metrics
from collections import Iterable
def dloss(p, y):
z = p * y
if z > 18.0:
return np.exp(-z) * -y
if z < -18.0:
return -y
return -y / (np.exp(z) + 1.0)
def batch_yield(x, y, batch_size):
for i in range(0, x.shape[0], batch_size):
yield (x[i:i + batch_size], y[i:i + batch_size])
def trucate_geometric_thres(x, clip_thres, variation, times=2):
if isinstance(x, Iterable):
norm_x = np.sqrt(sum(x * x))
n = len(x)
else:
norm_x = abs(x)
n = 1
clip_thres = np.max([1, norm_x / clip_thres])
clip_x = x / clip_thres
dp_noise = None
for _ in range(2 * times):
cur_noise = np.random.normal(0, clip_thres * variation, n)
if dp_noise is None:
dp_noise = cur_noise
else:
dp_noise += cur_noise
dp_noise /= np.sqrt(2 * times)
dp_x = clip_x + dp_noise
return dp_x
class HeteroLrBase:
def __init__(self,
learning_rate=0.01,
alpha=0.0001,
epochs=10,
penalty="l2",
batch_size=64,
optimal_method=None,
update_type=None,
loss_type='log',
random_state=2023,
clip_thres=1.0,
noise_variation=1.0):
self.learning_rate = learning_rate
self.alpha = alpha
self.epochs = epochs
self.batch_size = batch_size
self.penalty = penalty
self.optimal_method = optimal_method
self.random_state = random_state
self.update_type = update_type
self.loss_type = loss_type
self.clip_thres = clip_thres
self.noise_variation = noise_variation
self.theta = 0
def fit(self):
pass
def predict(self):
pass
def loss(self, y_hat, y_true):
if self.loss_type == 'log':
y_prob = self.sigmoid(y_hat)
return metrics.log_loss(y_true, y_prob)
elif self.loss == "squarederror":
return metrics.mean_squared_error(
y_true, y_hat) # mse don't activate inputs
else:
raise KeyError('The type is not implemented!')
class PlainLR:
def __init__(self,
learning_rate=0.01,
alpha=0.0001,
epochs=10,
penalty="l2",
batch_size=64,
optimal_method=None,
update_type=None,
loss_type='log',
random_state=2023):
self.learning_rate = learning_rate
self.alpha = alpha
self.epochs = epochs
self.batch_size = batch_size
self.penalty = penalty
self.optimal_method = optimal_method
self.random_state = random_state
self.update_type = update_type
self.loss_type = loss_type
self.theta = 0
def add_intercept(self, x):
intercept = np.ones((x.shape[0], 1))
return np.concatenate((intercept, x), axis=1)
def sigmoid(self, x):
return 1.0 / (1 + np.exp(-x))
def predict_prob(self, x):
return self.sigmoid(np.dot(x, self.theta))
def predict(self, x):
preds = self.predict_prob(x)
preds[preds <= 0.5] = 0
preds[preds > 0.5] = 1
return preds
def gradient(self, x, y):
h = self.predict_prob(x)
if self.penalty == "l2":
grad = (np.dot(x.T, (h - y)) / x.shape[0] + self.alpha * self.theta
) #/ x.shape[0]
elif self.penalty == "l1":
raise ValueError("It's not implemented now!")
else:
grad = np.dot(x.T, (h - y)) / x.shape[0] #/ x.shape[0]
return grad
def update_lr(self, current_epoch, type="sqrt"):
if type == "sqrt":
self.learning_rate /= np.sqrt(current_epoch + 1)
else:
typw = np.sqrt(1.0 / np.sqrt(self.alpha))
initial_eta0 = typw / max(1.0, dloss(-typw, 1.0))
optimal_init = 1.0 / (initial_eta0 * self.alpha)
self.learning_rate = 1.0 / (self.alpha *
(optimal_init + current_epoch + 1))
def simple_gd(self, x, y):
grad = self.gradient(x, y)
self.theta -= self.learning_rate * grad
# print("======", self.theta, grad, self.learning_rate)
def batch_gd(self, x, y):
for batch_x, bathc_y in batch_yield(x, y, self.batch_size):
grad = self.gradient(batch_x, bathc_y)
self.theta -= self.learning_rate * grad
def fit(self, x, y):
x = self.add_intercept(x)
if self.batch_size < 0:
self.batch_size = x.shape[0]
np.random.seed(self.random_state)
self.theta = np.random.rand(x.shape[1])
total_loss = []
for i in range(self.epochs):
self.update_lr(i, type=self.update_type)
if self.optimal_method == "simple":
self.simple_gd(x, y)
else:
self.batch_gd(x, y)
print("current iteration and theta", i, self.theta)
y_hat = np.dot(x, self.theta)
current_loss = self.loss(y_hat, y)
total_loss.append(current_loss)
# print("current iteration and loss", i, current_loss)
print("loss", total_loss)
def loss(self, y_hat, y_true):
if self.loss_type == 'log':
y_prob = self.sigmoid(y_hat)
return metrics.log_loss(y_true, y_prob)
elif self.loss == "squarederror":
return metrics.mean_squared_error(
y_true, y_hat) # mse don't activate inputs
else:
raise KeyError('The type is not implemented!')
import numpy as np
from primihub.FL.model.logistic_regression.hetero_lr_base import HeteroLrBase, batch_yield, trucate_geometric_thres
class HeterLrGuest(HeteroLrBase):
def __init__(self,
learning_rate=0.01,
alpha=0.0001,
epochs=10,
penalty="l2",
batch_size=64,
optimal_method=None,
update_type=None,
loss_type='log',
random_state=2023,
guest_channel=None,
add_noise=True):
super().__init__(learning_rate, alpha, epochs, penalty, batch_size,
optimal_method, update_type, loss_type, random_state)
self.channel = guest_channel
self.add_noise = add_noise
def predict(self, x):
guest_part = np.dot(x, self.theta)
# if self.add_noise:
# guest_part = trucate_geometric_thres(guest_part, self.clip_thres,
# self.noise_variation)
self.channel.sender("guest_part", guest_part)
def gradient(self, x):
error = self.channel.recv('error')
if self.penalty == "l2":
grad = (np.dot(x.T, error) / x.shape[0] + self.alpha * self.theta
) #/ x.shape[0]
elif self.penalty == "l1":
raise ValueError("It's not implemented now!")
else:
grad = np.dot(x.T, error) / x.shape[0] #/ x.shape[0]
return grad
def batch_gd(self, x):
for batch_x, _ in batch_yield(x, x, self.batch_size):
self.predict(batch_x)
grad = self.gradient(batch_x)
self.theta -= self.learning_rate * grad
def simple_gd(self, x):
self.predict(x)
grad = self.gradient(x)
self.theta -= self.learning_rate * grad
def fit(self, x):
if self.batch_size < 0:
self.batch_size = x.shape[0]
self.theta = np.zeros(x.shape[1])
for i in range(self.epochs):
self.learning_rate = self.channel.recv("learning_rate")
if self.optimal_method == "simple":
self.simple_gd(x)
else:
self.batch_gd(x)
self.predict(x)
is_converged = self.channel.recv('is_converged')
if is_converged:
break
import numpy as np
from primihub.FL.model.logistic_regression.hetero_lr_base import HeteroLrBase, batch_yield, dloss, trucate_geometric_thres
class HeterLrHost(HeteroLrBase):
def __init__(self,
learning_rate=0.01,
alpha=0.0001,
epochs=10,
penalty="l2",
batch_size=64,
optimal_method=None,
update_type=None,
loss_type='log',
random_state=2023,
host_channel=None,
add_noise=True,
tol=0.001):
super().__init__(learning_rate, alpha, epochs, penalty, batch_size,
optimal_method, update_type, loss_type, random_state)
self.channel = host_channel
self.add_noise = add_noise
self.tol = tol
def add_intercept(self, x):
intercept = np.ones((x.shape[0], 1))
return np.concatenate((intercept, x), axis=1)
def sigmoid(self, x):
return 1.0 / (1 + np.exp(-x))
def predict_raw(self, x):
host_part = np.dot(x, self.theta)
guest_part = self.channel.recv("guest_part")
h = host_part + guest_part
return h
def predict(self, x):
preds = self.sigmoid(self.predict_raw(x))
preds[preds <= 0.5] = 0
preds[preds > 0.5] = 1
return preds
def update_lr(self, current_epoch, type="sqrt"):
if type == "sqrt":
self.learning_rate /= np.sqrt(current_epoch + 1)
else:
typw = np.sqrt(1.0 / np.sqrt(self.alpha))
initial_eta0 = typw / max(1.0, dloss(-typw, 1.0))
optimal_init = 1.0 / (initial_eta0 * self.alpha)
self.learning_rate = 1.0 / (self.alpha *
(optimal_init + current_epoch + 1))
def gradient(self, x, y):
h = self.sigmoid(self.predict_raw(x))
error = h - y
# self.channel.sender('error', error)
if self.add_noise:
# nois_error = trucate_geometric_thres(error,
# clip_thres=self.clip_thres,
# variation=self.noise_variation)
# add adaptive-noise for error
error_std = np.std(error)
noise = np.random.normal(0, error_std, error.shape)
nois_error = error + noise
else:
nois_error = error
self.channel.sender('error', nois_error)
if self.penalty == "l2":
grad = (np.dot(x.T, error) / x.shape[0] + self.alpha * self.theta
) #/ x.shape[0]
elif self.penalty == "l1":
raise ValueError("It's not implemented now!")
else:
grad = np.dot(x.T, error) / x.shape[0] #/ x.shape[0]
return grad
def batch_gd(self, x, y):
for batch_x, bathc_y in batch_yield(x, y, self.batch_size):
grad = self.gradient(batch_x, bathc_y)
self.theta -= self.learning_rate * grad
def simple_gd(self, x, y):
grad = self.gradient(x, y)
self.theta -= self.learning_rate * grad
def fit(self, x, y):
x = self.add_intercept(x)
if self.batch_size < 0:
self.batch_size = x.shape[0]
self.theta = np.zeros(x.shape[1])
pre_loss = None
is_converged = False
for i in range(self.epochs):
self.update_lr(i, type=self.update_type)
self.channel.sender("learning_rate", self.learning_rate)
if self.optimal_method == "simple":
self.simple_gd(x, y)
else:
self.batch_gd(x, y)
y_hat = self.predict_raw(x)
cur_loss = self.loss(y_hat, y)
if pre_loss is None:
pre_loss = cur_loss
else:
loss_diff = abs(pre_loss - cur_loss)
pre_loss = cur_loss
if loss_diff < self.tol:
is_converged = True
self.channel.sender('is_converged', is_converged)
if is_converged:
break
pred_prob = self.sigmoid(y_hat)
preds = (pred_prob > 0.5).astype('int')
acc = sum((preds == y).astype('int')) / len(y)
print("acc: ", acc)
import primihub as ph
import pandas as pd
from primihub import dataset, context
from primihub.utils.net_worker import GrpcServer
from primihub.FL.model.logistic_regression.hetero_lr_host import HeterLrHost
from primihub.FL.model.logistic_regression.hetero_lr_guest import HeterLrGuest
from sklearn.preprocessing import StandardScaler, MinMaxScaler
config = {
"learning_rate": 0.01,
'alpha': 0.0001,
"epochs": 50,
"penalty": "l2",
"optimal_method": "Complex",
"random_state": 2023,
"host_columns": None,
"guest_columns": None,
"scale_type": 'z-score',
"batch_size": 512
}
@ph.context.function(
role='host',
protocol='hetero_lr',
datasets=['train_hetero_xgb_host'
], # ['train_hetero_xgb_host'], #, 'test_hetero_xgb_host'],
port='8000',
task_type="classification")
def lr_host_logic():
role_node_map = ph.context.Context.get_role_node_map()
node_addr_map = ph.context.Context.get_node_addr_map()
dataset_map = ph.context.Context.dataset_map
taskId = ph.context.Context.params_map['taskid']
jobId = ph.context.Context.params_map['jobid']
host_nodes = role_node_map["host"]
host_port = node_addr_map[host_nodes[0]].split(":")[1]
host_ip = node_addr_map[host_nodes[0]].split(":")[0]
guest_nodes = role_node_map["guest"]
guest_ip, guest_port = node_addr_map[guest_nodes[0]].split(":")
data_key = list(dataset_map.keys())[0]
data = ph.dataset.read(dataset_key=data_key).df_data
print("ports: ", guest_port, host_port)
#data = pd.read_csv("/home/xusong/data/epsilon_normalized.host", header=0)
host_cols = config['host_columns']
if host_cols is not None:
data = data[host_cols]
if 'id' in data.columns:
data.pop('id')
Y = data.pop('y').values
X_host = data.copy()
# grpc server initialization
host_channel = GrpcServer(remote_ip=guest_ip,
local_ip=host_ip,
remote_port=guest_port,
local_port=host_port,
context=ph.context.Context)
lr_host = HeterLrHost(learning_rate=config['learning_rate'],
alpha=config['alpha'],
epochs=config['epochs'],
optimal_method=config['optimal_method'],
random_state=config['random_state'],
host_channel=host_channel,
add_noise=False,
batch_size=config['batch_size'])
scale_type = config['scale_type']
scale_type = config['scale_type']
if scale_type is not None:
if scale_type == "z-score":
std = StandardScaler()
else:
std = MinMaxScaler()
scale_x = std.fit_transform(X_host)
else:
scale_x = X_host.copy()
lr_host.fit(scale_x, Y)
@ph.context.function(
role='guest',
protocol='heter_lr',
datasets=[
'train_hetero_xgb_guest' #'five_thous_guest'
], #['train_hetero_xgb_guest'], #, 'test_hetero_xgb_guest'],
port='9000',
task_type="classification")
def lr_guest_logic(cry_pri="paillier"):
role_node_map = ph.context.Context.get_role_node_map()
node_addr_map = ph.context.Context.get_node_addr_map()
dataset_map = ph.context.Context.dataset_map
taskId = ph.context.Context.params_map['taskid']
jobId = ph.context.Context.params_map['jobid']
guest_nodes = role_node_map["guest"]
guest_port = node_addr_map[guest_nodes[0]].split(":")[1]
guest_ip = node_addr_map[guest_nodes[0]].split(":")[0]
host_nodes = role_node_map["host"]
host_ip, host_port = node_addr_map[host_nodes[0]].split(":")
data_key = list(dataset_map.keys())[0]
data = ph.dataset.read(dataset_key=data_key).df_data
print("ports: ", host_port, guest_port)
# data = pd.read_csv("/home/xusong/data/epsilon_normalized.guest", header=0)
guest_cols = config['guest_columns']
if guest_cols is not None:
data = data[guest_cols]
if 'id' in data.columns:
data.pop('id')
X_guest = data
guest_channel = GrpcServer(remote_ip=host_ip,
remote_port=host_port,
local_ip=guest_ip,
local_port=guest_port,
context=ph.context.Context)
lr_guest = HeterLrGuest(learning_rate=config['learning_rate'],
alpha=config['alpha'],
epochs=config['epochs'],
optimal_method=config['optimal_method'],
random_state=config['random_state'],
guest_channel=guest_channel,
batch_size=config['batch_size'])
scale_type = config['scale_type']
if scale_type is not None:
if scale_type == "z-score":
std = StandardScaler()
else:
std = MinMaxScaler()
scale_x = std.fit_transform(X_guest)
else:
scale_x = X_guest.copy()
lr_guest.fit(scale_x)
# -*- coding:utf-8
import numpy as np
from primihub.FL.feature_engineer.onehot_encode import HorOneHotEncoder
from sklearn.preprocessing import MinMaxScaler
class LRModel:
def __init__(self, X, y, category, w=None):
self.w_size = X.shape[1] + 1
self.coef = None
self.intercept = None
self.theta = None
self.one_vs_rest_theta = np.random.uniform(-0.5, 0.5, (category, self.w_size))
if w is not None:
self.theta = w
# l2 regularization by default, alpha is the penalty parameter
def __init__(self, X, y, category, learning_rate=0.2, alpha=0.0001):
self.learning_rate = learning_rate
self.alpha = alpha # regularization parameter
self.t = 0 # iteration number, used for learning rate decay
if category == 2:
self.theta = np.random.uniform(-0.5, 0.5, (X.shape[1] + 1,))
self.multi_class = False
else:
# init model parameters
self.theta = np.random.uniform(-0.5, 0.5, (self.w_size,))
self.one_vs_rest_theta = np.random.uniform(-0.5, 0.5, (category, X.shape[1] + 1))
self.multi_class = True
# 'optimal' learning rate refer to sklearn SGDClassifier
def dloss(p, y):
z = p * y
if z > 18.0:
return np.exp(-z) * -y
if z < -18.0:
return -y
return -y / (np.exp(z) + 1.0)
typw = np.sqrt(1.0 / np.sqrt(alpha))
# computing eta0, the initial learning rate
initial_eta0 = typw / max(1.0, dloss(-typw, 1.0))
# initialize t such that eta at first sample equals eta0
self.optimal_init = 1.0 / (initial_eta0 * alpha)
# if encrypted == True:
# self.theta = self.utils.encrypt_vector(public_key, self.theta)
@staticmethod
def normalization(x):
"""
data normalization
"""
scaler = MinMaxScaler()
scaler = scaler.fit(x)
x = scaler.transform(x)
return x
def sigmoid(self, x):
x = np.array(x, dtype=np.float64)
y = 1.0 / (1.0 + np.exp(-x))
return y
def loss_func(self, theta, x_b, y):
"""
loss function
:param theta: intercept and coef
:param x_b: training data
:param y: label
:return:
"""
p_predict = self.sigmoid(x_b.dot(theta))
def get_theta(self):
return self.theta
def set_theta(self, theta):
if not isinstance(theta, np.ndarray):
theta = np.array(theta)
self.theta = theta
def loss(self, x, y):
temp = x.dot(self.theta[1:]) + self.theta[0]
try:
return -np.sum(y * np.log(p_predict) + (1 - y) * np.log(1 - p_predict))
return (np.maximum(temp, 0.).sum() - y.dot(temp) +
np.log(1 + np.exp(-np.abs(temp))).sum() +
0.5 * self.alpha * self.theta.dot(self.theta)) / x.shape[0]
except:
return float('inf')
def d_loss_func(self, theta, x_b, y):
out = self.sigmoid(x_b.dot(theta))
return x_b.T.dot(out - y) / len(x_b)
def compute_grad(self, x, y):
temp = self.predict_prob(x) - y
return (np.concatenate((temp.sum(keepdims=True), x.T.dot(temp)))
+ self.alpha * self.theta) / x.shape[0]
def gradient_descent(self, x_b, y, theta, eta):
def gradient_descent(self, x, y):
grad = self.compute_grad(x, y)
self.theta -= self.learning_rate * grad
def gradient_descent_olr(self, x, y):
"""
:param x_b: training data
:param y: label
:param theta: model parameters
:param eta: learning rate
:return:
optimal learning rate
"""
gradient = self.d_loss_func(theta, x_b, y)
theta = theta - eta * gradient
return theta
def fit(self, train_data, train_label, theta, eta=0.01,):
assert train_data.shape[0] == train_label.shape[0], "The length of the training data set shall " \
"be consistent with the length of the label"
x_b = np.hstack([np.ones((train_data.shape[0], 1)), train_data])
self.theta = self.gradient_descent(x_b, train_label, theta, eta)
self.intercept = self.theta[0]
self.coef = self.theta[1:]
return self.theta
grad = self.compute_grad(x, y)
learning_rate = 1.0 / (self.alpha * (self.optimal_init + self.t))
self.t += 1
self.theta -= learning_rate * grad
def predict_prob(self, x_predict):
x_b = np.hstack([np.ones((len(x_predict), 1)), x_predict])
return self.sigmoid(x_b.dot(self.theta))
def fit(self, x, y):
self.gradient_descent_olr(x, y)
def predict(self, x_predict):
"""
classification
"""
prob = self.predict_prob(x_predict)
def predict_prob(self, x):
return self.sigmoid(x.dot(self.theta[1:]) + self.theta[0])
def predict(self, x):
prob = self.predict_prob(x)
return np.array(prob > 0.5, dtype='int')
def one_vs_rest(self, X, y, k):
......@@ -111,3 +107,4 @@ class LRModel:
# def load_dummies(self, union_cats_len, union_cats_idxs):
# self.onehot_encoder.cats_len = union_cats_len
# self.onehot_encoder.cats_idxs = union_cats_idxs
import primihub as ph
from primihub.FL.model.logistic_regression.homo_lr_dev import run_party
config = {
'mode': 'DPSGD',
'delta': 1e-3,
'noise_multiplier': 2.0,
'l2_norm_clip': 1.0,
'secure_mode': True,
'learning_rate': 'optimal',
'alpha': 0.0001,
'batch_size': 50,
'max_iter': 100,
'category': 2,
'feature_names': None,
}
@ph.context.function(role='arbiter',
protocol='lr',
datasets=['train_homo_lr'],
port='9010',
task_type="lr-train")
def run_arbiter_party():
run_party('arbiter', config)
@ph.context.function(role='host',
protocol='lr',
datasets=['train_homo_lr_host'],
port='9020',
task_type="lr-train")
def run_host_party():
run_party('host', config)
@ph.context.function(role='guest',
protocol='lr',
datasets=['train_homo_lr_guest'],
port='9030',
task_type="lr-train")
def run_guest_party():
run_party('guest', config)
import primihub as ph
from primihub.FL.model.logistic_regression.homo_lr_dev import run_party
config = {
'mode': 'Paillier',
'n_length': 1024,
'learning_rate': 'optimal',
'alpha': 0.01,
'batch_size': 100,
'max_iter': 50,
'n_iter_no_change': 5,
'compare_threshold': 1e-6,
'category': 2,
'feature_names': None,
}
@ph.context.function(role='arbiter',
protocol='lr',
datasets=['train_homo_lr'],
port='9010',
task_type="lr-train")
def run_arbiter_party():
run_party('arbiter', config)
@ph.context.function(role='host',
protocol='lr',
datasets=['train_homo_lr_host'],
port='9020',
task_type="lr-train")
def run_host_party():
run_party('host', config)
@ph.context.function(role='guest',
protocol='lr',
datasets=['train_homo_lr_guest'],
port='9030',
task_type="lr-train")
def run_guest_party():
run_party('guest', config)
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -25,3 +25,5 @@ transformers
requests
line_profiler
matplotlib
dp-accounting==0.3.0
......@@ -37,6 +37,9 @@ class AlgorithmBase {
virtual int execute() = 0;
virtual int finishPartyComm() = 0;
virtual int saveModel() = 0;
std::shared_ptr<DatasetService>& datasetService() {
return dataset_service_;
}
protected:
std::shared_ptr<DatasetService> dataset_service_;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -78,6 +78,7 @@ class KeywordPIRClientTask : public TaskBase {
private:
std::string dataset_path_;
std::string dataset_id_;
std::string result_file_path_;
std::string server_address_;
bool recv_query_data_direct{false};
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册