未验证 提交 88600d7d 编写于 作者: C Chengmo 提交者: GitHub

fix lod tensor return (#101)

Co-authored-by: Ntangwei <tangwei12@baidu.com>
上级 5a521800
...@@ -16,10 +16,9 @@ from __future__ import print_function ...@@ -16,10 +16,9 @@ from __future__ import print_function
import os import os
import time import time
import warnings import numpy as np
import datetime
import paddle.fluid as fluid import paddle.fluid as fluid
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
__all__ = [ __all__ = [
...@@ -27,6 +26,42 @@ __all__ = [ ...@@ -27,6 +26,42 @@ __all__ = [
] ]
def as_numpy(tensor):
"""
Convert a Tensor to a numpy.ndarray, its only support Tensor without LoD information.
For higher dimensional sequence data, please use LoDTensor directly.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy
new_scope = fluid.Scope()
with fluid.scope_guard(new_scope):
fluid.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), fluid.CPUPlace())
tensor = new_scope.find_var("data").get_tensor()
fluid.executor.as_numpy(tensor) # or numpy.array(new_scope.find_var("data").get_tensor())
Args:
tensor(Variable): a instance of Tensor
Returns:
numpy.ndarray
"""
if isinstance(tensor, fluid.core.LoDTensorArray):
return [as_numpy(t) for t in tensor]
if isinstance(tensor, list):
return [as_numpy(t) for t in tensor]
assert isinstance(tensor, fluid.core.LoDTensor)
lod = tensor.lod()
# (todo) need print lod or return it for user
if tensor._is_initialized():
return np.array(tensor)
else:
return None
class RunnerBase(object): class RunnerBase(object):
"""R """R
""" """
...@@ -92,9 +127,6 @@ class RunnerBase(object): ...@@ -92,9 +127,6 @@ class RunnerBase(object):
model_class = context["model"][model_dict["name"]]["model"] model_class = context["model"][model_dict["name"]]["model"]
program = self._get_dataloader_program(model_dict, context) program = self._get_dataloader_program(model_dict, context)
reader_name = model_dict["dataset_name"]
fetch_vars = []
fetch_alias = []
fetch_period = int( fetch_period = int(
envs.get_global_env("runner." + context["runner_name"] + envs.get_global_env("runner." + context["runner_name"] +
".print_interval", 20)) ".print_interval", 20))
...@@ -103,9 +135,6 @@ class RunnerBase(object): ...@@ -103,9 +135,6 @@ class RunnerBase(object):
else: else:
metrics = model_class.get_metrics() metrics = model_class.get_metrics()
if metrics:
fetch_vars = metrics.values()
fetch_alias = metrics.keys()
metrics_varnames = [] metrics_varnames = []
metrics_format = [] metrics_format = []
metrics_format.append("{}: {{}}".format("batch")) metrics_format.append("{}: {{}}".format("batch"))
...@@ -121,9 +150,16 @@ class RunnerBase(object): ...@@ -121,9 +150,16 @@ class RunnerBase(object):
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
try: try:
while True: while True:
metrics_rets = context["exe"].run( metrics_tensors = context["exe"].run(
program=program, fetch_list=metrics_varnames) program=program,
fetch_list=metrics_varnames,
return_numpy=False)
metrics = [batch_id] metrics = [batch_id]
metrics_rets = [
as_numpy(metrics_tensor)
for metrics_tensor in metrics_tensors
]
metrics.extend(metrics_rets) metrics.extend(metrics_rets)
if batch_id % fetch_period == 0 and batch_id != 0: if batch_id % fetch_period == 0 and batch_id != 0:
...@@ -248,7 +284,7 @@ class RunnerBase(object): ...@@ -248,7 +284,7 @@ class RunnerBase(object):
fetch_varnames = envs.get_global_env( fetch_varnames = envs.get_global_env(
name + "save_inference_fetch_varnames", []) name + "save_inference_fetch_varnames", [])
if feed_varnames is None or fetch_varnames is None or feed_varnames == "" or fetch_varnames == "" or \ if feed_varnames is None or fetch_varnames is None or feed_varnames == "" or fetch_varnames == "" or \
len(feed_varnames) == 0 or len(fetch_varnames) == 0: len(feed_varnames) == 0 or len(fetch_varnames) == 0:
return return
fetch_vars = [ fetch_vars = [
fluid.default_main_program().global_block().vars[varname] fluid.default_main_program().global_block().vars[varname]
......
...@@ -19,12 +19,7 @@ from __future__ import print_function ...@@ -19,12 +19,7 @@ from __future__ import print_function
import os import os
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
from paddlerec.core.trainer import Trainer, EngineMode, FleetMode, Device from paddlerec.core.trainer import Trainer, EngineMode, FleetMode
from paddlerec.core.trainers.framework.dataset import *
from paddlerec.core.trainers.framework.runner import *
from paddlerec.core.trainers.framework.instance import *
from paddlerec.core.trainers.framework.network import *
from paddlerec.core.trainers.framework.startup import *
class GeneralTrainer(Trainer): class GeneralTrainer(Trainer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册