未验证 提交 3d37039c 编写于 作者: O overlordmax 提交者: GitHub

fix bugs (#4549)

上级 c75dbb37
......@@ -50,7 +50,7 @@ train_path="data/census-income.data"
test_path="data/census-income.test"
train_data_path="train_data/"
test_data_path="test_data/"
pip install -r requirements.txt
wget -P data/ https://archive.ics.uci.edu/ml/machine-learning-databases/census-income-mld/census.tar.gz
tar -zxvf data/census.tar.gz -C data/
......
......@@ -30,9 +30,9 @@ def parse_args():
parser.add_argument("--epochs", type=int, default=400, help="epochs")
parser.add_argument("--batch_size", type=int, default=32, help="batch_size")
parser.add_argument('--use_gpu', type=int, default=0, help='whether using gpu')
parser.add_argument('--model_dir',type=str, default='./model_dir', help="model_dir")
parser.add_argument('--train_data_path',type=str, default='./train_data', help="train_data_path")
parser.add_argument('--test_data_path',type=str, default='./test_data', help="test_data_path")
parser.add_argument('--model_dir',type=str, default='model_dir', help="model_dir")
parser.add_argument('--train_data_path',type=str, default='train_data', help="train_data_path")
parser.add_argument('--test_data_path',type=str, default='test_data', help="test_data_path")
args = parser.parse_args()
return args
......@@ -43,6 +43,5 @@ def data_preparation_args():
parser.add_argument('--train_data_path',type=str, default='', help="train_data_path")
parser.add_argument('--test_data_path',type=str, default='', help="test_data_path")
parser.add_argument('--validation_data_path',type=str, default='', help="validation_data_path")
args = parser.parse_args()
return args
......@@ -5,7 +5,6 @@ train_path="data/census-income.data"
test_path="data/census-income.test"
train_data_path="train_data/"
test_data_path="test_data/"
pip install -r requirements.txt
wget -P data/ https://archive.ics.uci.edu/ml/machine-learning-databases/census-income-mld/census.tar.gz
......
......@@ -18,8 +18,7 @@ def fun2(x):
return 0
def data_preparation(train_path, test_path, train_data_path, test_data_path,
validation_data_path):
def data_preparation(train_path, test_path, train_data_path, test_data_path):
# The column names are from
# https://www2.1010data.com/documentationcenter/prod/Tutorials/MachineLearningExamples/CensusIncomeDataSet.html
column_names = [
......@@ -102,4 +101,4 @@ def data_preparation(train_path, test_path, train_data_path, test_data_path,
args = data_preparation_args()
data_preparation(args.train_path, args.test_path, args.train_data_path,
args.test_data_path, args.validation_data_path)
args.test_data_path)
......@@ -69,8 +69,11 @@ def MMOE(feature_size=499,expert_num=8, gate_num=2, expert_size=16, tower_size=8
label_income_1 = fluid.layers.slice(label_income, axes=[1], starts=[1], ends=[2])
label_marital_1 = fluid.layers.slice(label_marital, axes=[1], starts=[1], ends=[2])
auc_income, batch_auc_1, auc_states_1 = fluid.layers.auc(input=output_layers[0], label=fluid.layers.cast(x=label_income_1, dtype='int64'))
auc_marital, batch_auc_2, auc_states_2 = fluid.layers.auc(input=output_layers[1], label=fluid.layers.cast(x=label_marital_1, dtype='int64'))
pred_income = fluid.layers.clip(output_layers[0], min=1e-10, max=1.0 - 1e-10)
pred_marital = fluid.layers.clip(output_layers[1], min=1e-10, max=1.0 - 1e-10)
auc_income, batch_auc_1, auc_states_1 = fluid.layers.auc(input=pred_income, label=fluid.layers.cast(x=label_income_1, dtype='int64'))
auc_marital, batch_auc_2, auc_states_2 = fluid.layers.auc(input=pred_marital, label=fluid.layers.cast(x=label_marital_1, dtype='int64'))
avg_cost_income = fluid.layers.mean(x=cost_income)
avg_cost_marital = fluid.layers.mean(x=cost_marital)
......@@ -116,7 +119,6 @@ test_loader = fluid.io.DataLoader.from_generator(feed_list=data_list, capacity=b
test_loader.set_sample_list_generator(test_reader, places=place)
auc_income_list = []
auc_marital_list = []
mmoe_res_file = open('mmoe_res.txt', 'w',encoding='utf-8')
for epoch in range(epochs):
for var in auc_states_1: # reset auc states
set_zero(var.name,place=place)
......
......@@ -50,7 +50,7 @@ train_path="data/census-income.data"
test_path="data/census-income.test"
train_data_path="train_data/"
test_data_path="test_data/"
pip install -r requirements.txt
wget -P data/ https://archive.ics.uci.edu/ml/machine-learning-databases/census-income-mld/census.tar.gz
tar -zxvf data/census.tar.gz -C data/
......
......@@ -40,15 +40,7 @@ def data_preparation_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--train_path", type=str, default='', help="train_path")
parser.add_argument("--test_path", type=str, default='', help="test_path")
parser.add_argument(
'--train_data_path', type=str, default='train_data', help="train_data_path")
parser.add_argument(
'--test_data_path', type=str, default='test_data', help="test_data_path")
parser.add_argument(
'--validation_data_path',
type=str,
default='',
help="validation_data_path")
parser.add_argument('--train_data_path', type=str, default='train_data', help="train_data_path")
parser.add_argument('--test_data_path', type=str, default='test_data', help="test_data_path")
args = parser.parse_args()
return args
......@@ -18,8 +18,7 @@ def fun2(x):
return 0
def data_preparation(train_path, test_path, train_data_path, test_data_path,
validation_data_path):
def data_preparation(train_path, test_path, train_data_path, test_data_path):
# The column names are from
# https://www2.1010data.com/documentationcenter/prod/Tutorials/MachineLearningExamples/CensusIncomeDataSet.html
column_names = [
......@@ -101,5 +100,4 @@ def data_preparation(train_path, test_path, train_data_path, test_data_path,
args = data_preparation_args()
data_preparation(args.train_path, args.test_path, args.train_data_path,
args.test_data_path, args.validation_data_path)
data_preparation(args.train_path, args.test_path, args.train_data_path,args.test_data_path)
......@@ -56,8 +56,11 @@ def share_bottom(feature_size=499,bottom_size=117,tower_nums=2,tower_size=8):
label_income_1 = fluid.layers.slice(label_income, axes=[1], starts=[1], ends=[2])
label_marital_1 = fluid.layers.slice(label_marital, axes=[1], starts=[1], ends=[2])
auc_income, batch_auc_1, auc_states_1 = fluid.layers.auc(input=output_layers[0], label=fluid.layers.cast(x=label_income_1, dtype='int64'))
auc_marital, batch_auc_2, auc_states_2 = fluid.layers.auc(input=output_layers[1], label=fluid.layers.cast(x=label_marital_1, dtype='int64'))
pred_income = fluid.layers.clip(output_layers[0], min=1e-10, max=1.0 - 1e-10)
pred_marital = fluid.layers.clip(output_layers[1], min=1e-10, max=1.0 - 1e-10)
auc_income, batch_auc_1, auc_states_1 = fluid.layers.auc(input=pred_income, label=fluid.layers.cast(x=label_income_1, dtype='int64'))
auc_marital, batch_auc_2, auc_states_2 = fluid.layers.auc(input=pred_marital, label=fluid.layers.cast(x=label_marital_1, dtype='int64'))
avg_cost_income = fluid.layers.mean(x=cost_income)
avg_cost_marital = fluid.layers.mean(x=cost_marital)
......
python share_bottom.py --use_gpu 1 \
CUDA_VISIBLE_DEVICES=0 python share_bottom.py --use_gpu 1 \
--epochs 100 \
--train_data_path 'train_data' \
--test_data_path 'test_data' \
--train_data_path '.train_data' \
--test_data_path '.test_data' \
--model_dir 'model_dir' \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册