print_op.cc 6.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yan Chunwei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_registry.h"
H
Huihuang Zheng 已提交
16
#include "paddle/fluid/operators/tensor_formatter.h"
Y
Yan Chunwei 已提交
17

W
wanghuancoder 已提交
18 19 20 21 22 23 24 25 26 27 28 29
namespace paddle {
namespace framework {
class InferShapeContext;
class LoDTensor;
class OpDesc;
class Scope;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
}  // namespace paddle

Y
Yan Chunwei 已提交
30 31
namespace paddle {
namespace operators {
Y
Yu Yang 已提交
32
using framework::GradVarName;
Y
Yan Chunwei 已提交
33 34 35

#define CLOG std::cout

36 37 38
const char kForward[] = "FORWARD";
const char kBackward[] = "BACKWARD";
const char kBoth[] = "BOTH";
Y
yangyaming 已提交
39

Y
Yan Chunwei 已提交
40
// TODO(ChunweiYan) there should be some other printers for TensorArray
41
class PrintOp : public framework::OperatorBase {
Y
Yan Chunwei 已提交
42
 public:
43 44 45
  PrintOp(const std::string &type, const framework::VariableNameMap &inputs,
          const framework::VariableNameMap &outputs,
          const framework::AttributeMap &attrs)
Y
Yan Chunwei 已提交
46 47
      : OperatorBase(type, inputs, outputs, attrs) {}

48
 private:
Y
Yu Yang 已提交
49 50
  void RunImpl(const framework::Scope &scope,
               const platform::Place &place) const override {
51 52
    const auto in_var = scope.FindVar(Input("In"));
    auto out_var = scope.FindVar(Output("Out"));
53 54 55 56 57 58 59 60

    PADDLE_ENFORCE_NOT_NULL(
        in_var, platform::errors::NotFound("The input:%s not found in scope",
                                           Input("In")));
    PADDLE_ENFORCE_NOT_NULL(
        out_var, platform::errors::NotFound("The output:%s not found in scope",
                                            Output("Out")));

61 62 63 64 65 66 67 68
    auto &in_tensor = in_var->Get<framework::LoDTensor>();
    framework::LoDTensor *out_tensor =
        out_var->GetMutable<framework::LoDTensor>();

    PrintValue(place, Inputs("In").front(), in_tensor);
    framework::TensorCopy(in_tensor, place, out_tensor);
    out_tensor->set_lod(in_tensor.lod());
  }
Y
yangyaming 已提交
69

70 71 72
  void PrintValue(const platform::Place &place,
                  const std::string &printed_var_name,
                  const framework::LoDTensor &in_tensor) const {
Y
yangyaming 已提交
73
    std::string print_phase = Attr<std::string>("print_phase");
Y
Yu Yang 已提交
74 75 76 77
    bool is_forward = Attr<bool>("is_forward");

    if ((is_forward && print_phase == kBackward) ||
        (!is_forward && print_phase == kForward)) {
Y
yangyaming 已提交
78 79 80
      return;
    }

Y
Yan Chunwei 已提交
81 82 83
    int first_n = Attr<int>("first_n");
    if (first_n > 0 && ++times_ > first_n) return;

H
Huihuang Zheng 已提交
84 85 86 87 88 89 90 91
    TensorFormatter formatter;
    const std::string &name =
        Attr<bool>("print_tensor_name") ? printed_var_name : "";
    formatter.SetPrintTensorType(Attr<bool>("print_tensor_type"));
    formatter.SetPrintTensorShape(Attr<bool>("print_tensor_shape"));
    formatter.SetPrintTensorLod(Attr<bool>("print_tensor_lod"));
    formatter.SetPrintTensorLayout(Attr<bool>("print_tensor_layout"));
    formatter.SetSummarize(static_cast<int64_t>(Attr<int>("summarize")));
92
    formatter.Print(in_tensor, name, Attr<std::string>("message"));
Y
Yan Chunwei 已提交
93 94 95 96 97 98 99 100
  }

