From 71824d938e99359850631784c9835fe44063bde0 Mon Sep 17 00:00:00 2001 From: malin10 Date: Tue, 21 Jul 2020 21:17:17 +0800 Subject: [PATCH] update --- core/metric.py | 15 +++++++++++++-- core/metrics/precision.py | 2 ++ core/trainers/framework/runner.py | 5 +++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/core/metric.py b/core/metric.py index 287ec2f8..d9968fa4 100755 --- a/core/metric.py +++ b/core/metric.py @@ -13,6 +13,8 @@ # limitations under the License. import abc +import paddle.fluid as fluid +import numpy as np class Metric(object): @@ -24,14 +26,23 @@ class Metric(object): """ """ pass - def clear(self, scope, params): + def clear(self, scope=None, **kwargs): """ clear current value Args: scope: value container params: extend varilable for clear """ - pass + if scope is None: + scope = fluid.global_scope() + + place = fluid.CPUPlace() + for (varname, dtype) in self._need_clear_list: + if scope.find_var(varname) is None: + continue + var = scope.var(varname).get_tensor() + data_array = np.zeros(var._get_dims()).astype(dtype) + var.set(data_array, place) def calculate(self, scope, params): """ diff --git a/core/metrics/precision.py b/core/metrics/precision.py index d748e17d..4b9b4bd3 100755 --- a/core/metrics/precision.py +++ b/core/metrics/precision.py @@ -99,6 +99,8 @@ class Precision(Metric): self.accuracy = local_pos_num / local_ins_num + self._need_clear_list = [("local_ins_num", "float32"), + ("local_pos_num", "float32")] self.metrics = dict() metric_varname = "P@%d" % kwargs.get("k") self.metrics[metric_varname] = self.accuracy diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index f316327a..d5fced11 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -221,6 +221,7 @@ class RunnerBase(object): program = context["model"][model_name]["main_program"].clone() _exe_strategy, _build_strategy = self._get_strategy(model_dict, context) + program = fluid.compiler.CompiledProgram(program).with_data_parallel( loss_name=model_class.get_avg_cost().name, build_strategy=_build_strategy, @@ -497,6 +498,10 @@ class SingleInferRunner(RunnerBase): with fluid.program_guard(train_prog, startup_prog): fluid.io.load_persistables( context["exe"], model_path, main_program=train_prog) + clear_metrics = context["model"][model_dict["name"]][ + "model"].get_clear_metrics() + for var in clear_metrics: + var.clear() def _dir_check(self, context): dirname = envs.get_global_env( -- GitLab