提交 801dfd34 编写于 作者: T tangwei

rename get_cost_op to avg_cost

上级 ba023dfb
......@@ -47,7 +47,7 @@ class Model(object):
def get_infer_results(self):
return self._infer_results
def get_cost_op(self):
def get_avg_cost(self):
"""R
"""
return self._cost
......
......@@ -82,7 +82,7 @@ class ClusterTrainer(TranspileTrainer):
strategy = self.build_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(self.model.get_cost_op())
optimizer.minimize(self.model.get_avg_cost())
if fleet.is_server():
context['status'] = 'server_pass'
......@@ -114,7 +114,7 @@ class ClusterTrainer(TranspileTrainer):
program = fluid.compiler.CompiledProgram(
fleet.main_program).with_data_parallel(
loss_name=self.model.get_cost_op().name,
loss_name=self.model.get_avg_cost().name,
build_strategy=self.strategy.get_build_strategy(),
exec_strategy=self.strategy.get_execute_strategy())
......
......@@ -88,7 +88,7 @@ class CtrPaddleTrainer(Trainer):
optimizer = self.model.optimizer()
optimizer = fleet.distributed_optimizer(optimizer, strategy={"use_cvm": False})
optimizer.minimize(self.model.get_cost_op())
optimizer.minimize(self.model.get_avg_cost())
if fleet.is_server():
context['status'] = 'server_pass'
......
......@@ -129,7 +129,7 @@ class CtrPaddleTrainer(Trainer):
model = self._exector_context[executor['name']]['model']
self._metrics.update(model.get_metrics())
runnnable_scope.append(scope)
runnnable_cost_op.append(model.get_cost_op())
runnnable_cost_op.append(model.get_avg_cost())
for var in model._data_var:
if var.name in data_var_name_dict:
continue
......@@ -146,7 +146,7 @@ class CtrPaddleTrainer(Trainer):
model = self._exector_context[executor['name']]['model']
program = model._build_param['model']['train_program']
if not executor['is_update_sparse']:
program._fleet_opt["program_configs"][str(id(model.get_cost_op().block.program))]["push_sparse"] = []
program._fleet_opt["program_configs"][str(id(model.get_avg_cost().block.program))]["push_sparse"] = []
if 'train_thread_num' not in executor:
executor['train_thread_num'] = self.global_config['train_thread_num']
with fluid.scope_guard(scope):
......
......@@ -78,7 +78,7 @@ class ClusterTrainer(TranspileTrainer):
optimizer = self.model.optimizer()
strategy = self.build_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(self.model.get_cost_op())
optimizer.minimize(self.model.get_avg_cost())
if fleet.is_server():
context['status'] = 'server_pass'
......
......@@ -47,7 +47,7 @@ class SingleTrainer(TranspileTrainer):
def init(self, context):
self.model.train_net()
optimizer = self.model.optimizer()
optimizer.minimize((self.model.get_cost_op()))
optimizer.minimize((self.model.get_avg_cost()))
self.fetch_vars = []
self.fetch_alias = []
......@@ -74,7 +74,7 @@ class SingleTrainer(TranspileTrainer):
program = fluid.compiler.CompiledProgram(
fluid.default_main_program()).with_data_parallel(
loss_name=self.model.get_cost_op().name)
loss_name=self.model.get_avg_cost().name)
metrics_varnames = []
metrics_format = []
......
......@@ -153,7 +153,7 @@ class Model(object):
def infer_net(self):
pass
def get_cost_op(self):
def get_avg_cost(self):
return self._cost
```
......
......@@ -59,7 +59,7 @@ class Model(ModelBase):
self.cost = avg_cost
self._metrics["acc"] = acc
def get_cost_op(self):
def get_avg_cost(self):
return self.cost
def get_metrics(self):
......
......@@ -89,7 +89,7 @@ class Model(ModelBase):
self.metrics["correct"] = correct
self.metrics["cos_pos"] = cos_pos
def get_cost_op(self):
def get_avg_cost(self):
return self.cost
def get_metrics(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册