提交 71824d93 编写于 作者: M malin10

update

上级 91a870d4
......@@ -13,6 +13,8 @@
# limitations under the License.
import abc
import paddle.fluid as fluid
import numpy as np
class Metric(object):
......@@ -24,14 +26,23 @@ class Metric(object):
""" """
pass
def clear(self, scope, params):
def clear(self, scope=None, **kwargs):
"""
clear current value
Args:
scope: value container
params: extend varilable for clear
"""
pass
if scope is None:
scope = fluid.global_scope()
place = fluid.CPUPlace()
for (varname, dtype) in self._need_clear_list:
if scope.find_var(varname) is None:
continue
var = scope.var(varname).get_tensor()
data_array = np.zeros(var._get_dims()).astype(dtype)
var.set(data_array, place)
def calculate(self, scope, params):
"""
......
......@@ -99,6 +99,8 @@ class Precision(Metric):
self.accuracy = local_pos_num / local_ins_num
self._need_clear_list = [("local_ins_num", "float32"),
("local_pos_num", "float32")]
self.metrics = dict()
metric_varname = "P@%d" % kwargs.get("k")
self.metrics[metric_varname] = self.accuracy
......
......@@ -221,6 +221,7 @@ class RunnerBase(object):
program = context["model"][model_name]["main_program"].clone()
_exe_strategy, _build_strategy = self._get_strategy(model_dict,
context)
program = fluid.compiler.CompiledProgram(program).with_data_parallel(
loss_name=model_class.get_avg_cost().name,
build_strategy=_build_strategy,
......@@ -497,6 +498,10 @@ class SingleInferRunner(RunnerBase):
with fluid.program_guard(train_prog, startup_prog):
fluid.io.load_persistables(
context["exe"], model_path, main_program=train_prog)
clear_metrics = context["model"][model_dict["name"]][
"model"].get_clear_metrics()
for var in clear_metrics:
var.clear()
def _dir_check(self, context):
dirname = envs.get_global_env(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册