未验证 提交 8e965485 编写于 作者: X Xuefeng Xu 提交者: GitHub

add HFL preprocessing pipeline (#604)

上级 1b70f701
{
"party_info": {
"task_manager": "127.0.0.1:50050"
},
"component_params": {
"roles": {
"server": "Alice",
"client": [
"Bob",
"Charlie"
]
},
"common_params": {
"model": "FL_Preprocess",
"process": "fit_transform",
"FL_type": "H",
"task_name": "HFL_pipeline_fit_transform",
"task": "classification",
"selected_column": null,
"id": "id",
"label": "y",
"preprocess_column": null,
"preprocess_module": {
"LabelEncoder": {
"column": "y"
},
"SimpleImputer_string": {
"column": null,
"missing_values": "np.nan",
"strategy": "most_frequent"
},
"SimpleImputer_numeric": {
"column": null,
"missing_values": "np.nan",
"strategy": "mean"
},
"OrdinalEncoder": {
"column": null,
"handle_unknown": "use_encoded_value",
"unknown_value": -1
},
"StandardScaler": {
"column": null
}
}
},
"role_params": {
"Bob": {
"data_set": "preprocess_hfl_train_client1",
"preprocess_dataset_path": "data/result/Bob_train_dataset.csv",
"preprocess_module_path": "data/result/Bob_preprocess_module.pkl"
},
"Charlie": {
"data_set": "preprocess_hfl_train_client2",
"preprocess_dataset_path": "data/result/Charlie_train_dataset.csv",
"preprocess_module_path": "data/result/Charlie_preprocess_module.pkl"
},
"Alice": {
"data_set": "fl_fake_data"
}
}
}
}
\ No newline at end of file
......@@ -11,7 +11,7 @@
"model": "FL_Preprocess",
"process": "fit_transform",
"FL_type": "V",
"task_name": "VFL_preprocess_fit_transform",
"task_name": "VFL_pipeline_fit_transform",
"task": "classification"
},
"role_params": {
......@@ -29,10 +29,12 @@
},
"SimpleImputer_string": {
"column": null,
"missing_values": "np.nan",
"strategy": "most_frequent"
},
"SimpleImputer_numeric": {
"column": null,
"missing_values": "np.nan",
"strategy": "mean"
},
"OrdinalEncoder": {
......@@ -55,10 +57,12 @@
"preprocess_module": {
"SimpleImputer_string": {
"column": null,
"missing_values": "np.nan",
"strategy": "most_frequent"
},
"SimpleImputer_numeric": {
"column": null,
"missing_values": "np.nan",
"strategy": "mean"
},
"OrdinalEncoder": {
......
......@@ -8,6 +8,7 @@ from primihub.FL.preprocessing import *
import pickle
import numpy as np
import pandas as pd
from itertools import chain
class Pipeline(BaseModel):
......@@ -109,23 +110,31 @@ class Pipeline(BaseModel):
raise RuntimeError(error_msg)
column = params.get('column')
if column is None and preprocess_column is not None:
if column is None:
column = preprocess_column
if column is None and role != 'server':
if 'SimpleImputer' in module_name:
nan_column = preprocess_column[data[preprocess_column].isna().any()]
if 'string' in module_name:
column = data[nan_column].select_dtypes(exclude=num_type).columns
elif 'numeric' in module_name:
column = data[nan_column].select_dtypes(include=num_type).columns
else:
column = nan_column
elif 'Encoder' in module_name:
column = data[preprocess_column].select_dtypes(exclude=num_type).columns
elif 'Scaler' in module_name:
column = data[preprocess_column].select_dtypes(include=num_type).columns
else:
column = preprocess_column
if role != 'server':
if 'SimpleImputer' in module_name:
nan_column = column[data[column].isna().any()]
if 'string' in module_name:
column = data[nan_column].select_dtypes(exclude=num_type).columns
elif 'numeric' in module_name:
column = data[nan_column].select_dtypes(include=num_type).columns
else:
column = nan_column
elif module_name in ['OrdinalEncoder', 'OneHotEncoder']:
column = data[column].select_dtypes(exclude=num_type).columns
elif 'Scaler' in module_name:
column = data[column].select_dtypes(include=num_type).columns
if role == 'client':
channel.send('column', column)
column = channel.recv('column')
if role == 'server':
client_column = channel.recv_all('column')
column = list(set(chain.from_iterable(client_column)))
channel.send_all('column', column)
if column is not None:
if isinstance(column, pd.Index):
column = column.tolist()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册