未验证 提交 1b70f701 编写于 作者: X Xuefeng Xu 提交者: GitHub

add SimpleImputer for strategy mean, most_frequent, constant (#603)

上级 fb53f290
{
"party_info": {
"task_manager": "127.0.0.1:50050"
},
"component_params": {
"roles": {
"server": "Alice",
"client": [
"Bob",
"Charlie"
]
},
"common_params": {
"model": "FL_Preprocess",
"process": "fit_transform",
"FL_type": "H",
"task_name": "HFL_simpleimputer_numeric_fit_transform",
"task": "classification",
"selected_column": [
"MinTemp",
"MaxTemp",
"Rainfall",
"Evaporation",
"Sunshine",
"WindGustSpeed",
"WindSpeed9am",
"WindSpeed3pm",
"Humidity9am",
"Humidity3pm",
"Pressure9am",
"Pressure3pm",
"Cloud9am",
"Cloud3pm",
"Temp9am",
"Temp3pm",
"RISK_MM"
],
"id": "id",
"label": "y",
"preprocess_column": null,
"preprocess_module": {
"SimpleImputer_numeric": {
"column": null,
"missing_values": "np.nan",
"strategy": "mean",
"fill_value": null,
"copy": true,
"add_indicator": false,
"keep_empty_features": false
}
}
},
"role_params": {
"Bob": {
"data_set": "preprocess_hfl_train_client1",
"preprocess_dataset_path": "data/result/Bob_train_dataset.csv",
"preprocess_module_path": "data/result/Bob_preprocess_module.pkl"
},
"Charlie": {
"data_set": "preprocess_hfl_train_client2",
"preprocess_dataset_path": "data/result/Charlie_train_dataset.csv",
"preprocess_module_path": "data/result/Charlie_preprocess_module.pkl"
},
"Alice": {
"data_set": "fl_fake_data"
}
}
}
}
\ No newline at end of file
{
"party_info": {
"task_manager": "127.0.0.1:50050"
},
"component_params": {
"roles": {
"server": "Alice",
"client": [
"Bob",
"Charlie"
]
},
"common_params": {
"model": "FL_Preprocess",
"process": "fit_transform",
"FL_type": "H",
"task_name": "HFL_simpleimputer_string_fit_transform",
"task": "classification",
"selected_column": [
"WindGustDir",
"WindDir9am",
"WindDir3pm",
"RainToday"
],
"id": "id",
"label": "y",
"preprocess_column": null,
"preprocess_module": {
"SimpleImputer_string": {
"column": null,
"missing_values": "np.nan",
"strategy": "most_frequent",
"fill_value": null,
"copy": true,
"add_indicator": false,
"keep_empty_features": false
}
}
},
"role_params": {
"Bob": {
"data_set": "preprocess_hfl_train_client1",
"preprocess_dataset_path": "data/result/Bob_train_dataset.csv",
"preprocess_module_path": "data/result/Bob_preprocess_module.pkl"
},
"Charlie": {
"data_set": "preprocess_hfl_train_client2",
"preprocess_dataset_path": "data/result/Charlie_train_dataset.csv",
"preprocess_module_path": "data/result/Charlie_preprocess_module.pkl"
},
"Alice": {
"data_set": "fl_fake_data"
}
}
}
}
\ No newline at end of file
import numbers
import numpy as np import numpy as np
from sklearn.impute import SimpleImputer as SKL_SimpleImputer from sklearn.impute import SimpleImputer as SKL_SimpleImputer
from sklearn.impute._base import _BaseImputer
from .base import PreprocessBase from .base import PreprocessBase
from .util import get_dense_mask, unique
class SimpleImputer(PreprocessBase): class SimpleImputer(PreprocessBase, _BaseImputer):
def __init__(self, def __init__(self,
missing_values=np.nan, missing_values=np.nan,
...@@ -22,7 +25,160 @@ class SimpleImputer(PreprocessBase): ...@@ -22,7 +25,160 @@ class SimpleImputer(PreprocessBase):
copy=copy, copy=copy,
add_indicator=add_indicator, add_indicator=add_indicator,
keep_empty_features=keep_empty_features) keep_empty_features=keep_empty_features)
if FL_type == 'H':
self.missing_values = missing_values
self.strategy = strategy
self.copy = copy
self.add_indicator = add_indicator
def Hfit(self, x): def Hfit(self, X):
pass if self.role == 'client':
X = self.module._validate_input(X, in_fit=True)
\ No newline at end of file
# default fill_value is 0 for numerical input and "missing_value"
# otherwise
if self.module.fill_value is None:
if X.dtype.kind in ("i", "u", "f"):
fill_value = 0
else:
fill_value = "missing_value"
else:
fill_value = self.module.fill_value
# fill_value should be numerical in case of numerical input
if (
self.module.strategy == "constant"
and X.dtype.kind in ("i", "u", "f")
and not isinstance(fill_value, numbers.Real)
):
raise ValueError(
"'fill_value'={0} is invalid. Expected a "
"numerical value when imputing numerical "
"data".format(fill_value)
)
elif self.role == 'server':
fill_value = self.module.fill_value
self.module.statistics_ = \
self._dense_fit(
X,
self.module.strategy,
self.module.missing_values,
fill_value
)
return self
def _dense_fit(self, X, strategy, missing_values, fill_value):
"""Fit the transformer on dense data."""
if self.role == 'client':
missing_mask = get_dense_mask(X, missing_values)
masked_X = np.ma.masked_array(X, mask=missing_mask)
super()._fit_indicator(missing_mask)
self.module.indicator_ = self.indicator_
# Mean
if strategy == "mean":
if self.role == 'client':
sum_masked = np.ma.sum(masked_X, axis=0)
self.channel.send('sum_masked', sum_masked)
n_samples = X.shape[0] - np.sum(missing_mask, axis=0)
# for backward-compatibility, reduce n_samples to an integer
# if the number of samples is the same for each feature (i.e. no
# missing values)
if np.ptp(n_samples) == 0:
n_samples = n_samples[0]
self.channel.send('n_samples', n_samples)
mean = self.channel.recv('mean')
elif self.role == 'server':
sum_masked = self.channel.recv_all('sum_masked')
sum_masked = np.ma.sum(sum_masked, axis=0)
n_samples = self.channel.recv_all('n_samples')
# n_samples could be np.int or np.ndarray
n_sum = 0
for n in n_samples:
n_sum += n
if isinstance(n_sum, np.ndarray) and np.ptp(n_sum) == 0:
n_sum = n_sum[0]
mean_masked = sum_masked / n_sum
# Avoid the warning "Warning: converting a masked element to nan."
mean = np.ma.getdata(mean_masked)
mean[np.ma.getmask(mean_masked)] = 0 if self.module.keep_empty_features else np.nan
self.channel.send_all('mean', mean)
return mean
# Median
elif strategy == "median":
median_masked = np.ma.median(masked_X, axis=0)
# Avoid the warning "Warning: converting a masked element to nan."
median = np.ma.getdata(median_masked)
median[np.ma.getmaskarray(median_masked)] = (
0 if self.module.keep_empty_features else np.nan
)
return median
# Most frequent
elif strategy == "most_frequent":
if self.role == 'client':
frequency_counts = []
# To be able access the elements by columns
X = X.transpose()
mask = missing_mask.transpose()
for row, row_mask in zip(X[:], mask[:]):
row_mask = np.logical_not(row_mask).astype(bool)
row = row[row_mask]
if len(row) == 0 and self.module.keep_empty_features:
frequency_counts.append(([missing_values], len(row_mask)))
else:
frequency_counts.append(unique(row, return_counts=True))
self.channel.send('frequency_counts', frequency_counts)
most_frequent = self.channel.recv('most_frequent')
elif self.role == 'server':
frequency_counts = self.channel.recv_all('frequency_counts')
n_features = len(frequency_counts[0])
most_frequent = []
for feature_idx in range(n_features):
feature_counts = {}
for client_fc in frequency_counts:
feature, counts = client_fc[feature_idx]
for i, key in enumerate(feature):
if key in feature_counts:
feature_counts[key] += counts[i]
else:
feature_counts[key] = counts[i]
most_frequent_for_idx = max(feature_counts, key=feature_counts.get)
if most_frequent_for_idx == missing_values:
if self.module.keep_empty_features:
most_frequent_for_idx = 0
else:
most_frequent_for_idx = np.nan
most_frequent.append(most_frequent_for_idx)
most_frequent = np.array(most_frequent)
self.channel.send_all('most_frequent', most_frequent)
return most_frequent
# Constant
elif strategy == "constant":
if self.role == 'client':
# for constant strategy, self.statistcs_ is used to store
# fill_value in each column
return np.full(X.shape[1], fill_value, dtype=X.dtype)
elif self.role == 'server':
return fill_value
\ No newline at end of file
...@@ -308,8 +308,15 @@ def select_module(module_name, params, FL_type, role, channel=None): ...@@ -308,8 +308,15 @@ def select_module(module_name, params, FL_type, role, channel=None):
channel=channel channel=channel
) )
elif "SimpleImputer" in module_name: elif "SimpleImputer" in module_name:
missing_values = params.get('missing_values')
if isinstance(missing_values, str):
missing_values = missing_values.lower()
if missing_values == 'np.nan':
missing_values = np.nan
elif missing_values == 'pd.na':
missing_values = pd.NA
module = SimpleImputer( module = SimpleImputer(
missing_values=params.get('missing_values', np.nan), missing_values=missing_values,
strategy=params.get('strategy', 'mean'), strategy=params.get('strategy', 'mean'),
fill_value=params.get('fill_value'), fill_value=params.get('fill_value'),
copy=params.get('copy', True), copy=params.get('copy', True),
......
...@@ -273,4 +273,31 @@ def num_samples(x): ...@@ -273,4 +273,31 @@ def num_samples(x):
return len(x) return len(x)
except TypeError as type_error: except TypeError as type_error:
raise TypeError(message) from type_error raise TypeError(message) from type_error
\ No newline at end of file
def get_dense_mask(X, value_to_mask):
with suppress(ImportError, AttributeError):
# We also suppress `AttributeError` because older versions of pandas do
# not have `NA`.
import pandas
if value_to_mask is pandas.NA:
return pandas.isna(X)
if is_scalar_nan(value_to_mask):
if X.dtype.kind == "f":
Xt = np.isnan(X)
elif X.dtype.kind in ("i", "u"):
# can't have NaNs in integer array.
Xt = np.zeros(X.shape, dtype=bool)
else:
# np.isnan does not work on object dtypes.
Xt = _object_dtype_isnan(X)
else:
Xt = X == value_to_mask
return Xt
def _object_dtype_isnan(X):
return X != X
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册