未验证 提交 9b851ba2 编写于 作者: C Chen Weihang 提交者: GitHub

[dy2static] Add print transformer and unify print format (#24068)

* add print transformer & unify print format, test=develop

* remove using of dygraph_to_static_func, test=develop

* remove python stdout capture, test=develop

* fix compatibility problems for PY2, test=develop

* fix detail error, test=develop

* fix type analysis bug, test=develop

* fix print tuple compatible error in PY2, test=develop

* replace get_func to declarative, test=develop

* fix detail bug, test=develop

* fix some detail problems, test=develop

* change visit_call in print transformer, test=develop
上级 3e962aec
...@@ -50,9 +50,10 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) { ...@@ -50,9 +50,10 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) {
} }
std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
os << "\tlod: " << t.lod() << "\n"; if (t.lod().size() > 0) {
os << static_cast<Tensor>(t) << "\n"; os << " - lod: " << t.lod() << "\n";
}
os << static_cast<Tensor>(t);
return os; return os;
} }
......
...@@ -639,7 +639,7 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) { ...@@ -639,7 +639,7 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) {
auto inspect = tensor.data<T>(); auto inspect = tensor.data<T>();
auto element_num = tensor.numel(); auto element_num = tensor.numel();
os << "\tdata: ["; os << " - data: [";
if (element_num > 0) { if (element_num > 0) {
os << inspect[0]; os << inspect[0];
for (int j = 1; j < element_num; ++j) { for (int j = 1; j < element_num; ++j) {
...@@ -651,8 +651,9 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) { ...@@ -651,8 +651,9 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) {
} }
std::ostream& operator<<(std::ostream& os, const Tensor& t) { std::ostream& operator<<(std::ostream& os, const Tensor& t) {
os << "\tdim: " << t.dims() << "\n"; os << " - place: " << t.place() << "\n";
os << "\tlayout: " << DataLayoutToString(t.layout()) << "\n"; os << " - shape: [" << t.dims() << "]\n";
os << " - layout: " << DataLayoutToString(t.layout()) << "\n";
Tensor tensor; Tensor tensor;
tensor.Resize(t.dims()); tensor.Resize(t.dims());
...@@ -669,7 +670,7 @@ std::ostream& operator<<(std::ostream& os, const Tensor& t) { ...@@ -669,7 +670,7 @@ std::ostream& operator<<(std::ostream& os, const Tensor& t) {
#define PrintTensorCallback(cpp_type, proto_type) \ #define PrintTensorCallback(cpp_type, proto_type) \
do { \ do { \
if (tensor.type() == proto_type) { \ if (tensor.type() == proto_type) { \
os << "\tdtype: " << proto_type << "\n"; \ os << " - dtype: " << proto_type << "\n"; \
print_tensor<cpp_type>(os, tensor); \ print_tensor<cpp_type>(os, tensor); \
return os; \ return os; \
} \ } \
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/assign_op.h" #include "paddle/fluid/operators/assign_op.h"
...@@ -43,8 +44,9 @@ class LogGuard { ...@@ -43,8 +44,9 @@ class LogGuard {
struct Formater { struct Formater {
std::string message; std::string message;
std::string name; std::string name;
std::vector<int> dims; std::string dims;
std::type_index dtype{typeid(const char)}; std::type_index dtype{typeid(const char)};
std::string layout;
framework::LoD lod; framework::LoD lod;
int summarize; int summarize;
void *data{nullptr}; void *data{nullptr};
...@@ -52,50 +54,62 @@ struct Formater { ...@@ -52,50 +54,62 @@ struct Formater {
std::stringstream logs; std::stringstream logs;
void operator()(size_t size) { void operator()(size_t size) {
PrintMessage();
PrintPlaceInfo();
PrintName(); PrintName();
PrintMessage();
PrintLod();
PrintPlace();
PrintDims(); PrintDims();
PrintLayout();
PrintDtype(); PrintDtype();
PrintLod();
PrintData(size); PrintData(size);
LogGuard guard; LogGuard guard;
CLOG << logs.str(); CLOG << logs.str();
} }
private: private:
void PrintPlaceInfo() { logs << "The place is:" << place << std::endl; } void PrintPlace() { logs << " - place: " << place << std::endl; }
void PrintMessage() { logs << std::time(nullptr) << "\t" << message << "\t"; } void PrintMessage() {
if (!message.empty()) {
logs << " - message: " << message << std::endl;
}
}
void PrintName() { void PrintName() {
if (!name.empty()) { if (!name.empty()) {
logs << "Tensor[" << name << "]" << std::endl; logs << "Variable: " << name << std::endl;
} }
} }
void PrintDims() { void PrintDims() {
if (!dims.empty()) { if (!dims.empty()) {
logs << "\tshape: ["; logs << " - shape: " << dims << std::endl;
for (auto i : dims) {
logs << i << ",";
}
logs << "]" << std::endl;
} }
} }
void PrintDtype() { void PrintDtype() {
if (!framework::IsType<const char>(dtype)) { if (!framework::IsType<const char>(dtype)) {
logs << "\tdtype: " << dtype.name() << std::endl; logs << " - dtype: " << platform::demangle(dtype.name()) << std::endl;
}
}
void PrintLayout() {
if (!layout.empty()) {
logs << " - layout: " << layout << std::endl;
} }
} }
void PrintLod() { void PrintLod() {
if (!lod.empty()) { if (!lod.empty()) {
logs << "\tLoD: ["; logs << " - lod: {";
for (auto level : lod) { for (auto level : lod) {
logs << "[ "; logs << "{";
bool is_first = true;
for (auto i : level) { for (auto i : level) {
logs << i << ","; if (is_first) {
logs << i;
is_first = false;
} else {
logs << ", " << i;
}
} }
logs << " ]"; logs << "}";
} }
logs << "]" << std::endl; logs << "}" << std::endl;
} }
} }
...@@ -113,25 +127,31 @@ struct Formater { ...@@ -113,25 +127,31 @@ struct Formater {
} else if (framework::IsType<const bool>(dtype)) { } else if (framework::IsType<const bool>(dtype)) {
Display<bool>(size); Display<bool>(size);
} else { } else {
logs << "\tdata: unprintable type: " << dtype.name() << std::endl; logs << " - data: unprintable type: " << dtype.name() << std::endl;
} }
} }
template <typename T> template <typename T>
void Display(size_t size) { void Display(size_t size) {
auto *d = reinterpret_cast<T *>(data); auto *d = reinterpret_cast<T *>(data);
logs << "\tdata: "; logs << " - data: [";
if (summarize != -1) { if (summarize != -1) {
summarize = std::min(size, (size_t)summarize); summarize = std::min(size, (size_t)summarize);
for (int i = 0; i < summarize; i++) { if (summarize > 0) {
logs << d[i] << ","; logs << d[0];
for (int i = 1; i < summarize; ++i) {
logs << " " << d[i];
}
} }
} else { } else {
for (size_t i = 0; i < size; i++) { if (size > 0) {
logs << d[i] << ","; logs << d[0];
for (size_t i = 1; i < size; ++i) {
logs << " " << d[i];
}
} }
} }
logs << std::endl; logs << "]" << std::endl;
} }
}; };
...@@ -201,13 +221,14 @@ class PrintOp : public framework::OperatorBase { ...@@ -201,13 +221,14 @@ class PrintOp : public framework::OperatorBase {
formater.dtype = framework::ToTypeIndex(printed_tensor.type()); formater.dtype = framework::ToTypeIndex(printed_tensor.type());
} }
if (Attr<bool>("print_tensor_shape")) { if (Attr<bool>("print_tensor_shape")) {
auto &dims = printed_tensor.dims(); formater.dims = printed_tensor.dims().to_str();
formater.dims.resize(dims.size());
for (int i = 0; i < dims.size(); ++i) formater.dims[i] = dims[i];
} }
if (Attr<bool>("print_tensor_lod")) { if (Attr<bool>("print_tensor_lod")) {
formater.lod = printed_tensor.lod(); formater.lod = printed_tensor.lod();
} }
if (Attr<bool>("print_tensor_layout")) {
formater.layout = framework::DataLayoutToString(printed_tensor.layout());
}
formater.summarize = Attr<int>("summarize"); formater.summarize = Attr<int>("summarize");
formater.data = reinterpret_cast<void *>(printed_tensor.data<void>()); formater.data = reinterpret_cast<void *>(printed_tensor.data<void>());
formater(printed_tensor.numel()); formater(printed_tensor.numel());
...@@ -225,10 +246,17 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker { ...@@ -225,10 +246,17 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("first_n", "Only log `first_n` number of times."); AddAttr<int>("first_n", "Only log `first_n` number of times.");
AddAttr<std::string>("message", "A string message to print as a prefix."); AddAttr<std::string>("message", "A string message to print as a prefix.");
AddAttr<int>("summarize", "Number of elements printed."); AddAttr<int>("summarize", "Number of elements printed.");
AddAttr<bool>("print_tensor_name", "Whether to print the tensor name."); AddAttr<bool>("print_tensor_name", "Whether to print the tensor name.")
AddAttr<bool>("print_tensor_type", "Whether to print the tensor's dtype."); .SetDefault(true);
AddAttr<bool>("print_tensor_shape", "Whether to print the tensor's shape."); AddAttr<bool>("print_tensor_type", "Whether to print the tensor's dtype.")
AddAttr<bool>("print_tensor_lod", "Whether to print the tensor's lod."); .SetDefault(true);
AddAttr<bool>("print_tensor_shape", "Whether to print the tensor's shape.")
.SetDefault(true);
AddAttr<bool>("print_tensor_layout",
"Whether to print the tensor's layout.")
.SetDefault(true);
AddAttr<bool>("print_tensor_lod", "Whether to print the tensor's lod.")
.SetDefault(true);
AddAttr<std::string>("print_phase", AddAttr<std::string>("print_phase",
"(string, default 'FORWARD') Which phase to display " "(string, default 'FORWARD') Which phase to display "
"including 'FORWARD' " "including 'FORWARD' "
......
...@@ -33,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransfor ...@@ -33,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransfor
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
...@@ -88,6 +89,9 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -88,6 +89,9 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Transform all if/else statement of Dygraph into Static Graph. # Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node_wrapper).transform() IfElseTransformer(node_wrapper).transform()
# Transform all python print statement
PrintTransformer(node_wrapper).transform()
# Transform call recursively # Transform call recursively
CallTransformer(node_wrapper).transform() CallTransformer(node_wrapper).transform()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
import astor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
class PrintTransformer(gast.NodeTransformer):
"""
This class transforms python print function to fluid.layers.Print.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of PrintTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
def transform(self):
self.visit(self.root)
# NOTE: deal with print in PY3
def visit_Call(self, node):
assert isinstance(node, gast.Call)
if isinstance(node.func, gast.Name) and node.func.id == 'print':
var = self._get_print_var(node)
return self._construct_print_node(var)
return node
# NOTE: deal with print in PY2
def visit_Print(self, node):
var = self._get_print_var(node)
print_call_node = self._construct_print_node(var)
return gast.Expr(value=print_call_node)
def _get_print_var(self, node):
if isinstance(node, gast.Call):
var_list = node.args
elif isinstance(node, gast.Print):
var_list = node.values
if isinstance(var_list[0], gast.Tuple):
var_list = var_list[0].elts
# TODO: support print multiple Var
assert len(var_list) == 1, "Now only support print one Variable."
return var_list[0]
def _construct_print_node(self, node):
if isinstance(node, gast.Name):
if self._is_tensor_node(node):
print_node = gast.Call(
func=gast.parse('fluid.layers.Print').body[0].value,
args=[node],
keywords=[])
return print_node
else:
raise TypeError(
"print object type error, only support print Variable now.")
else:
# TODO: may not only print with format
raise NotImplementedError(
"cannot transform print with format temporarily.")
def _is_tensor_node(self, node):
tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
wrapper_node = self.node_to_wrapper_map.get(node, None)
if wrapper_node is not None:
if wrapper_node.node_var_type & tensor_types:
return True
return False
...@@ -181,6 +181,9 @@ class ProgramCache(object): ...@@ -181,6 +181,9 @@ class ProgramCache(object):
# func just returns one reuslt # func just returns one reuslt
fetch_list = [fetch_list] fetch_list = [fetch_list]
fetch_list = list(fetch_list) fetch_list = list(fetch_list)
# NOTE: avoid fetch_list is [None]
if len(fetch_list) == 1 and fetch_list[0] is None:
fetch_list = None
self._outputs = fetch_list self._outputs = fetch_list
else: else:
fetch_list = func(*args, **kwargs) fetch_list = func(*args, **kwargs)
...@@ -188,6 +191,9 @@ class ProgramCache(object): ...@@ -188,6 +191,9 @@ class ProgramCache(object):
# func just returns one reuslt # func just returns one reuslt
fetch_list = [fetch_list] fetch_list = [fetch_list]
fetch_list = list(fetch_list) fetch_list = list(fetch_list)
# NOTE: avoid fetch_list is [None]
if len(fetch_list) == 1 and fetch_list[0] is None:
fetch_list = None
return fetch_list return fetch_list
......
...@@ -66,7 +66,8 @@ class NodeVarType(object): ...@@ -66,7 +66,8 @@ class NodeVarType(object):
supported_types = [ supported_types = [
NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT, NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT,
NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR,
NodeVarType.PADDLE_RETURN_TYPES
] ]
if in_type1 not in supported_types: if in_type1 not in supported_types:
......
...@@ -198,11 +198,9 @@ def monkey_patch_varbase(): ...@@ -198,11 +198,9 @@ def monkey_patch_varbase():
# TODO(panyx0718): add more dygraph debug info. # TODO(panyx0718): add more dygraph debug info.
tensor = self.value().get_tensor() tensor = self.value().get_tensor()
if tensor._is_initialized(): if tensor._is_initialized():
return 'name %s, dtype: %s shape: %s %s' % ( return 'Variable: %s\n%s' % (self.name, str(tensor))
self.name, self.dtype, self.shape, str(tensor))
else: else:
return 'name %s, shape: %s, not inited' % (self.name, return 'Variable: %s, not initialized' % (self.name)
self.shape)
def __nonzero__(self): def __nonzero__(self):
numel = np.prod(self.shape) numel = np.prod(self.shape)
......
...@@ -5134,10 +5134,9 @@ class ParamBase(core.VarBase): ...@@ -5134,10 +5134,9 @@ class ParamBase(core.VarBase):
bool) bool)
tensor = self.value().get_tensor() tensor = self.value().get_tensor()
if tensor._is_initialized(): if tensor._is_initialized():
return 'name %s, dtype: %s shape: %s %s' % (self.name, self.dtype, return 'Parameter: %s\n%s' % (self.name, str(tensor))
self.shape, str(tensor))
else: else:
return 'name %s, shape: %s, not inited' % (self.name, self.shape) return 'Parameter: %s, not initialized' % (self.name)
__repr__ = __str__ __repr__ = __str__
......
...@@ -218,6 +218,7 @@ def Print(input, ...@@ -218,6 +218,7 @@ def Print(input,
print_tensor_name=True, print_tensor_name=True,
print_tensor_type=True, print_tensor_type=True,
print_tensor_shape=True, print_tensor_shape=True,
print_tensor_layout=True,
print_tensor_lod=True, print_tensor_lod=True,
print_phase='both'): print_phase='both'):
''' '''
...@@ -238,6 +239,7 @@ def Print(input, ...@@ -238,6 +239,7 @@ def Print(input,
print_tensor_name (bool, optional): Print the tensor name. Default: True. print_tensor_name (bool, optional): Print the tensor name. Default: True.
print_tensor_type (bool, optional): Print the tensor type. Defaultt: True. print_tensor_type (bool, optional): Print the tensor type. Defaultt: True.
print_tensor_shape (bool, optional): Print the tensor shape. Default: True. print_tensor_shape (bool, optional): Print the tensor shape. Default: True.
print_tensor_layout (bool, optional): Print the tensor layout. Default: True.
print_tensor_lod (bool, optional): Print the tensor lod. Default: True. print_tensor_lod (bool, optional): Print the tensor lod. Default: True.
print_phase (str): Which phase to displace, including 'forward', print_phase (str): Which phase to displace, including 'forward',
'backward' and 'both'. Default: 'both'. If set to 'backward', will 'backward' and 'both'. Default: 'both'. If set to 'backward', will
...@@ -291,6 +293,7 @@ def Print(input, ...@@ -291,6 +293,7 @@ def Print(input,
'print_tensor_name': print_tensor_name, 'print_tensor_name': print_tensor_name,
'print_tensor_type': print_tensor_type, 'print_tensor_type': print_tensor_type,
'print_tensor_shape': print_tensor_shape, 'print_tensor_shape': print_tensor_shape,
'print_tensor_layout': print_tensor_layout,
'print_tensor_lod': print_tensor_lod, 'print_tensor_lod': print_tensor_lod,
'print_phase': print_phase.upper() 'print_phase': print_phase.upper()
}) })
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy
import unittest
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import declarative
# 1. print VarBase
@declarative
def dyfunc_print_variable(x):
"""
PY2:
Print(dest=None, values=[Name(id='x_v', annotation=None, type_comment=None)], nl=True)],
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[Name(id='x_v', annotation=None, type_comment=None)],
keywords=[]))
"""
# NOTE: transform to static code, var name will be changed
x_v = fluid.dygraph.to_variable(x)
print(x_v)
# 2. print ndarray
@declarative
def dyfunc_print_ndarray(x):
"""
PY2:
Print(dest=None, values=[Name(id='x', annotation=None, type_comment=None)
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[Name(id='x', annotation=None, type_comment=None)],
keywords=[]))
"""
print(x)
# 3. print VarBase with format
@declarative
def dyfunc_print_with_format(x):
"""
PY2:
Print(dest=None,
values=[
Call(
func=Attribute(value=Constant(value='PrintVariable: {}', kind=None), attr='format'),
args=[Name(id='x_v', annotation=None, type_comment=None)],
keywords=[])],
nl=True)
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[
Call(
func=Attribute(value=Constant(value='PrintVariable: {}', kind=None), attr='format'),
args=[Name(id='x_v', annotation=None, type_comment=None)],
keywords=[])],
keywords=[]))
"""
x_v = fluid.dygraph.to_variable(x)
print("PrintVariable: {}".format(x_v))
# 4. print VarBase with format 2
@declarative
def dyfunc_print_with_format2(x):
"""
PY2:
Print(dest=None,
values=[
BinOp(left=Constant(value='PrintVariable: %s', kind=None),
op=Mod,
right=Name(id='x_v', annotation=None, type_comment=None))],
nl=True)
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[
BinOp(left=Constant(value='PrintVariable: %s', kind=None),
op=Mod,
right=Name(id='x_v', annotation=None, type_comment=None))],
keywords=[]))
"""
x_v = fluid.dygraph.to_variable(x)
print("PrintVariable: %s" % (x_v))
# 5. print VarBase in control flow1
@declarative
def dyfunc_print_with_ifelse(x):
x_v = fluid.dygraph.to_variable(x)
if len(x_v.shape) > 1:
print(x_v)
else:
print(x_v)
# 6. print mutiple VarBases
@declarative
def dyfunc_print_multi_vars(x):
"""
# NOTE: y_v type is error before cur PR in this case
Assign(targets=[Name(id='y_v', annotation=None, type_comment=None)],
value=BinOp(left=Name(id='x_v', annotation=None, type_comment=None), op=Mult, right=Constant(value=2, kind=None)))
"""
x_v = fluid.dygraph.to_variable(x)
y_v = x_v * 2
print(x_v)
print(y_v)
# 7. print continue VarBase
@declarative
def dyfunc_print_continue_vars(x):
"""
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[Name(id='x_v', annotation=None, type_comment=None),
Name(id='y_v', annotation=None, type_comment=None)],
keywords=[]))
PY2:
Print(dest=None,
values=[
Tuple(
elts=[Name(id='x_v', annotation=None, type_comment=None),
Name(id='y_v', annotation=None, type_comment=None)])],
nl=True)
"""
x_v = fluid.dygraph.to_variable(x)
y_v = x_v * 2
print(x_v, y_v)
class TestPrintBase(unittest.TestCase):
def setUp(self):
self.input = numpy.ones(5).astype("int32")
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.set_test_func()
def set_test_func(self):
raise NotImplementedError("Print test should implement set_test_func")
def get_dygraph_output(self):
with fluid.dygraph.guard():
self.dygraph_func(self.input)
def get_static_output(self):
with fluid.program_guard(fluid.Program()):
# TODO: How to catch C++ stdout to python
self.dygraph_func(self.input)
class TestPrintVariable(TestPrintBase):
def set_test_func(self):
self.dygraph_func = dyfunc_print_variable
def test_transformed_static_result(self):
self.get_dygraph_output()
self.get_static_output()
class TestPrintNdArray(TestPrintBase):
def set_test_func(self):
self.dygraph_func = dyfunc_print_ndarray
def test_transform_static_error(self):
with self.assertRaises(TypeError):
self.get_dygraph_output()
self.get_static_output()
class TestPrintWithFormat(TestPrintBase):
def set_test_func(self):
self.dygraph_func = dyfunc_print_with_format
def test_transform_static_error(self):
with self.assertRaises(NotImplementedError):
self.get_dygraph_output()
self.get_static_output()
class TestPrintWithFormat2(TestPrintBase):
def set_test_func(self):
self.dygraph_func = dyfunc_print_with_format2
def test_transform_static_error(self):
with self.assertRaises(NotImplementedError):
self.get_dygraph_output()
self.get_static_output()
class TestPrintWithIfElse(TestPrintVariable):
def set_test_func(self):
self.dygraph_func = dyfunc_print_with_ifelse
class TestPrintMultipleVar(TestPrintVariable):
def set_test_func(self):
self.dygraph_func = dyfunc_print_multi_vars
class TestPrintContinueVar(TestPrintBase):
def set_test_func(self):
self.dygraph_func = dyfunc_print_continue_vars
def test_transform_static_error(self):
with self.assertRaises(AssertionError):
self.get_dygraph_output()
self.get_static_output()
if __name__ == '__main__':
unittest.main()
...@@ -81,6 +81,14 @@ class TestPrintOpCPU(unittest.TestCase): ...@@ -81,6 +81,14 @@ class TestPrintOpCPU(unittest.TestCase):
fetch_list=[loss], fetch_list=[loss],
return_numpy=False) return_numpy=False)
def test_no_summarize(self):
switch_main_program(Program())
printed = self.build_network(True, summarize=-1, print_phase='forward')
exe = Executor(self.place)
outs = exe.run(feed={'x': self.x_tensor},
fetch_list=[printed],
return_numpy=False)
class TestPrintOpError(unittest.TestCase): class TestPrintOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册