提交 eb04c187 编写于 作者: T tangwei

add infer net define

上级 ac45ce07
...@@ -28,7 +28,6 @@ class UserDefineTrainer(TranspileTrainer): ...@@ -28,7 +28,6 @@ class UserDefineTrainer(TranspileTrainer):
self.regist_context_processor('train_pass', self.train) self.regist_context_processor('train_pass', self.train)
def init(self, context): def init(self, context):
self.model.input()
self.model.net() self.model.net()
self.model.metrics() self.model.metrics()
self.model.avg_loss() self.model.avg_loss()
......
...@@ -63,11 +63,11 @@ def create(config): ...@@ -63,11 +63,11 @@ def create(config):
model = None model = None
if config['mode'] == 'fluid': if config['mode'] == 'fluid':
model = YamlModel(config) model = YamlModel(config)
model.net() model.train_net()
return model return model
class Model(object): class ModelBase(object):
"""R """R
""" """
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
...@@ -80,6 +80,9 @@ class Model(object): ...@@ -80,6 +80,9 @@ class Model(object):
self._data_var = [] self._data_var = []
self._fetch_interval = 20 self._fetch_interval = 20
def get_input(self):
return self._data_var
def get_cost_op(self): def get_cost_op(self):
"""R """R
""" """
...@@ -94,20 +97,23 @@ class Model(object): ...@@ -94,20 +97,23 @@ class Model(object):
return self._fetch_interval return self._fetch_interval
@abc.abstractmethod @abc.abstractmethod
def net(self): def train_net(self):
"""R """R
""" """
pass pass
def infer_net(self):
pass
class YamlModel(Model): class YamlModel(ModelBase):
"""R """R
""" """
def __init__(self, config): def __init__(self, config):
"""R """R
""" """
Model.__init__(self, config) ModelBase.__init__(self, config)
self._config = config self._config = config
self._name = config['name'] self._name = config['name']
f = open(config['layer_file'], 'r') f = open(config['layer_file'], 'r')
...@@ -116,7 +122,7 @@ class YamlModel(Model): ...@@ -116,7 +122,7 @@ class YamlModel(Model):
self._build_param = {'layer': {}, 'inner_layer': {}, 'layer_extend': {}, 'model': {}} self._build_param = {'layer': {}, 'inner_layer': {}, 'layer_extend': {}, 'model': {}}
self._inference_meta = {'dependency': {}, 'params': {}} self._inference_meta = {'dependency': {}, 'params': {}}
def net(self): def train_net(self):
"""R """R
build a fluid model with config build a fluid model with config
Return: Return:
......
...@@ -16,12 +16,12 @@ import math ...@@ -16,12 +16,12 @@ import math
import paddle.fluid as fluid import paddle.fluid as fluid
from fleetrec.utils import envs from fleetrec.utils import envs
from fleetrec.models.base import Model from fleetrec.models.base import ModelBase
class TrainModel(Model): class Model(ModelBase):
def __init__(self, config): def __init__(self, config):
Model.__init__(self, config) ModelBase.__init__(self, config)
self.namespace = "train.model" self.namespace = "train.model"
def input(self): def input(self):
...@@ -52,8 +52,10 @@ class TrainModel(Model): ...@@ -52,8 +52,10 @@ class TrainModel(Model):
self.dense_input = dense_input() self.dense_input = dense_input()
self.label_input = label_input() self.label_input = label_input()
def inputs(self): for input in self.sparse_inputs:
return [self.dense_input] + self.sparse_inputs + [self.label_input] self._data_var.append(input)
self._data_var.append(self.dense_input)
self._data_var.append(self.label_input)
def net(self): def net(self):
def embedding_layer(input): def embedding_layer(input):
...@@ -112,15 +114,17 @@ class TrainModel(Model): ...@@ -112,15 +114,17 @@ class TrainModel(Model):
self._metrics["AUC"] = auc self._metrics["AUC"] = auc
self._metrics["BATCH_AUC"] = batch_auc self._metrics["BATCH_AUC"] = batch_auc
def train_net(self):
self.input()
self.net()
self.avg_loss()
self.metrics()
def optimizer(self): def optimizer(self):
learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self.namespace) learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self.namespace)
optimizer = fluid.optimizer.Adam(learning_rate, lazy_mode=True) optimizer = fluid.optimizer.Adam(learning_rate, lazy_mode=True)
return optimizer return optimizer
def infer_net(self):
class EvaluateModel(object): self.input()
def input(self): self.net()
pass
def net(self):
pass
...@@ -66,15 +66,11 @@ class ClusterTrainer(TranspileTrainer): ...@@ -66,15 +66,11 @@ class ClusterTrainer(TranspileTrainer):
return strategy return strategy
def init(self, context): def init(self, context):
self.model.input() self.model.train_net()
self.model.net()
self.model.metrics()
self.model.avg_loss()
optimizer = self.model.optimizer() optimizer = self.model.optimizer()
strategy = self.build_strategy() strategy = self.build_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(self.model._cost) optimizer.minimize(self.model.get_cost_op())
if fleet.is_server(): if fleet.is_server():
context['status'] = 'server_pass' context['status'] = 'server_pass'
......
...@@ -37,12 +37,9 @@ class SingleTrainer(TranspileTrainer): ...@@ -37,12 +37,9 @@ class SingleTrainer(TranspileTrainer):
self.regist_context_processor('terminal_pass', self.terminal) self.regist_context_processor('terminal_pass', self.terminal)
def init(self, context): def init(self, context):
self.model.input() self.model.train_net()
self.model.net()
self.model.metrics()
self.model.avg_loss()
optimizer = self.model.optimizer() optimizer = self.model.optimizer()
optimizer.minimize(self.model._cost) optimizer.minimize((self.model.get_cost_op()))
self.fetch_vars = [] self.fetch_vars = []
self.fetch_alias = [] self.fetch_alias = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册