transfer_layout_op.cc 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/transfer_layout_op.h"

#include <string>

19 20
#include "paddle/fluid/framework/op_version_registry.h"

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
namespace paddle {
namespace framework {
class OpDesc;
class InferShapeContext;
template <typename T>
class EmptyGradOpMaker;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
}  // namespace paddle

namespace paddle {
namespace operators {

class TransferLayoutOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "TransferLayout");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "TransferLayout");

    auto dst_layout = ctx->Attrs().Get<int>("dst_layout");
45
    auto low_bound = static_cast<int>(framework::DataLayout::kAnyLayout);
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    auto upper_bound = static_cast<int>(framework::DataLayout::kMKLDNN);
    PADDLE_ENFORCE_GE(
        dst_layout, low_bound,
        platform::errors::PreconditionNotMet(
            "Required dst_layout >= %d, but received dst_layout = %d",
            low_bound, dst_layout));
    PADDLE_ENFORCE_LE(
        dst_layout, upper_bound,
        platform::errors::PreconditionNotMet(
            "Required dst_layout <= %d, but received dst_layout = %d",
            upper_bound, dst_layout));

    // TODO(Aurelius84): Out's ddim is different with X because they have
    // different layout
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    // kernel's device type is decided by input tensor place
    auto *in = ctx.InputVar("X");
    auto *in_tensor = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in);
L
Leo Chen 已提交
70 71 72 73 74 75 76 77 78 79
    // NOTE(zhiqiu): hot fix, allow empty tensor of kMKLDNN layout to run this
    // op
    if (in_tensor->layout() != DataLayout::kMKLDNN) {
      PADDLE_ENFORCE_EQ(in_tensor->IsInitialized(), true,
                        platform::errors::PreconditionNotMet(
                            "The tensor of Input(X) is not initialized."));
    }
    auto place =
        in_tensor->IsInitialized() ? in_tensor->place() : platform::CPUPlace();

80
    // dtype is not important
L
Leo Chen 已提交
81
    return framework::OpKernelType(framework::proto::VarType::FP32, place);
82 83 84 85 86 87
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string &var_name, const framework::Tensor &tensor,
      const framework::OpKernelType &expected_kernel_type) const override {
    return framework::OpKernelType(expected_kernel_type.data_type_,
L
Leo Chen 已提交
88
                                   expected_kernel_type.place_,
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
                                   expected_kernel_type.data_layout_);
  }
};

class TransferLayoutInferVarType : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    ctx->SyncTypeAndDataType("X", "Out");
  }
};

class TransferLayoutKernel {
 public:
  void operator()(const framework::ExecutionContext &ctx) const {
    auto *x = ctx.InputVar("X");
    auto *out = ctx.OutputVar("Out");
    auto &dev_ctx = ctx.device_context();
106
    auto src_layout = ctx.Attr<int>("src_layout");
107
    auto dst_layout = ctx.Attr<int>("dst_layout");
L
Leo Chen 已提交
108 109 110
    auto input_name = ctx.InputName("X");
    TransferLayoutFunctor(x, out, dev_ctx, src_layout, dst_layout,
                          input_name)();
111 112 113 114 115 116 117 118
  }
};

class TransferLayoutOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(LoDTensor) The input Tensor");
    AddOutput("Out", "(LoDTensor) The Output Tensor with desired layout");
119 120 121 122 123 124 125 126
    // NOTE(zhiqiu): in most case, the src_layout is not needed, the op can use
    // the layout
    // of input X. However, in some mkldnn kernel, the src layout computed by
    // GetKernelTypeForVar is different with the layout of tensor X.
    AddAttr<int>("src_layout",
                 "kAnyLayout = 0, kNHWC = 1, kNCHW = 2, kMKLDNN = 3, default "
                 "-1 means unspecified and use the tensor's layout.")
        .SetDefault(-1);
127
    AddAttr<int>("dst_layout",
128
                 "kAnyLayout = 0, kNHWC = 1, kNCHW = 2, kMKLDNN = 3");
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    AddComment(R"DOC(
    TransferLayout Operator)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
    transfer_layout, ops::TransferLayoutOp, ops::TransferLayoutOpProtoMaker,
    ops::TransferLayoutInferVarType,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);

// dtype is not important
REGISTER_OP_CPU_KERNEL_FUNCTOR(transfer_layout, float,
                               ops::TransferLayoutKernel);
148
REGISTER_OP_VERSION(transfer_layout)
149 150 151 152
    .AddCheckpoint(R"ROC(refine transfer_layout, add src_layout attribute)ROC",
                   paddle::framework::compatible::OpVersionDesc().NewAttr(
                       "src_layout", "(int, the layout of the input tensor",
                       -1));