未验证 提交 88600d7d 编写于 作者: C Chengmo 提交者: GitHub

fix lod tensor return (#101)

Co-authored-by: Ntangwei <tangwei12@baidu.com>
上级 5a521800
......@@ -16,10 +16,9 @@ from __future__ import print_function
import os
import time
import warnings
import datetime
import numpy as np
import paddle.fluid as fluid
from paddlerec.core.utils import envs
__all__ = [
......@@ -27,6 +26,42 @@ __all__ = [
]
def as_numpy(tensor):
"""
Convert a Tensor to a numpy.ndarray, its only support Tensor without LoD information.
For higher dimensional sequence data, please use LoDTensor directly.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy
new_scope = fluid.Scope()
with fluid.scope_guard(new_scope):
fluid.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), fluid.CPUPlace())
tensor = new_scope.find_var("data").get_tensor()
fluid.executor.as_numpy(tensor) # or numpy.array(new_scope.find_var("data").get_tensor())
Args:
tensor(Variable): a instance of Tensor
Returns:
numpy.ndarray
"""
if isinstance(tensor, fluid.core.LoDTensorArray):
return [as_numpy(t) for t in tensor]
if isinstance(tensor, list):
return [as_numpy(t) for t in tensor]
assert isinstance(tensor, fluid.core.LoDTensor)
lod = tensor.lod()
# (todo) need print lod or return it for user
if tensor._is_initialized():
return np.array(tensor)
else:
return None
class RunnerBase(object):
"""R
"""
......@@ -92,9 +127,6 @@ class RunnerBase(object):
model_class = context["model"][model_dict["name"]]["model"]
program = self._get_dataloader_program(model_dict, context)
reader_name = model_dict["dataset_name"]
fetch_vars = []
fetch_alias = []
fetch_period = int(
envs.get_global_env("runner." + context["runner_name"] +
".print_interval", 20))
......@@ -103,9 +135,6 @@ class RunnerBase(object):
else:
metrics = model_class.get_metrics()
if metrics:
fetch_vars = metrics.values()
fetch_alias = metrics.keys()
metrics_varnames = []
metrics_format = []
metrics_format.append("{}: {{}}".format("batch"))
......@@ -121,9 +150,16 @@ class RunnerBase(object):
with fluid.scope_guard(scope):
try:
while True:
metrics_rets = context["exe"].run(
program=program, fetch_list=metrics_varnames)
metrics_tensors = context["exe"].run(
program=program,
fetch_list=metrics_varnames,
return_numpy=False)
metrics = [batch_id]
metrics_rets = [
as_numpy(metrics_tensor)
for metrics_tensor in metrics_tensors
]
metrics.extend(metrics_rets)
if batch_id % fetch_period == 0 and batch_id != 0:
......
......@@ -19,12 +19,7 @@ from __future__ import print_function
import os
from paddlerec.core.utils import envs
from paddlerec.core.trainer import Trainer, EngineMode, FleetMode, Device
from paddlerec.core.trainers.framework.dataset import *
from paddlerec.core.trainers.framework.runner import *
from paddlerec.core.trainers.framework.instance import *
from paddlerec.core.trainers.framework.network import *
from paddlerec.core.trainers.framework.startup import *
from paddlerec.core.trainer import Trainer, EngineMode, FleetMode
class GeneralTrainer(Trainer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册