From 9b851ba21687c6fb061cf06a3392f941cbd0aebf Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 27 Apr 2020 18:41:44 +0800 Subject: [PATCH] [dy2static] Add print transformer and unify print format (#24068) * add print transformer & unify print format, test=develop * remove using of dygraph_to_static_func, test=develop * remove python stdout capture, test=develop * fix compatibility problems for PY2, test=develop * fix detail error, test=develop * fix type analysis bug, test=develop * fix print tuple compatible error in PY2, test=develop * replace get_func to declarative, test=develop * fix detail bug, test=develop * fix some detail problems, test=develop * change visit_call in print transformer, test=develop --- paddle/fluid/framework/lod_tensor.cc | 7 +- paddle/fluid/framework/tensor_util.cc | 9 +- paddle/fluid/operators/print_op.cc | 92 ++++--- .../dygraph_to_static/ast_transformer.py | 4 + .../dygraph_to_static/print_transformer.py | 89 +++++++ .../dygraph_to_static/program_translator.py | 6 + .../dygraph_to_static/static_analysis.py | 3 +- .../fluid/dygraph/varbase_patch_methods.py | 6 +- python/paddle/fluid/framework.py | 5 +- python/paddle/fluid/layers/control_flow.py | 3 + .../unittests/dygraph_to_static/test_print.py | 233 ++++++++++++++++++ .../fluid/tests/unittests/test_print_op.py | 8 + 12 files changed, 418 insertions(+), 47 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_print.py diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 8fd989a2c8..2d1cba3b0f 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -50,9 +50,10 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) { } std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { - os << "\tlod: " << t.lod() << "\n"; - os << static_cast(t) << "\n"; - + if (t.lod().size() > 0) { + os << " - lod: " << t.lod() << "\n"; + } + os << static_cast(t); return os; } diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index b082865331..75d3597bab 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -639,7 +639,7 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) { auto inspect = tensor.data(); auto element_num = tensor.numel(); - os << "\tdata: ["; + os << " - data: ["; if (element_num > 0) { os << inspect[0]; for (int j = 1; j < element_num; ++j) { @@ -651,8 +651,9 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) { } std::ostream& operator<<(std::ostream& os, const Tensor& t) { - os << "\tdim: " << t.dims() << "\n"; - os << "\tlayout: " << DataLayoutToString(t.layout()) << "\n"; + os << " - place: " << t.place() << "\n"; + os << " - shape: [" << t.dims() << "]\n"; + os << " - layout: " << DataLayoutToString(t.layout()) << "\n"; Tensor tensor; tensor.Resize(t.dims()); @@ -669,7 +670,7 @@ std::ostream& operator<<(std::ostream& os, const Tensor& t) { #define PrintTensorCallback(cpp_type, proto_type) \ do { \ if (tensor.type() == proto_type) { \ - os << "\tdtype: " << proto_type << "\n"; \ + os << " - dtype: " << proto_type << "\n"; \ print_tensor(os, tensor); \ return os; \ } \ diff --git a/paddle/fluid/operators/print_op.cc b/paddle/fluid/operators/print_op.cc index dff2074fbe..238d8218a2 100644 --- a/paddle/fluid/operators/print_op.cc +++ b/paddle/fluid/operators/print_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include +#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/operators/assign_op.h" @@ -43,8 +44,9 @@ class LogGuard { struct Formater { std::string message; std::string name; - std::vector dims; + std::string dims; std::type_index dtype{typeid(const char)}; + std::string layout; framework::LoD lod; int summarize; void *data{nullptr}; @@ -52,50 +54,62 @@ struct Formater { std::stringstream logs; void operator()(size_t size) { - PrintMessage(); - PrintPlaceInfo(); PrintName(); + PrintMessage(); + PrintLod(); + PrintPlace(); PrintDims(); + PrintLayout(); PrintDtype(); - PrintLod(); PrintData(size); LogGuard guard; CLOG << logs.str(); } private: - void PrintPlaceInfo() { logs << "The place is:" << place << std::endl; } - void PrintMessage() { logs << std::time(nullptr) << "\t" << message << "\t"; } + void PrintPlace() { logs << " - place: " << place << std::endl; } + void PrintMessage() { + if (!message.empty()) { + logs << " - message: " << message << std::endl; + } + } void PrintName() { if (!name.empty()) { - logs << "Tensor[" << name << "]" << std::endl; + logs << "Variable: " << name << std::endl; } } void PrintDims() { if (!dims.empty()) { - logs << "\tshape: ["; - for (auto i : dims) { - logs << i << ","; - } - logs << "]" << std::endl; + logs << " - shape: " << dims << std::endl; } } void PrintDtype() { if (!framework::IsType(dtype)) { - logs << "\tdtype: " << dtype.name() << std::endl; + logs << " - dtype: " << platform::demangle(dtype.name()) << std::endl; + } + } + void PrintLayout() { + if (!layout.empty()) { + logs << " - layout: " << layout << std::endl; } } void PrintLod() { if (!lod.empty()) { - logs << "\tLoD: ["; + logs << " - lod: {"; for (auto level : lod) { - logs << "[ "; + logs << "{"; + bool is_first = true; for (auto i : level) { - logs << i << ","; + if (is_first) { + logs << i; + is_first = false; + } else { + logs << ", " << i; + } } - logs << " ]"; + logs << "}"; } - logs << "]" << std::endl; + logs << "}" << std::endl; } } @@ -113,25 +127,31 @@ struct Formater { } else if (framework::IsType(dtype)) { Display(size); } else { - logs << "\tdata: unprintable type: " << dtype.name() << std::endl; + logs << " - data: unprintable type: " << dtype.name() << std::endl; } } template void Display(size_t size) { auto *d = reinterpret_cast(data); - logs << "\tdata: "; + logs << " - data: ["; if (summarize != -1) { summarize = std::min(size, (size_t)summarize); - for (int i = 0; i < summarize; i++) { - logs << d[i] << ","; + if (summarize > 0) { + logs << d[0]; + for (int i = 1; i < summarize; ++i) { + logs << " " << d[i]; + } } } else { - for (size_t i = 0; i < size; i++) { - logs << d[i] << ","; + if (size > 0) { + logs << d[0]; + for (size_t i = 1; i < size; ++i) { + logs << " " << d[i]; + } } } - logs << std::endl; + logs << "]" << std::endl; } }; @@ -201,13 +221,14 @@ class PrintOp : public framework::OperatorBase { formater.dtype = framework::ToTypeIndex(printed_tensor.type()); } if (Attr("print_tensor_shape")) { - auto &dims = printed_tensor.dims(); - formater.dims.resize(dims.size()); - for (int i = 0; i < dims.size(); ++i) formater.dims[i] = dims[i]; + formater.dims = printed_tensor.dims().to_str(); } if (Attr("print_tensor_lod")) { formater.lod = printed_tensor.lod(); } + if (Attr("print_tensor_layout")) { + formater.layout = framework::DataLayoutToString(printed_tensor.layout()); + } formater.summarize = Attr("summarize"); formater.data = reinterpret_cast(printed_tensor.data()); formater(printed_tensor.numel()); @@ -225,10 +246,17 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker { AddAttr("first_n", "Only log `first_n` number of times."); AddAttr("message", "A string message to print as a prefix."); AddAttr("summarize", "Number of elements printed."); - AddAttr("print_tensor_name", "Whether to print the tensor name."); - AddAttr("print_tensor_type", "Whether to print the tensor's dtype."); - AddAttr("print_tensor_shape", "Whether to print the tensor's shape."); - AddAttr("print_tensor_lod", "Whether to print the tensor's lod."); + AddAttr("print_tensor_name", "Whether to print the tensor name.") + .SetDefault(true); + AddAttr("print_tensor_type", "Whether to print the tensor's dtype.") + .SetDefault(true); + AddAttr("print_tensor_shape", "Whether to print the tensor's shape.") + .SetDefault(true); + AddAttr("print_tensor_layout", + "Whether to print the tensor's layout.") + .SetDefault(true); + AddAttr("print_tensor_lod", "Whether to print the tensor's lod.") + .SetDefault(true); AddAttr("print_phase", "(string, default 'FORWARD') Which phase to display " "including 'FORWARD' " diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 9c51972369..67f54b131c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -33,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransfor from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer +from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType @@ -88,6 +89,9 @@ class DygraphToStaticAst(gast.NodeTransformer): # Transform all if/else statement of Dygraph into Static Graph. IfElseTransformer(node_wrapper).transform() + # Transform all python print statement + PrintTransformer(node_wrapper).transform() + # Transform call recursively CallTransformer(node_wrapper).transform() diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py new file mode 100644 index 0000000000..1258cbd1a4 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py @@ -0,0 +1,89 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import gast +import astor + +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor + + +class PrintTransformer(gast.NodeTransformer): + """ + This class transforms python print function to fluid.layers.Print. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Input non-AstNodeWrapper node for the initialization of PrintTransformer." + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + + self.static_analysis_visitor = StaticAnalysisVisitor(self.root) + self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( + ) + + def transform(self): + self.visit(self.root) + + # NOTE: deal with print in PY3 + def visit_Call(self, node): + assert isinstance(node, gast.Call) + if isinstance(node.func, gast.Name) and node.func.id == 'print': + var = self._get_print_var(node) + return self._construct_print_node(var) + return node + + # NOTE: deal with print in PY2 + def visit_Print(self, node): + var = self._get_print_var(node) + print_call_node = self._construct_print_node(var) + return gast.Expr(value=print_call_node) + + def _get_print_var(self, node): + if isinstance(node, gast.Call): + var_list = node.args + elif isinstance(node, gast.Print): + var_list = node.values + if isinstance(var_list[0], gast.Tuple): + var_list = var_list[0].elts + # TODO: support print multiple Var + assert len(var_list) == 1, "Now only support print one Variable." + return var_list[0] + + def _construct_print_node(self, node): + if isinstance(node, gast.Name): + if self._is_tensor_node(node): + print_node = gast.Call( + func=gast.parse('fluid.layers.Print').body[0].value, + args=[node], + keywords=[]) + return print_node + else: + raise TypeError( + "print object type error, only support print Variable now.") + else: + # TODO: may not only print with format + raise NotImplementedError( + "cannot transform print with format temporarily.") + + def _is_tensor_node(self, node): + tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES} + wrapper_node = self.node_to_wrapper_map.get(node, None) + if wrapper_node is not None: + if wrapper_node.node_var_type & tensor_types: + return True + return False diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 4d0abd5269..a4d1131ab8 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -181,6 +181,9 @@ class ProgramCache(object): # func just returns one reuslt fetch_list = [fetch_list] fetch_list = list(fetch_list) + # NOTE: avoid fetch_list is [None] + if len(fetch_list) == 1 and fetch_list[0] is None: + fetch_list = None self._outputs = fetch_list else: fetch_list = func(*args, **kwargs) @@ -188,6 +191,9 @@ class ProgramCache(object): # func just returns one reuslt fetch_list = [fetch_list] fetch_list = list(fetch_list) + # NOTE: avoid fetch_list is [None] + if len(fetch_list) == 1 and fetch_list[0] is None: + fetch_list = None return fetch_list diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py index c4bbd278ac..374b0d4e06 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -66,7 +66,8 @@ class NodeVarType(object): supported_types = [ NodeVarType.BOOLEAN, NodeVarType.INT, NodeVarType.FLOAT, - NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR + NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR, + NodeVarType.PADDLE_RETURN_TYPES ] if in_type1 not in supported_types: diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 5da400391e..6b528479ff 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -198,11 +198,9 @@ def monkey_patch_varbase(): # TODO(panyx0718): add more dygraph debug info. tensor = self.value().get_tensor() if tensor._is_initialized(): - return 'name %s, dtype: %s shape: %s %s' % ( - self.name, self.dtype, self.shape, str(tensor)) + return 'Variable: %s\n%s' % (self.name, str(tensor)) else: - return 'name %s, shape: %s, not inited' % (self.name, - self.shape) + return 'Variable: %s, not initialized' % (self.name) def __nonzero__(self): numel = np.prod(self.shape) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5ff0bd9e60..487c4f83aa 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -5134,10 +5134,9 @@ class ParamBase(core.VarBase): bool) tensor = self.value().get_tensor() if tensor._is_initialized(): - return 'name %s, dtype: %s shape: %s %s' % (self.name, self.dtype, - self.shape, str(tensor)) + return 'Parameter: %s\n%s' % (self.name, str(tensor)) else: - return 'name %s, shape: %s, not inited' % (self.name, self.shape) + return 'Parameter: %s, not initialized' % (self.name) __repr__ = __str__ diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index c8b2cef735..dc31ec3b60 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -218,6 +218,7 @@ def Print(input, print_tensor_name=True, print_tensor_type=True, print_tensor_shape=True, + print_tensor_layout=True, print_tensor_lod=True, print_phase='both'): ''' @@ -238,6 +239,7 @@ def Print(input, print_tensor_name (bool, optional): Print the tensor name. Default: True. print_tensor_type (bool, optional): Print the tensor type. Defaultt: True. print_tensor_shape (bool, optional): Print the tensor shape. Default: True. + print_tensor_layout (bool, optional): Print the tensor layout. Default: True. print_tensor_lod (bool, optional): Print the tensor lod. Default: True. print_phase (str): Which phase to displace, including 'forward', 'backward' and 'both'. Default: 'both'. If set to 'backward', will @@ -291,6 +293,7 @@ def Print(input, 'print_tensor_name': print_tensor_name, 'print_tensor_type': print_tensor_type, 'print_tensor_shape': print_tensor_shape, + 'print_tensor_layout': print_tensor_layout, 'print_tensor_lod': print_tensor_lod, 'print_phase': print_phase.upper() }) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_print.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_print.py new file mode 100644 index 0000000000..2dd1ec7e9e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_print.py @@ -0,0 +1,233 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy +import unittest + +import paddle.fluid as fluid +from paddle.fluid.dygraph.jit import declarative + + +# 1. print VarBase +@declarative +def dyfunc_print_variable(x): + """ + PY2: + Print(dest=None, values=[Name(id='x_v', annotation=None, type_comment=None)], nl=True)], + PY3: + Expr( + value=Call(func=Name(id='print', annotation=None, type_comment=None), + args=[Name(id='x_v', annotation=None, type_comment=None)], + keywords=[])) + """ + # NOTE: transform to static code, var name will be changed + x_v = fluid.dygraph.to_variable(x) + print(x_v) + + +# 2. print ndarray +@declarative +def dyfunc_print_ndarray(x): + """ + PY2: + Print(dest=None, values=[Name(id='x', annotation=None, type_comment=None) + PY3: + Expr( + value=Call(func=Name(id='print', annotation=None, type_comment=None), + args=[Name(id='x', annotation=None, type_comment=None)], + keywords=[])) + """ + print(x) + + +# 3. print VarBase with format +@declarative +def dyfunc_print_with_format(x): + """ + PY2: + Print(dest=None, + values=[ + Call( + func=Attribute(value=Constant(value='PrintVariable: {}', kind=None), attr='format'), + args=[Name(id='x_v', annotation=None, type_comment=None)], + keywords=[])], + nl=True) + PY3: + Expr( + value=Call(func=Name(id='print', annotation=None, type_comment=None), + args=[ + Call( + func=Attribute(value=Constant(value='PrintVariable: {}', kind=None), attr='format'), + args=[Name(id='x_v', annotation=None, type_comment=None)], + keywords=[])], + keywords=[])) + """ + x_v = fluid.dygraph.to_variable(x) + print("PrintVariable: {}".format(x_v)) + + +# 4. print VarBase with format 2 +@declarative +def dyfunc_print_with_format2(x): + """ + PY2: + Print(dest=None, + values=[ + BinOp(left=Constant(value='PrintVariable: %s', kind=None), + op=Mod, + right=Name(id='x_v', annotation=None, type_comment=None))], + nl=True) + PY3: + Expr( + value=Call(func=Name(id='print', annotation=None, type_comment=None), + args=[ + BinOp(left=Constant(value='PrintVariable: %s', kind=None), + op=Mod, + right=Name(id='x_v', annotation=None, type_comment=None))], + keywords=[])) + """ + x_v = fluid.dygraph.to_variable(x) + print("PrintVariable: %s" % (x_v)) + + +# 5. print VarBase in control flow1 +@declarative +def dyfunc_print_with_ifelse(x): + x_v = fluid.dygraph.to_variable(x) + if len(x_v.shape) > 1: + print(x_v) + else: + print(x_v) + + +# 6. print mutiple VarBases +@declarative +def dyfunc_print_multi_vars(x): + """ + # NOTE: y_v type is error before cur PR in this case + Assign(targets=[Name(id='y_v', annotation=None, type_comment=None)], + value=BinOp(left=Name(id='x_v', annotation=None, type_comment=None), op=Mult, right=Constant(value=2, kind=None))) + """ + x_v = fluid.dygraph.to_variable(x) + y_v = x_v * 2 + print(x_v) + print(y_v) + + +# 7. print continue VarBase +@declarative +def dyfunc_print_continue_vars(x): + """ + PY3: + Expr( + value=Call(func=Name(id='print', annotation=None, type_comment=None), + args=[Name(id='x_v', annotation=None, type_comment=None), + Name(id='y_v', annotation=None, type_comment=None)], + keywords=[])) + PY2: + Print(dest=None, + values=[ + Tuple( + elts=[Name(id='x_v', annotation=None, type_comment=None), + Name(id='y_v', annotation=None, type_comment=None)])], + nl=True) + """ + x_v = fluid.dygraph.to_variable(x) + y_v = x_v * 2 + print(x_v, y_v) + + +class TestPrintBase(unittest.TestCase): + def setUp(self): + self.input = numpy.ones(5).astype("int32") + self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() + self.set_test_func() + + def set_test_func(self): + raise NotImplementedError("Print test should implement set_test_func") + + def get_dygraph_output(self): + with fluid.dygraph.guard(): + self.dygraph_func(self.input) + + def get_static_output(self): + with fluid.program_guard(fluid.Program()): + # TODO: How to catch C++ stdout to python + self.dygraph_func(self.input) + + +class TestPrintVariable(TestPrintBase): + def set_test_func(self): + self.dygraph_func = dyfunc_print_variable + + def test_transformed_static_result(self): + self.get_dygraph_output() + self.get_static_output() + + +class TestPrintNdArray(TestPrintBase): + def set_test_func(self): + self.dygraph_func = dyfunc_print_ndarray + + def test_transform_static_error(self): + with self.assertRaises(TypeError): + self.get_dygraph_output() + self.get_static_output() + + +class TestPrintWithFormat(TestPrintBase): + def set_test_func(self): + self.dygraph_func = dyfunc_print_with_format + + def test_transform_static_error(self): + with self.assertRaises(NotImplementedError): + self.get_dygraph_output() + self.get_static_output() + + +class TestPrintWithFormat2(TestPrintBase): + def set_test_func(self): + self.dygraph_func = dyfunc_print_with_format2 + + def test_transform_static_error(self): + with self.assertRaises(NotImplementedError): + self.get_dygraph_output() + self.get_static_output() + + +class TestPrintWithIfElse(TestPrintVariable): + def set_test_func(self): + self.dygraph_func = dyfunc_print_with_ifelse + + +class TestPrintMultipleVar(TestPrintVariable): + def set_test_func(self): + self.dygraph_func = dyfunc_print_multi_vars + + +class TestPrintContinueVar(TestPrintBase): + def set_test_func(self): + self.dygraph_func = dyfunc_print_continue_vars + + def test_transform_static_error(self): + with self.assertRaises(AssertionError): + self.get_dygraph_output() + self.get_static_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_print_op.py b/python/paddle/fluid/tests/unittests/test_print_op.py index 21bf3f0b28..5029822e85 100644 --- a/python/paddle/fluid/tests/unittests/test_print_op.py +++ b/python/paddle/fluid/tests/unittests/test_print_op.py @@ -81,6 +81,14 @@ class TestPrintOpCPU(unittest.TestCase): fetch_list=[loss], return_numpy=False) + def test_no_summarize(self): + switch_main_program(Program()) + printed = self.build_network(True, summarize=-1, print_phase='forward') + exe = Executor(self.place) + outs = exe.run(feed={'x': self.x_tensor}, + fetch_list=[printed], + return_numpy=False) + class TestPrintOpError(unittest.TestCase): def test_errors(self): -- GitLab