transfer_layout_op.cc 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/transfer_layout_op.h"

#include <string>

19 20
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
21
#include "paddle/fluid/framework/op_version_registry.h"
22 23
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
24

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
namespace paddle {
namespace framework {
class OpDesc;
class InferShapeContext;
template <typename T>
class EmptyGradOpMaker;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
}  // namespace paddle

namespace paddle {
namespace operators {

class TransferLayoutOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    // kernel's device type is decided by input tensor place
    auto *in = ctx.InputVar("X");
    auto *in_tensor = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in);
L
Leo Chen 已提交
50 51
    // NOTE(zhiqiu): hot fix, allow empty tensor of kMKLDNN layout to run this
    // op
52
    if (in_tensor->layout() != DataLayout::ONEDNN) {
53 54
      PADDLE_ENFORCE_EQ(in_tensor->IsInitialized(),
                        true,
L
Leo Chen 已提交
55 56 57 58 59 60
                        platform::errors::PreconditionNotMet(
                            "The tensor of Input(X) is not initialized."));
    }
    auto place =
        in_tensor->IsInitialized() ? in_tensor->place() : platform::CPUPlace();

61
    // dtype is not important
L
Leo Chen 已提交
62
    return framework::OpKernelType(framework::proto::VarType::FP32, place);
63 64 65
  }

  framework::OpKernelType GetKernelTypeForVar(
66
      const std::string &var_name,
67
      const phi::DenseTensor &tensor,
68
      const framework::OpKernelType &expected_kernel_type) const override {
69
    return expected_kernel_type;
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
  }
};

class TransferLayoutInferVarType : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    ctx->SyncTypeAndDataType("X", "Out");
  }
};

class TransferLayoutKernel {
 public:
  void operator()(const framework::ExecutionContext &ctx) const {
    auto *x = ctx.InputVar("X");
    auto *out = ctx.OutputVar("Out");
    auto &dev_ctx = ctx.device_context();
86
    auto src_layout = ctx.Attr<int>("src_layout");
87
    auto dst_layout = ctx.Attr<int>("dst_layout");
L
Leo Chen 已提交
88
    auto input_name = ctx.InputName("X");
89 90
    TransferLayoutFunctor(
        x, out, dev_ctx, src_layout, dst_layout, input_name)();
91 92 93 94 95 96
  }
};

class TransferLayoutOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
97 98 99
    AddInput("X", "(phi::DenseTensor) The input Tensor");
    AddOutput("Out",
              "(phi::DenseTensor) The Output Tensor with desired layout");
100 101 102 103 104 105 106 107
    // NOTE(zhiqiu): in most case, the src_layout is not needed, the op can use
    // the layout
    // of input X. However, in some mkldnn kernel, the src layout computed by
    // GetKernelTypeForVar is different with the layout of tensor X.
    AddAttr<int>("src_layout",
                 "kAnyLayout = 0, kNHWC = 1, kNCHW = 2, kMKLDNN = 3, default "
                 "-1 means unspecified and use the tensor's layout.")
        .SetDefault(-1);
108
    AddAttr<int>("dst_layout",
109
                 "kAnyLayout = 0, kNHWC = 1, kNCHW = 2, kMKLDNN = 3");
110 111 112 113 114 115 116 117 118 119
    AddComment(R"DOC(
    TransferLayout Operator)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
120 121 122
DECLARE_INFER_SHAPE_FUNCTOR(transfer_layout,
                            TransferLayoutInferShapeFunctor,
                            PD_INFER_META(phi::TransferLayoutInferMeta));
123
REGISTER_OPERATOR(
124 125 126
    transfer_layout,
    ops::TransferLayoutOp,
    ops::TransferLayoutOpProtoMaker,
127 128
    ops::TransferLayoutInferVarType,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
129 130
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
    TransferLayoutInferShapeFunctor);
131

132
REGISTER_OP_VERSION(transfer_layout)
133 134
    .AddCheckpoint(R"ROC(refine transfer_layout, add src_layout attribute)ROC",
                   paddle::framework::compatible::OpVersionDesc().NewAttr(
135 136
                       "src_layout",
                       "(int, the layout of the input tensor",
137
                       -1));