 private:
  mutable int times_{0};
};

class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
101
  void Make() override {
Y
yangyaming 已提交
102
    AddInput("In", "Input tensor to be displayed.");
103
    AddOutput("Out", "The output tensor.");
Y
Yan Chunwei 已提交
104 105
    AddAttr<int>("first_n", "Only log `first_n` number of times.");
    AddAttr<std::string>("message", "A string message to print as a prefix.");
Y
yangyaming 已提交
106
    AddAttr<int>("summarize", "Number of elements printed.");
107 108 109 110 111 112 113 114 115 116 117
    AddAttr<bool>("print_tensor_name", "Whether to print the tensor name.")
        .SetDefault(true);
    AddAttr<bool>("print_tensor_type", "Whether to print the tensor's dtype.")
        .SetDefault(true);
    AddAttr<bool>("print_tensor_shape", "Whether to print the tensor's shape.")
        .SetDefault(true);
    AddAttr<bool>("print_tensor_layout",
                  "Whether to print the tensor's layout.")
        .SetDefault(true);
    AddAttr<bool>("print_tensor_lod", "Whether to print the tensor's lod.")
        .SetDefault(true);
Y
Yu Yang 已提交
118 119 120 121
    AddAttr<std::string>("print_phase",
                         "(string, default 'FORWARD') Which phase to display "
                         "including 'FORWARD' "
                         "'BACKWARD' and 'BOTH'.")
122 123 124
        .SetDefault(std::string(kBoth))
        .InEnum({std::string(kForward), std::string(kBackward),
                 std::string(kBoth)});
Y
Yu Yang 已提交
125
    AddAttr<bool>("is_forward", "Whether is forward or not").SetDefault(true);
Y
Yan Chunwei 已提交
126
    AddComment(R"DOC(
Y
yangyaming 已提交
127
Creates a print op that will print when a tensor is accessed.
Y
Yan Chunwei 已提交
128

Y
yangyaming 已提交
129 130 131
Wraps the tensor passed in so that whenever that a tensor is accessed,
the message `message` is printed, along with the current value of the
tensor `t`.)DOC");
Y
Yan Chunwei 已提交
132 133 134
  }
};

135 136 137 138
class PrintOpInferShape : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext *ctx) const override {
    VLOG(10) << "PrintOpInferShape";
139 140
    OP_INOUT_CHECK(ctx->HasInput("In"), "Input", "In", "Print");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Print");
141 142 143 144 145 146
    ctx->ShareDim("In", /*->*/ "Out");
    ctx->ShareLoD("In", /*->*/ "Out");
  }
};

class PrintOpVarTypeInference : public framework::VarTypeInference {
Y
Yan Chunwei 已提交
147
 public:
148
  void operator()(framework::InferVarTypeContext *ctx) const override {
149
    ctx->SetOutputType("Out", ctx->GetInputType("In"));
Y
Yan Chunwei 已提交
150 151 152
  }
};

H
hong 已提交
153 154
template <typename T>
class PrintOpGradientMaker : public framework::SingleGradOpMaker<T> {
Y
yangyaming 已提交
155
 public:
H
hong 已提交
156
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
Y
yangyaming 已提交
157

158
  void Apply(GradOpPtr<T> op_desc_ptr) const override {
Y
Yu Yang 已提交
159
    op_desc_ptr->SetType("print");
H
hong 已提交
160 161 162
    op_desc_ptr->SetInput("In", this->OutputGrad("Out"));
    op_desc_ptr->SetOutput("Out", this->InputGrad("In"));
    op_desc_ptr->SetAttrMap(this->Attrs());
Y
Yu Yang 已提交
163
    op_desc_ptr->SetAttr("is_forward", false);
Y
yangyaming 已提交
164 165 166
  }
};

Y
Yan Chunwei 已提交
167 168 169
}  // namespace operators
}  // namespace paddle

Y
yangyaming 已提交
170 171
namespace ops = paddle::operators;

172
REGISTER_OPERATOR(print, ops::PrintOp, ops::PrintOpProtoAndCheckMaker,
H
hong 已提交
173 174 175
                  ops::PrintOpGradientMaker<paddle::framework::OpDesc>,
                  ops::PrintOpGradientMaker<paddle::imperative::OpBase>,
                  ops::PrintOpInferShape, ops::PrintOpVarTypeInference);