提交 67444120 编写于 作者: Y yaoxuefeng

add wide&deep

上级 c8f35128
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
train:
trainer:
# for cluster training
strategy: "async"
epochs: 10
workspace: "fleetrec.models.rank.wide_deep"
reader:
batch_size: 2
class: "{workspace}/reader.py"
train_data_path: "{workspace}/data/train_data"
model:
models: "{workspace}/model.py"
hyper_parameters:
hidden1_units: 75
hidden2_units: 50
hidden3_units: 25
learning_rate: 0.0001
reg: 0.001
act: "relu"
optimizer: SGD
save:
increment:
dirname: "increment"
epoch_interval: 2
save_last: True
inference:
dirname: "inference"
epoch_interval: 4
save_last: True
mkdir train_data
mkdir test_data
mkdir data
train_path="/home/yaoxuefeng/repos/models/models/PaddleRec/ctr/wide_deep/data/adult.data"
test_path="/home/yaoxuefeng/repos/models/models/PaddleRec/ctr/wide_deep/data/adult.test"
train_data_path="/home/yaoxuefeng/repos/models/models/PaddleRec/ctr/wide_deep/train_data/train_data.csv"
test_data_path="/home/yaoxuefeng/repos/models/models/PaddleRec/ctr/wide_deep/test_data/test_data.csv"
#pip install -r requirements.txt
#wget -P data/ https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data
#wget -P data/ https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test
python data_preparation.py --train_path ${train_path} \
--test_path ${test_path} \
--train_data_path ${train_data_path}\
--test_data_path ${test_data_path}
import paddle.fluid as fluid
import math
from fleetrec.core.utils import envs
from fleetrec.core.model import Model as ModelBase
class Model(ModelBase):
def __init__(self, config):
ModelBase.__init__(self, config)
def wide_part(self, data):
out = fluid.layers.fc(input=data,
size=1,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=1.0 / math.sqrt(data.shape[1])),
regularizer=fluid.regularizer.L2DecayRegularizer(regularization_coeff=1e-4)),
act=None,
name='wide')
return out
def fc(self, data, hidden_units, active, tag):
output = fluid.layers.fc(input=data,
size=hidden_units,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=1.0 / math.sqrt(data.shape[1]))),
act=active,
name=tag)
return output
def deep_part(self, data, hidden1_units, hidden2_units, hidden3_units):
l1 = self.fc(data, hidden1_units, 'relu', 'l1')
l2 = self.fc(l1, hidden2_units, 'relu', 'l2')
l3 = self.fc(l2, hidden3_units, 'relu', 'l3')
return l3
def train_net(self):
wide_input = fluid.data(name='wide_input', shape=[None, 8], dtype='float32')
deep_input = fluid.data(name='deep_input', shape=[None, 58], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='float32')
self._data_var.append(wide_input)
self._data_var.append(deep_input)
self._data_var.append(label)
hidden1_units = envs.get_global_env("hyper_parameters.hidden1_units", 75, self._namespace)
hidden2_units = envs.get_global_env("hyper_parameters.hidden2_units", 50, self._namespace)
hidden3_units = envs.get_global_env("hyper_parameters.hidden3_units", 25, self._namespace)
wide_output = self.wide_part(wide_input)
deep_output = self.deep_part(deep_input, hidden1_units, hidden2_units, hidden3_units)
wide_model = fluid.layers.fc(input=wide_output,
size=1,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=1.0)),
act=None,
name='w_wide')
deep_model = fluid.layers.fc(input=deep_output,
size=1,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=1.0)),
act=None,
name='w_deep')
prediction = fluid.layers.elementwise_add(wide_model, deep_model)
pred = fluid.layers.sigmoid(fluid.layers.clip(prediction, min=-15.0, max=15.0), name="prediction")
num_seqs = fluid.layers.create_tensor(dtype='int64')
acc = fluid.layers.accuracy(input=pred, label=fluid.layers.cast(x=label, dtype='int64'), total=num_seqs)
auc_var, batch_auc, auc_states = fluid.layers.auc(input=pred, label=fluid.layers.cast(x=label, dtype='int64'))
self._metrics["AUC"] = auc_var
self._metrics["BATCH_AUC"] = batch_auc
self._metrics["ACC"] = acc
cost = fluid.layers.sigmoid_cross_entropy_with_logits(x=prediction, label=label)
avg_cost = fluid.layers.mean(cost)
self._cost = avg_cost
def optimizer(self):
learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self._namespace)
optimizer = fluid.optimizer.Adam(learning_rate, lazy_mode=True)
return optimizer
def infer_net(self, parameter_list):
self.deepfm_net()
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from fleetrec.core.reader import Reader
from fleetrec.core.utils import envs
try:
import cPickle as pickle
except ImportError:
import pickle
class TrainReader(Reader):
def init(self):
pass
def _process_line(self, line):
line = line.strip().split(',')
features = list(map(float, line))
wide_feat = features[0:8]
deep_feat = features[8:58+8]
label = features[-1]
return wide_feat, deep_feat, [label]
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def data_iter():
wide_feat, deep_deat, label = self._process_line(line)
yield [('wide_input', wide_feat), ('deep_input', deep_deat), ('label', label)]
return data_iter
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册