transfer_layout_op.cc 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/transfer_layout_op.h"

#include <string>

namespace paddle {
namespace framework {
class OpDesc;
class InferShapeContext;
template <typename T>
class EmptyGradOpMaker;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
}  // namespace paddle

namespace paddle {
namespace operators {

class TransferLayoutOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "TransferLayout");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "TransferLayout");

    auto dst_layout = ctx->Attrs().Get<int>("dst_layout");
43
    auto low_bound = static_cast<int>(framework::DataLayout::kAnyLayout);
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    auto upper_bound = static_cast<int>(framework::DataLayout::kMKLDNN);
    PADDLE_ENFORCE_GE(
        dst_layout, low_bound,
        platform::errors::PreconditionNotMet(
            "Required dst_layout >= %d, but received dst_layout = %d",
            low_bound, dst_layout));
    PADDLE_ENFORCE_LE(
        dst_layout, upper_bound,
        platform::errors::PreconditionNotMet(
            "Required dst_layout <= %d, but received dst_layout = %d",
            upper_bound, dst_layout));

    // TODO(Aurelius84): Out's ddim is different with X because they have
    // different layout
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    // kernel's device type is decided by input tensor place
    auto *in = ctx.InputVar("X");
    auto *in_tensor = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in);
    PADDLE_ENFORCE_EQ(in_tensor->IsInitialized(), true,
                      platform::errors::PreconditionNotMet(
                          "The tensor of Input(X) is not initialized."));
    // dtype is not important
    return framework::OpKernelType(framework::proto::VarType::FP32,
                                   in_tensor->place());
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string &var_name, const framework::Tensor &tensor,
      const framework::OpKernelType &expected_kernel_type) const override {
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(),
                                   expected_kernel_type.data_layout_);
  }
};

class TransferLayoutInferVarType : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    ctx->SyncTypeAndDataType("X", "Out");
  }
};

class TransferLayoutKernel {
 public:
  void operator()(const framework::ExecutionContext &ctx) const {
    auto *x = ctx.InputVar("X");
    auto *out = ctx.OutputVar("Out");
    auto &dev_ctx = ctx.device_context();
    auto dst_layout = ctx.Attr<int>("dst_layout");
    TransferLayoutFunctor(x, out, dev_ctx, dst_layout)();
  }
};

class TransferLayoutOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(LoDTensor) The input Tensor");
    AddOutput("Out", "(LoDTensor) The Output Tensor with desired layout");
    AddAttr<int>("dst_layout",
109
                 "kAnyLayout = 0, kNHWC = 1, kNCHW = 2, kMKLDNN = 3");
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
    AddComment(R"DOC(
    TransferLayout Operator)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
    transfer_layout, ops::TransferLayoutOp, ops::TransferLayoutOpProtoMaker,
    ops::TransferLayoutInferVarType,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);

// dtype is not important
REGISTER_OP_CPU_KERNEL_FUNCTOR(transfer_layout, float,
                               ops::TransferLayoutKernel);