print_op.cc 6.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yan Chunwei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/framework/op_version_registry.h"
H
Huihuang Zheng 已提交
17
#include "paddle/fluid/operators/tensor_formatter.h"
Y
Yan Chunwei 已提交
18

W
wanghuancoder 已提交
19 20 21
namespace paddle {
namespace framework {
class InferShapeContext;
22
class Tensor;
W
wanghuancoder 已提交
23 24 25 26 27 28 29 30
class OpDesc;
class Scope;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
}  // namespace paddle

Y
Yan Chunwei 已提交
31 32
namespace paddle {
namespace operators {
Y
Yu Yang 已提交
33
using framework::GradVarName;
Y
Yan Chunwei 已提交
34 35 36

#define CLOG std::cout

37 38 39
const char kForward[] = "FORWARD";
const char kBackward[] = "BACKWARD";
const char kBoth[] = "BOTH";
Y
yangyaming 已提交
40

Y
Yan Chunwei 已提交
41
// TODO(ChunweiYan) there should be some other printers for TensorArray
42
class PrintOp : public framework::OperatorBase {
Y
Yan Chunwei 已提交
43
 public:
44 45 46
  PrintOp(const std::string &type, const framework::VariableNameMap &inputs,
          const framework::VariableNameMap &outputs,
          const framework::AttributeMap &attrs)
Y
Yan Chunwei 已提交
47 48
      : OperatorBase(type, inputs, outputs, attrs) {}

49
 private:
Y
Yu Yang 已提交
50 51
  void RunImpl(const framework::Scope &scope,
               const platform::Place &place) const override {
52 53
    const auto in_var = scope.FindVar(Input("In"));
    auto out_var = scope.FindVar(Output("Out"));
54 55 56 57 58 59 60 61

    PADDLE_ENFORCE_NOT_NULL(
        in_var, platform::errors::NotFound("The input:%s not found in scope",
                                           Input("In")));
    PADDLE_ENFORCE_NOT_NULL(
        out_var, platform::errors::NotFound("The output:%s not found in scope",
                                            Output("Out")));

62 63 64 65 66 67 68 69
    auto &in_tensor = in_var->Get<framework::LoDTensor>();
    framework::LoDTensor *out_tensor =
        out_var->GetMutable<framework::LoDTensor>();

    PrintValue(place, Inputs("In").front(), in_tensor);
    framework::TensorCopy(in_tensor, place, out_tensor);
    out_tensor->set_lod(in_tensor.lod());
  }
Y
yangyaming 已提交
70

71 72 73
  void PrintValue(const platform::Place &place,
                  const std::string &printed_var_name,
                  const framework::LoDTensor &in_tensor) const {
Y
yangyaming 已提交
74
    std::string print_phase = Attr<std::string>("print_phase");
Y
Yu Yang 已提交
75 76 77 78
    bool is_forward = Attr<bool>("is_forward");

    if ((is_forward && print_phase == kBackward) ||
        (!is_forward && print_phase == kForward)) {
Y
yangyaming 已提交
79 80 81
      return;
    }

Y
Yan Chunwei 已提交
82 83 84
    int first_n = Attr<int>("first_n");
    if (first_n > 0 && ++times_ > first_n) return;

H
Huihuang Zheng 已提交
85 86 87 88 89 90 91 92
    TensorFormatter formatter;
    const std::string &name =
        Attr<bool>("print_tensor_name") ? printed_var_name : "";
    formatter.SetPrintTensorType(Attr<bool>("print_tensor_type"));
    formatter.SetPrintTensorShape(Attr<bool>("print_tensor_shape"));
    formatter.SetPrintTensorLod(Attr<bool>("print_tensor_lod"));
    formatter.SetPrintTensorLayout(Attr<bool>("print_tensor_layout"));
    formatter.SetSummarize(static_cast<int64_t>(Attr<int>("summarize")));
93
    formatter.Print(in_tensor, name, Attr<std::string>("message"));
Y
Yan Chunwei 已提交
94 95 96 97 98 99 100 101
  }

