提交 1e953617 编写于 作者: T tangwei

add ctr-dnn example

上级 910d0cd1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
train:
batch_size: 32
threads: 12
epochs: 10
trainer: "SingleTraining"
reader:
mode: "dataset"
pipe_command: "python reader.py dataset"
train_data_path: "raw_data"
model:
models: "eleps.models.ctr_dnn.model.py"
hyper_parameters:
sparse_inputs_slots: 27,
sparse_feature_number: 1000001,
sparse_feature_dim: 8,
dense_input_dim: 13,
fc_sizes: [1024, 512, 32],
learning_rate: 0.001
save:
increment:
dirname: "models_for_increment"
epoch_interval: 2
save_last: True
inference:
dirname: "models_for_inference"
epoch_interval: 4
feed_varnames: ["C1", "C2", "C3"]
fetch_varnames: "predict"
save_last: True
evaluate:
batch_size: 32
train_thread_num: 12
reader: "reader.py"
{ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
"sparse_inputs_slots": 27, #
"sparse_feature_number": 1000001, # Licensed under the Apache License, Version 2.0 (the "License");
"sparse_feature_dim": 8, # you may not use this file except in compliance with the License.
"dense_input_dim": 13, # You may obtain a copy of the License at
"fc_sizes": [400, 400, 40], #
"learning_rate": 0.001 # http://www.apache.org/licenses/LICENSE-2.0
} #
\ No newline at end of file # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
sparse_inputs_slots: 27,
sparse_feature_number: 1000001,
sparse_feature_dim: 8,
dense_input_dim: 13,
fc_sizes: [400, 400, 40],
learning_rate: 0.001
...@@ -57,6 +57,12 @@ class Train(object): ...@@ -57,6 +57,12 @@ class Train(object):
self.dense_input, self.dense_input_varname = dense_input() self.dense_input, self.dense_input_varname = dense_input()
self.label_input, self.label_input_varname = label_input() self.label_input, self.label_input_varname = label_input()
def input_vars(self):
return self.sparse_inputs + [self.dense_input] + [self.label_input]
def input_varnames(self):
return [input.name for input in self.input_vars()]
def net(self): def net(self):
def embedding_layer(input): def embedding_layer(input):
sparse_feature_number = envs.get_global_env("sparse_feature_number") sparse_feature_number = envs.get_global_env("sparse_feature_number")
...@@ -101,22 +107,28 @@ class Train(object): ...@@ -101,22 +107,28 @@ class Train(object):
self.predict = predict self.predict = predict
def loss(self, predict): def avg_loss(self, predict):
cost = fluid.layers.cross_entropy(input=predict, label=self.label_input) cost = fluid.layers.cross_entropy(input=predict, label=self.label_input)
avg_cost = fluid.layers.reduce_sum(cost) avg_cost = fluid.layers.reduce_sum(cost)
self.loss = avg_cost self.loss = avg_cost
return avg_cost
def metric(self): def metrics(self):
auc, batch_auc, _ = fluid.layers.auc(input=self.predict, auc, batch_auc, _ = fluid.layers.auc(input=self.predict,
label=self.label_input, label=self.label_input,
num_thresholds=2 ** 12, num_thresholds=2 ** 12,
slide_steps=20) slide_steps=20)
self.metrics = (auc, batch_auc)
def optimizer(self): def optimizer(self):
learning_rate = envs.get_global_env("learning_rate") learning_rate = envs.get_global_env("learning_rate")
optimizer = fluid.optimizer.Adam(learning_rate, lazy_mode=True) optimizer = fluid.optimizer.Adam(learning_rate, lazy_mode=True)
return optimizer return optimizer
def optimize(self):
optimizer = self.optimizer()
optimizer.minimize(self.loss)
class Evaluate(object): class Evaluate(object):
def input(self): def input(self):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import time
import numpy as np
import logging
import paddle.fluid as fluid
from network import CTR
from argument import params_args
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def get_dataset(inputs, params):
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var(inputs)
dataset.set_pipe_command("python dataset_generator.py")
dataset.set_batch_size(params.batch_size)
dataset.set_thread(int(params.cpu_num))
file_list = [
str(params.train_files_path) + "/%s" % x
for x in os.listdir(params.train_files_path)
]
dataset.set_filelist(file_list)
logger.info("file list: {}".format(file_list))
return dataset
def train(params):
ctr_model = CTR()
inputs = ctr_model.input_data(params)
avg_cost, auc_var, batch_auc_var = ctr_model.net(inputs, params)
optimizer = fluid.optimizer.Adam(params.learning_rate)
optimizer.minimize(avg_cost)
fluid.default_main_program()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
dataset = get_dataset(inputs, params)
logger.info("Training Begin")
for epoch in range(params.epochs):
start_time = time.time()
exe.train_from_dataset(program=fluid.default_main_program(),
dataset=dataset,
fetch_list=[auc_var],
fetch_info=["Epoch {} auc ".format(epoch)],
print_period=100,
debug=False)
end_time = time.time()
logger.info("epoch %d finished, use time=%d\n" %
((epoch), end_time - start_time))
if params.test:
model_path = (str(params.model_path) + "/" + "epoch_" + str(epoch))
fluid.io.save_persistables(executor=exe, dirname=model_path)
logger.info("Train Success!")
if __name__ == "__main__":
params = params_args()
train(params)
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training use fluid with one node only.
"""
from __future__ import print_function
import os
import time
import numpy as np
import logging
import paddle.fluid as fluid
from .trainer import Trainer
from ..utils import envs
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def need_save(epoch_id, epoch_interval, is_last=False):
if is_last:
return True
return epoch_id % epoch_interval == 0
class SingleTrainer(Trainer):
def __init__(self, config=None, yaml_file=None):
Trainer.__init__(self, config, yaml_file)
self.exe = fluid.Executor(fluid.CPUPlace())
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('train_pass', self.train)
self.regist_context_processor('infer_pass', self.infer)
self.regist_context_processor('terminal_pass', self.terminal)
def instance(self, context):
model_package = __import__(envs.get_global_env("train.model.models"))
train_model = getattr(model_package, 'Train')
self.model = train_model()
context['status'] = 'init_pass'
def init(self, context):
self.model.input()
self.model.net()
self.model.loss()
self.metrics = self.model.metrics()
self.model.optimize()
# run startup program at once
self.exe.run(fluid.default_startup_program())
context['status'] = 'train_pass'
def train(self, context):
print("Need to be implement")
context['is_exit'] = True
def infer(self, context):
print("Need to be implement")
context['is_exit'] = True
def terminal(self, context):
context['is_exit'] = True
class SingleTrainerWithDataloader(SingleTrainer):
pass
class SingleTrainerWithDataset(SingleTrainer):
def _get_dataset(self, inputs, threads, batch_size, pipe_command, train_files_path):
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var(inputs)
dataset.set_pipe_command(pipe_command)
dataset.set_batch_size(batch_size)
dataset.set_thread(threads)
file_list = [
os.path.join(train_files_path, x)
for x in os.listdir(train_files_path)
]
dataset.set_filelist(file_list)
return dataset
def save(self, epoch_id):
def save_inference_model():
is_save_inference = envs.get_global_env("save.inference", False)
if not is_save_inference:
return
save_interval = envs.get_global_env("save.inference.epoch_interval", 1)
if not need_save(epoch_id, save_interval, False):
return
feed_varnames = envs.get_global_env("save.inference.feed_varnames", None)
fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None)
fetch_vars = [fluid.global_scope().vars[varname] for varname in fetch_varnames]
dirname = envs.get_global_env("save.inference.dirname", None)
assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id))
fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe)
def save_persistables():
is_save_increment = envs.get_global_env("save.increment", False)
if not is_save_increment:
return
save_interval = envs.get_global_env("save.increment.epoch_interval", 1)
if not need_save(epoch_id, save_interval, False):
return
dirname = envs.get_global_env("save.inference.dirname", None)
assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id))
fluid.io.save_persistables(self.exe, dirname)
is_save = envs.get_global_env("save", False)
if not is_save:
return
save_persistables()
save_inference_model()
def train(self, context):
inputs = self.model.input_vars()
threads = envs.get_global_env("threads")
batch_size = envs.get_global_env("batch_size")
pipe_command = envs.get_global_env("pipe_command")
train_data_path = envs.get_global_env("train_data_path")
dataset = self._get_dataset(inputs, threads, batch_size, pipe_command, train_data_path)
epochs = envs.get_global_env("epochs")
for i in range(epochs):
self.exe.train_from_dataset(program=fluid.default_main_program(),
dataset=dataset,
fetch_list=[self.metrics],
fetch_info=["epoch {} auc ".format(i)],
print_period=100)
context['status'] = 'infer_pass'
def infer(self, context):
context['status'] = 'terminal_pass'
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
import abc import abc
import time import time
import yaml
from .. utils import envs
class Trainer(object): class Trainer(object):
...@@ -21,9 +24,20 @@ class Trainer(object): ...@@ -21,9 +24,20 @@ class Trainer(object):
""" """
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, config): def __init__(self, config=None, yaml_file=None):
"""R
""" if not config and not yaml_file:
raise ValueError("config and yaml file have at least one not empty")
if config and yaml_file:
print("config and yaml file are all assigned, will use yaml file: {}".format(yaml_file))
if yaml_file:
with open(yaml_file, "r") as rb:
config = yaml.load(rb.read())
envs.set_global_envs(config)
self._status_processor = {} self._status_processor = {}
self._context = {'status': 'uninit', 'is_exit': False} self._context = {'status': 'uninit', 'is_exit': False}
......
...@@ -24,17 +24,17 @@ def decode_value(v): ...@@ -24,17 +24,17 @@ def decode_value(v):
return v return v
def set_global_envs(yaml, envs): def set_global_envs(yaml):
for k, v in yaml.items(): for k, v in yaml.items():
envs[k] = encode_value(v) os.environ[k] = encode_value(v)
def get_global_env(env_name): def get_global_env(env_name, default_value=None):
""" """
get os environment value get os environment value
""" """
if env_name not in os.environ: if env_name not in os.environ:
raise ValueError("can not find config of {}".format(env_name)) return default_value
v = os.environ[env_name] v = os.environ[env_name]
return decode_value(v) return decode_value(v)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册