From a091d1a31cf3268e8d582f85d0640c42ed1c7150 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 15 Jan 2018 12:04:42 +0800 Subject: [PATCH] Enhance print_op. --- paddle/operators/print_op.cc | 133 ++++++++++++++---- python/paddle/v2/fluid/layers/control_flow.py | 34 +++-- python/paddle/v2/fluid/tests/test_print_op.py | 54 +++++-- 3 files changed, 169 insertions(+), 52 deletions(-) diff --git a/paddle/operators/print_op.cc b/paddle/operators/print_op.cc index 89e41d806c..8b233d64c9 100644 --- a/paddle/operators/print_op.cc +++ b/paddle/operators/print_op.cc @@ -16,12 +16,17 @@ #include #include "paddle/framework/op_registry.h" +#include "paddle/framework/variable.h" namespace paddle { namespace operators { #define CLOG std::cout +const std::string kForward = "FORWARD"; +const std::string kBackward = "BACKWARD"; +const std::string kBoth = "BOTH"; + struct Formater { std::string message; std::string name; @@ -122,40 +127,77 @@ class TensorPrintOp : public framework::OperatorBase { TensorPrintOp(const TensorPrintOp& o) : framework::OperatorBase( static_cast(o)) { - PADDLE_THROW("Not implemented"); + PADDLE_THROW("Not implemented."); } void Run(const framework::Scope& scope, const platform::Place& place) const override { - // Only run the `first_n` times. + const framework::Variable* in_var_ptr = nullptr; + std::string phase = kForward; + std::string printed_var_name = ""; + + auto& inputs = Inputs(); + if (inputs.find("In") != inputs.end() && !Inputs("In").empty()) { + in_var_ptr = scope.FindVar(Input("In")); + printed_var_name = Inputs("In").front(); + } else if (inputs.find("In@GRAD") != inputs.end() && + !Inputs("In@GRAD").empty()) { + in_var_ptr = scope.FindVar(Input("In@GRAD")); + printed_var_name = Inputs("In@GRAD").front(); + phase = kBackward; + } else { + PADDLE_THROW("Unknown phase, should be forward or backward."); + } + + PADDLE_ENFORCE_NOT_NULL(in_var_ptr); + + auto& in_tensor = in_var_ptr->Get(); + auto* out_var_ptr = scope.FindVar(Output("Out")); + auto& out_tensor = *out_var_ptr->GetMutable(); + + // Just copy data from input tensor to output tensor + // output tensor share same memory with input tensor + out_tensor.ShareDataWith(in_tensor); + out_tensor.set_lod(in_tensor.lod()); + + std::string print_phase = Attr("print_phase"); + if (print_phase != phase && print_phase != kBoth) { + return; + } + int first_n = Attr("first_n"); if (first_n > 0 && ++times_ > first_n) return; - PADDLE_ENFORCE(!Inputs("input").empty(), "input should be set"); - auto* input_var = scope.FindVar(Input("input")); - PADDLE_ENFORCE_NOT_NULL(input_var); - auto& tensor = input_var->Get(); + framework::LoDTensor printed_tensor; + printed_tensor.set_lod(in_tensor.lod()); + printed_tensor.Resize(in_tensor.dims()); - // TODO(ChunweiYan) support GPU - PADDLE_ENFORCE(platform::is_cpu_place(tensor.place())); + if (platform::is_cpu_place(in_tensor.place())) { + printed_tensor.ShareDataWith(in_tensor); + } else { + // copy data to cpu to print + platform::CPUPlace place; + framework::Copy(in_tensor, place, &printed_tensor); + } Formater formater; if (Attr("print_tensor_name")) { - formater.name = Inputs("input").front(); + formater.name = printed_var_name; } if (Attr("print_tensor_type")) { - formater.dtype = tensor.type(); + formater.dtype = printed_tensor.type(); } if (Attr("print_tensor_shape")) { - formater.dims.assign(tensor.dims()[0], - tensor.dims()[tensor.dims().size() - 1]); + auto& dims = printed_tensor.dims(); + formater.dims.resize(dims.size()); + for (int i = 0; i < dims.size(); ++i) formater.dims[i] = dims[i]; } if (Attr("print_tensor_lod")) { - formater.lod = tensor.lod(); + formater.lod = printed_tensor.lod(); } formater.summarize = Attr("summarize"); - formater.data = (void*)tensor.data(); - formater(tensor.numel()); + formater.data = (void*)printed_tensor.data(); + formater(printed_tensor.numel()); } private: @@ -166,27 +208,46 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker { public: PrintOpProtoAndCheckMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "the tensor that will be displayed."); + AddInput("In", "Input tensor to be displayed."); AddAttr("first_n", "Only log `first_n` number of times."); AddAttr("message", "A string message to print as a prefix."); - AddAttr("summarize", "Print this number of elements in the tensor."); + AddAttr("summarize", "Number of elements printed."); AddAttr("print_tensor_name", "Whether to print the tensor name."); AddAttr("print_tensor_type", "Whether to print the tensor's dtype."); AddAttr("print_tensor_shape", "Whether to print the tensor's shape."); AddAttr("print_tensor_lod", "Whether to print the tensor's lod."); + AddAttr( + "print_phase", + "(string, default 'BOTH') Which phase to display including 'FORWARD' " + "'BACKWARD' and 'BOTH'.") + .SetDefault(kBoth) + .InEnum({kForward, kBackward, kBoth}); + AddOutput("Out", "Output tensor with same data as input tensor."); AddComment(R"DOC( - Creates a print op that will print when a tensor is accessed. +Creates a print op that will print when a tensor is accessed. - Wraps the tensor passed in so that whenever that a tensor is accessed, - the message `message` is printed, along with the current value of the - tensor `t`.)DOC"); +Wraps the tensor passed in so that whenever that a tensor is accessed, +the message `message` is printed, along with the current value of the +tensor `t`.)DOC"); } }; -class InferShape : public framework::InferShapeBase { +class InferShapeForward : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* context) const override { - PADDLE_ENFORCE(context->HasInput("input"), "input should be set"); + PADDLE_ENFORCE(context->HasInput("In"), "Input(In) should not be null."); + context->ShareLoD("In", /*->*/ "Out"); + context->SetOutputDim("Out", context->GetInputDim("In")); + } +}; + +class InferShapeBackward : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* context) const override { + PADDLE_ENFORCE(context->HasInput("In@GRAD"), + "Input(In@GRAD) should not be null."); + context->ShareLoD("In@GRAD", /*->*/ "Out"); + context->SetOutputDim("Out", context->GetInputDim("In@GRAD")); } }; @@ -196,11 +257,27 @@ class InferVarType : public framework::VarTypeInference { framework::BlockDesc* block) const override {} }; +class PrintOpProtoAndCheckGradOpMaker + : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + std::unique_ptr Apply() const override { + auto* op_desc_ptr = new framework::OpDesc(); + op_desc_ptr->SetType("print_grad"); + op_desc_ptr->SetInput("In@GRAD", OutputGrad("Out")); + op_desc_ptr->SetOutput("Out", InputGrad("In")); + op_desc_ptr->SetAttrMap(Attrs()); + return std::unique_ptr(op_desc_ptr); + } +}; + } // namespace operators } // namespace paddle -REGISTER_OPERATOR(print, paddle::operators::TensorPrintOp, - paddle::operators::PrintOpProtoAndCheckMaker, - paddle::operators::InferShape, - paddle::operators::InferVarType, - paddle::framework::EmptyGradOpMaker); +namespace ops = paddle::operators; + +REGISTER_OPERATOR(print, ops::TensorPrintOp, ops::PrintOpProtoAndCheckMaker, + ops::PrintOpProtoAndCheckGradOpMaker, ops::InferShapeForward, + ops::InferVarType); +REGISTER_OPERATOR(print_grad, ops::TensorPrintOp, ops::InferShapeBackward); diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index bef9602bb7..ee97e5f4e6 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -117,7 +117,8 @@ def Print(input, print_tensor_name=True, print_tensor_type=True, print_tensor_shape=True, - print_tensor_lod=True): + print_tensor_lod=True, + print_phase='both'): ''' **Print operator** @@ -128,18 +129,21 @@ def Print(input, tensor `t`. Args: - input(Variable): A Tensor to print. - summarize(int): Print this number of elements in the tensor, will print all - if left negative. - message(str): A string message to print as a prefix. - first_n(int): Only log `first_n` number of times. - print_tensor_name(bool): Print the tensor name. - print_tensor_type(bool): Print the tensor type. - print_tensor_shape(bool): Print the tensor shape. - print_tensor_lod(bool): Print the tensor lod. + input (Variable): A Tensor to print. + summarize (int): Print this number of elements in the tensor, will print + all if left is negative. + message (str): A string message to print as a prefix. + first_n (int): Only log `first_n` number of times. + print_tensor_name (bool): Print the tensor name. + print_tensor_type (bool): Print the tensor type. + print_tensor_shape (bool): Print the tensor shape. + print_tensor_lod (bool): Print the tensor lod. + print_phase (bool): Which phase to displace, including 'forward', + 'backward' and 'both'. If set to 'backward' or 'both', will + print the gradients of input tensor. Returns: - None + Variable: Output tensor, same data with input tensor. Examples: .. code-block:: python @@ -149,10 +153,10 @@ def Print(input, message="The content of some_layer: ") ''' helper = LayerHelper('print', **locals()) - out = helper.create_tmp_variable(dtype='int32') + out = helper.create_tmp_variable(dtype=helper.input_dtype()) helper.append_op( type='print', - inputs={'input': input}, + inputs={'In': input}, attrs={ 'first_n': first_n, 'summarize': summarize, @@ -161,7 +165,9 @@ def Print(input, 'print_tensor_type': print_tensor_type, 'print_tensor_shape': print_tensor_shape, 'print_tensor_lod': print_tensor_lod, - }) + 'print_phase': print_phase.upper() + }, + outputs={'Out': out}) return out diff --git a/python/paddle/v2/fluid/tests/test_print_op.py b/python/paddle/v2/fluid/tests/test_print_op.py index 86a701a020..1550d0af5e 100644 --- a/python/paddle/v2/fluid/tests/test_print_op.py +++ b/python/paddle/v2/fluid/tests/test_print_op.py @@ -1,20 +1,54 @@ import unittest -import numpy as np -from paddle.v2.fluid.executor import Executor import paddle.v2.fluid.core as core -import paddle.v2.fluid.layers as pd +from paddle.v2.fluid.executor import Executor +import paddle.v2.fluid.layers as layers +from paddle.v2.fluid.backward import append_backward +from paddle.v2.fluid.framework import switch_main_program +from paddle.v2.fluid.framework import Program +import numpy as np + + +class TestPrintOpCPU(unittest.TestCase): + def setUp(self): + self.place = core.CPUPlace() + self.x_tensor = core.LoDTensor() + tensor_np = np.random.random(size=(2, 3)).astype('float32') + self.x_tensor.set(tensor_np, self.place) + self.x_tensor.set_lod([[0, 1, 1]]) + def build_network(self, only_forward, **kargs): + x = layers.data('x', shape=[3], dtype='float32', lod_level=1) + x.stop_gradient = False + printed = layers.Print(input=x, **kargs) + if only_forward: return printed + loss = layers.mean(x=printed) + append_backward(loss=loss) + return loss -class TestSumOp(unittest.TestCase): - def test_tensor(self): - i = pd.zeros(shape=[2, 10], dtype='float32') + def test_forward(self): + switch_main_program(Program()) + printed = self.build_network(True, print_phase='forward') + exe = Executor(self.place) + outs = exe.run(feed={'x': self.x_tensor}, + fetch_list=[printed], + return_numpy=False) - pd.Print(i, message="I am a message", summarize=10) + def test_backward(self): + switch_main_program(Program()) + loss = self.build_network(False, print_phase='backward') + exe = Executor(self.place) + outs = exe.run(feed={'x': self.x_tensor}, + fetch_list=[loss], + return_numpy=False) - cpu = core.CPUPlace() - exe = Executor(cpu) - exe.run() +class TestPrintOpGPU(TestPrintOpCPU): + def setUp(self): + self.place = core.CUDAPlace(0) + self.x_tensor = core.LoDTensor() + tensor_np = np.random.random(size=(2, 3)).astype('float32') + self.x_tensor.set(tensor_np, self.place) + self.x_tensor.set_lod([[0, 1, 1]]) if __name__ == '__main__': -- GitLab