 private:
  mutable int times_{0};
};

class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
102
  void Make() override {
Y
yangyaming 已提交
103
    AddInput("In", "Input tensor to be displayed.");
104
    AddOutput("Out", "The output tensor.");
Y
Yan Chunwei 已提交
105 106
    AddAttr<int>("first_n", "Only log `first_n` number of times.");
    AddAttr<std::string>("message", "A string message to print as a prefix.");
Y
yangyaming 已提交
107
    AddAttr<int>("summarize", "Number of elements printed.");
108 109 110 111 112 113 114 115 116 117 118
    AddAttr<bool>("print_tensor_name", "Whether to print the tensor name.")
        .SetDefault(true);
    AddAttr<bool>("print_tensor_type", "Whether to print the tensor's dtype.")
        .SetDefault(true);
    AddAttr<bool>("print_tensor_shape", "Whether to print the tensor's shape.")
        .SetDefault(true);
    AddAttr<bool>("print_tensor_layout",
                  "Whether to print the tensor's layout.")
        .SetDefault(true);
    AddAttr<bool>("print_tensor_lod", "Whether to print the tensor's lod.")
        .SetDefault(true);
Y
Yu Yang 已提交
119 120 121 122
    AddAttr<std::string>("print_phase",
                         "(string, default 'FORWARD') Which phase to display "
                         "including 'FORWARD' "
                         "'BACKWARD' and 'BOTH'.")
123 124 125
        .SetDefault(std::string(kBoth))
        .InEnum({std::string(kForward), std::string(kBackward),
                 std::string(kBoth)});
Y
Yu Yang 已提交
126
    AddAttr<bool>("is_forward", "Whether is forward or not").SetDefault(true);
Y
Yan Chunwei 已提交
127
    AddComment(R"DOC(
Y
yangyaming 已提交
128
Creates a print op that will print when a tensor is accessed.
Y
Yan Chunwei 已提交
129

Y
yangyaming 已提交
130 131 132
Wraps the tensor passed in so that whenever that a tensor is accessed,
the message `message` is printed, along with the current value of the
tensor `t`.)DOC");
Y
Yan Chunwei 已提交
133 134 135
  }
};

136 137 138 139
class PrintOpInferShape : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext *ctx) const override {
    VLOG(10) << "PrintOpInferShape";
140 141
    OP_INOUT_CHECK(ctx->HasInput("In"), "Input", "In", "Print");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Print");
142 143 144 145 146 147
    ctx->ShareDim("In", /*->*/ "Out");
    ctx->ShareLoD("In", /*->*/ "Out");
  }
};

class PrintOpVarTypeInference : public framework::VarTypeInference {
Y
Yan Chunwei 已提交
148
 public:
149
  void operator()(framework::InferVarTypeContext *ctx) const override {
150
    ctx->SetOutputType("Out", ctx->GetInputType("In"));
Y
Yan Chunwei 已提交
151 152 153
  }
};

H
hong 已提交
154 155
template <typename T>
class PrintOpGradientMaker : public framework::SingleGradOpMaker<T> {
Y
yangyaming 已提交
156
 public:
H
hong 已提交
157
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
Y
yangyaming 已提交
158

159
  void Apply(GradOpPtr<T> op_desc_ptr) const override {
Y
Yu Yang 已提交
160
    op_desc_ptr->SetType("print");
H
hong 已提交
161 162 163
    op_desc_ptr->SetInput("In", this->OutputGrad("Out"));
    op_desc_ptr->SetOutput("Out", this->InputGrad("In"));
    op_desc_ptr->SetAttrMap(this->Attrs());
Y
Yu Yang 已提交
164
    op_desc_ptr->SetAttr("is_forward", false);
Y
yangyaming 已提交
165 166 167
  }
};

Y
Yan Chunwei 已提交
168 169 170
}  // namespace operators
}  // namespace paddle

Y
yangyaming 已提交
171 172
namespace ops = paddle::operators;

173
REGISTER_OPERATOR(print, ops::PrintOp, ops::PrintOpProtoAndCheckMaker,
H
hong 已提交
174 175 176
                  ops::PrintOpGradientMaker<paddle::framework::OpDesc>,
                  ops::PrintOpGradientMaker<paddle::imperative::OpBase>,
                  ops::PrintOpInferShape, ops::PrintOpVarTypeInference);
177 178 179 180 181 182 183 184

REGISTER_OP_VERSION(print)
    .AddCheckpoint(
        R"ROC(Upgrade print add a new attribute [print_tensor_layout] to "
             "contorl whether to print tensor's layout.)ROC",
        paddle::framework::compatible::OpVersionDesc().NewAttr(
            "print_tensor_layout", "Whether to print the tensor's layout.",
            true));