diff --git a/fleetrec/examples/user_define_trainer.py b/fleetrec/examples/user_define_trainer.py index 551179e2a9d3b1a14ee08cf4ba0c755aab683e30..549d42b84ce209688fb597026786c79f50dad1f4 100644 --- a/fleetrec/examples/user_define_trainer.py +++ b/fleetrec/examples/user_define_trainer.py @@ -28,7 +28,6 @@ class UserDefineTrainer(TranspileTrainer): self.regist_context_processor('train_pass', self.train) def init(self, context): - self.model.input() self.model.net() self.model.metrics() self.model.avg_loss() diff --git a/fleetrec/models/base.py b/fleetrec/models/base.py index 0a92dee06030b7958efc0bc9ae8992f432849282..23269674828b234b8ef8867d4eb2cdcaec2f310d 100644 --- a/fleetrec/models/base.py +++ b/fleetrec/models/base.py @@ -63,11 +63,11 @@ def create(config): model = None if config['mode'] == 'fluid': model = YamlModel(config) - model.net() + model.train_net() return model -class Model(object): +class ModelBase(object): """R """ __metaclass__ = abc.ABCMeta @@ -80,6 +80,9 @@ class Model(object): self._data_var = [] self._fetch_interval = 20 + def get_input(self): + return self._data_var + def get_cost_op(self): """R """ @@ -94,20 +97,23 @@ class Model(object): return self._fetch_interval @abc.abstractmethod - def net(self): + def train_net(self): """R """ pass + def infer_net(self): + pass + -class YamlModel(Model): +class YamlModel(ModelBase): """R """ def __init__(self, config): """R """ - Model.__init__(self, config) + ModelBase.__init__(self, config) self._config = config self._name = config['name'] f = open(config['layer_file'], 'r') @@ -116,7 +122,7 @@ class YamlModel(Model): self._build_param = {'layer': {}, 'inner_layer': {}, 'layer_extend': {}, 'model': {}} self._inference_meta = {'dependency': {}, 'params': {}} - def net(self): + def train_net(self): """R build a fluid model with config Return: diff --git a/fleetrec/models/ctr_dnn/model.py b/fleetrec/models/ctr_dnn/model.py index 5c0f6e9bcd455ea816a459be1b70bab7c1655538..df96eb6fb25e7c91f8352175921a52fcac5c560d 100644 --- a/fleetrec/models/ctr_dnn/model.py +++ b/fleetrec/models/ctr_dnn/model.py @@ -16,12 +16,12 @@ import math import paddle.fluid as fluid from fleetrec.utils import envs -from fleetrec.models.base import Model +from fleetrec.models.base import ModelBase -class TrainModel(Model): +class Model(ModelBase): def __init__(self, config): - Model.__init__(self, config) + ModelBase.__init__(self, config) self.namespace = "train.model" def input(self): @@ -52,8 +52,10 @@ class TrainModel(Model): self.dense_input = dense_input() self.label_input = label_input() - def inputs(self): - return [self.dense_input] + self.sparse_inputs + [self.label_input] + for input in self.sparse_inputs: + self._data_var.append(input) + self._data_var.append(self.dense_input) + self._data_var.append(self.label_input) def net(self): def embedding_layer(input): @@ -112,15 +114,17 @@ class TrainModel(Model): self._metrics["AUC"] = auc self._metrics["BATCH_AUC"] = batch_auc + def train_net(self): + self.input() + self.net() + self.avg_loss() + self.metrics() + def optimizer(self): learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self.namespace) optimizer = fluid.optimizer.Adam(learning_rate, lazy_mode=True) return optimizer - -class EvaluateModel(object): - def input(self): - pass - - def net(self): - pass + def infer_net(self): + self.input() + self.net() diff --git a/fleetrec/trainer/cluster_trainer.py b/fleetrec/trainer/cluster_trainer.py index 6b3e6471841bb526d2cc95c1a2b328833d82570e..f9b41a192414d9fc313a8e544a9b84856b153e34 100644 --- a/fleetrec/trainer/cluster_trainer.py +++ b/fleetrec/trainer/cluster_trainer.py @@ -66,15 +66,11 @@ class ClusterTrainer(TranspileTrainer): return strategy def init(self, context): - self.model.input() - self.model.net() - self.model.metrics() - self.model.avg_loss() + self.model.train_net() optimizer = self.model.optimizer() - strategy = self.build_strategy() optimizer = fleet.distributed_optimizer(optimizer, strategy) - optimizer.minimize(self.model._cost) + optimizer.minimize(self.model.get_cost_op()) if fleet.is_server(): context['status'] = 'server_pass' diff --git a/fleetrec/trainer/single_trainer.py b/fleetrec/trainer/single_trainer.py index 8cb0be63ba1ea08b459a24779d4e45735c57ebae..d606fd38d793ce1d2fa319d862c9fd6c8ec2def7 100644 --- a/fleetrec/trainer/single_trainer.py +++ b/fleetrec/trainer/single_trainer.py @@ -37,12 +37,9 @@ class SingleTrainer(TranspileTrainer): self.regist_context_processor('terminal_pass', self.terminal) def init(self, context): - self.model.input() - self.model.net() - self.model.metrics() - self.model.avg_loss() + self.model.train_net() optimizer = self.model.optimizer() - optimizer.minimize(self.model._cost) + optimizer.minimize((self.model.get_cost_op())) self.fetch_vars = [] self.fetch_alias = []