diff --git a/core/metric.py b/core/metric.py index 287ec2f8a2007739eaa3303933a345ea10b063fb..d9968fa40167b6ca728b0c1046fca5e70ef427a7 100755 --- a/core/metric.py +++ b/core/metric.py @@ -13,6 +13,8 @@ # limitations under the License. import abc +import paddle.fluid as fluid +import numpy as np class Metric(object): @@ -24,14 +26,23 @@ class Metric(object): """ """ pass - def clear(self, scope, params): + def clear(self, scope=None, **kwargs): """ clear current value Args: scope: value container params: extend varilable for clear """ - pass + if scope is None: + scope = fluid.global_scope() + + place = fluid.CPUPlace() + for (varname, dtype) in self._need_clear_list: + if scope.find_var(varname) is None: + continue + var = scope.var(varname).get_tensor() + data_array = np.zeros(var._get_dims()).astype(dtype) + var.set(data_array, place) def calculate(self, scope, params): """ diff --git a/core/metrics/precision.py b/core/metrics/precision.py index d748e17d3ee6c34b282ee3558d35ce6d4f66c952..4b9b4bd3101854f70308455cabc67bb64249b5dc 100755 --- a/core/metrics/precision.py +++ b/core/metrics/precision.py @@ -99,6 +99,8 @@ class Precision(Metric): self.accuracy = local_pos_num / local_ins_num + self._need_clear_list = [("local_ins_num", "float32"), + ("local_pos_num", "float32")] self.metrics = dict() metric_varname = "P@%d" % kwargs.get("k") self.metrics[metric_varname] = self.accuracy diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index f316327ac2447d1a7ac40f511d487685d33e542b..d5fced11ffd546b36ee7db3e596f061bf8a58328 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -221,6 +221,7 @@ class RunnerBase(object): program = context["model"][model_name]["main_program"].clone() _exe_strategy, _build_strategy = self._get_strategy(model_dict, context) + program = fluid.compiler.CompiledProgram(program).with_data_parallel( loss_name=model_class.get_avg_cost().name, build_strategy=_build_strategy, @@ -497,6 +498,10 @@ class SingleInferRunner(RunnerBase): with fluid.program_guard(train_prog, startup_prog): fluid.io.load_persistables( context["exe"], model_path, main_program=train_prog) + clear_metrics = context["model"][model_dict["name"]][ + "model"].get_clear_metrics() + for var in clear_metrics: + var.clear() def _dir_check(self, context): dirname = envs.get_global_env(