提交 71824d93 编写于 作者: M malin10

update

上级 91a870d4
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import abc import abc
import paddle.fluid as fluid
import numpy as np
class Metric(object): class Metric(object):
...@@ -24,14 +26,23 @@ class Metric(object): ...@@ -24,14 +26,23 @@ class Metric(object):
""" """ """ """
pass pass
def clear(self, scope, params): def clear(self, scope=None, **kwargs):
""" """
clear current value clear current value
Args: Args:
scope: value container scope: value container
params: extend varilable for clear params: extend varilable for clear
""" """
pass if scope is None:
scope = fluid.global_scope()
place = fluid.CPUPlace()
for (varname, dtype) in self._need_clear_list:
if scope.find_var(varname) is None:
continue
var = scope.var(varname).get_tensor()
data_array = np.zeros(var._get_dims()).astype(dtype)
var.set(data_array, place)
def calculate(self, scope, params): def calculate(self, scope, params):
""" """
......
...@@ -99,6 +99,8 @@ class Precision(Metric): ...@@ -99,6 +99,8 @@ class Precision(Metric):
self.accuracy = local_pos_num / local_ins_num self.accuracy = local_pos_num / local_ins_num
self._need_clear_list = [("local_ins_num", "float32"),
("local_pos_num", "float32")]
self.metrics = dict() self.metrics = dict()
metric_varname = "P@%d" % kwargs.get("k") metric_varname = "P@%d" % kwargs.get("k")
self.metrics[metric_varname] = self.accuracy self.metrics[metric_varname] = self.accuracy
......
...@@ -221,6 +221,7 @@ class RunnerBase(object): ...@@ -221,6 +221,7 @@ class RunnerBase(object):
program = context["model"][model_name]["main_program"].clone() program = context["model"][model_name]["main_program"].clone()
_exe_strategy, _build_strategy = self._get_strategy(model_dict, _exe_strategy, _build_strategy = self._get_strategy(model_dict,
context) context)
program = fluid.compiler.CompiledProgram(program).with_data_parallel( program = fluid.compiler.CompiledProgram(program).with_data_parallel(
loss_name=model_class.get_avg_cost().name, loss_name=model_class.get_avg_cost().name,
build_strategy=_build_strategy, build_strategy=_build_strategy,
...@@ -497,6 +498,10 @@ class SingleInferRunner(RunnerBase): ...@@ -497,6 +498,10 @@ class SingleInferRunner(RunnerBase):
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
fluid.io.load_persistables( fluid.io.load_persistables(
context["exe"], model_path, main_program=train_prog) context["exe"], model_path, main_program=train_prog)
clear_metrics = context["model"][model_dict["name"]][
"model"].get_clear_metrics()
for var in clear_metrics:
var.clear()
def _dir_check(self, context): def _dir_check(self, context):
dirname = envs.get_global_env( dirname = envs.get_global_env(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册