rank_attention_op.cc 7.8 KB
Newer Older
S
ShenLiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/rank_attention_op.h"
#include <memory>
#include <string>
#include <vector>
16
#include "paddle/fluid/framework/op_version_registry.h"
S
ShenLiang 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;

class RankAttentionOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
                      platform::errors::InvalidArgument(
                          "Input(X) of RankAttentionOp should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("RankOffset"), true,
        platform::errors::InvalidArgument(
            "Input(RankOffset) of RankAttentionOp should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("RankParam"), true,
        platform::errors::InvalidArgument(
            "Input(RankParam) of RankAttentionOp should not be null."));
38 39 40 41 42 43 44 45
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("InsRank"), true,
        platform::errors::InvalidArgument(
            "Output(InsRank) of RankAttentionOp should not be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("InputHelp"), true,
        platform::errors::InvalidArgument(
            "Output(InputHelp) of RankAttentionOp should not be null."));
S
ShenLiang 已提交
46 47 48 49 50 51 52 53 54 55 56
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("Out"), true,
        platform::errors::InvalidArgument(
            "Output(Out) of RankAttentionOp should not be null."));
    auto max_rank = ctx->Attrs().Get<int>("MaxRank");

    auto x_dims = ctx->GetInputDim("X");
    auto ins_num = x_dims[0];
    auto param_dims = ctx->GetInputDim("RankParam");
    auto para_col = param_dims[1];
    auto rank_offset_dims = ctx->GetInputDim("RankOffset");
57 58
    auto x_fea_dim = x_dims[1];
    auto block_matrix_row = max_rank * x_fea_dim;
S
ShenLiang 已提交
59 60 61

    PADDLE_ENFORCE_EQ((rank_offset_dims[1] - 1) / 2, max_rank,
                      platform::errors::InvalidArgument(
S
ShenLiang 已提交
62 63 64
                          "Input(RankOffset) has wrong columns, "
                          "except columns to be %d, but got %d",
                          max_rank, (rank_offset_dims[1] - 1) / 2));
S
ShenLiang 已提交
65 66

    ctx->SetOutputDim("Out", {ins_num, para_col});
67 68
    ctx->SetOutputDim("InputHelp", {ins_num, block_matrix_row});
    ctx->SetOutputDim("InsRank", {ins_num, 1});
S
ShenLiang 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
  }
};

class RankAttentionGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("X"), true,
        platform::errors::InvalidArgument("Input(X) should not be null"));
    PADDLE_ENFORCE_EQ(ctx->HasInput("RankParam"), true,
                      platform::errors::InvalidArgument(
                          "Input(RankParam) should not be null"));
    PADDLE_ENFORCE_EQ(ctx->HasInput("RankOffset"), true,
                      platform::errors::InvalidArgument(
                          "Input(RankOffset) should not be null"));
95 96 97 98 99 100
    PADDLE_ENFORCE_EQ(ctx->HasInput("InputHelp"), true,
                      platform::errors::InvalidArgument(
                          "Input(InputHelp) should not be null"));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("InsRank"), true,
        platform::errors::InvalidArgument("Input(InsRank) should not be null"));
S
ShenLiang 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122

    ctx->SetOutputDim(framework::GradVarName("RankParam"),
                      ctx->GetInputDim("RankParam"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.device_context());
  }
};

class RankAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input tensor of rank_attention_Op operator.");
    AddInput("RankOffset",
             "(Tensor) Input tensor of rank_attention_Op operator.");
    AddInput("RankParam",
             "(Tensor) Input tensor of rank_attention_Op operator.");
123 124
    AddOutput("InputHelp", "Output tensor of rank_attention_Op operator.")
        .AsDispensable();
S
ShenLiang 已提交
125
    AddOutput("Out", "Output tensor of rank_attention_Op operator.");
126 127
    AddOutput("InsRank", "Output tensor of rank_attention_Op operator.")
        .AsDispensable();
S
ShenLiang 已提交
128 129
    AddAttr<int>("MaxRank", "(int, default 3) max rank of rank_attention_Op")
        .SetDefault(3);
130 131
    AddAttr<int>("MaxSize", "(int, default 0) max rank of rank_attention_Op")
        .SetDefault(0);
S
ShenLiang 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
    AddComment(R"DOC(
RankAttention Operator.
This Op can calculate rank attention between input and rank_param, 
and rank_param gives the organization of data. Notice: It currently supports GPU device.
This Op exists in contrib, which means that it is not shown to the public.
)DOC");
  }
};

template <typename T>
class RankAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("rank_attention_grad");

    op->SetInput("X", this->Input("X"));
    op->SetInput("RankOffset", this->Input("RankOffset"));
    op->SetInput("RankParam", this->Input("RankParam"));
153
    op->SetInput("InputHelp", this->Output("InputHelp"));
S
ShenLiang 已提交
154
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
155
    op->SetInput("InsRank", this->Output("InsRank"));
S
ShenLiang 已提交
156 157 158 159 160 161 162

    op->SetOutput(framework::GradVarName("RankParam"),
                  this->InputGrad("RankParam"));
    op->SetAttrMap(this->Attrs());
  }
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
163 164
    RankAttentionGradOpNoNeedBufferVarsInference, "X", "RankOffset",
    "RankParam");
S
ShenLiang 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(rank_attention, ops::RankAttentionOp,
                  ops::RankAttentionOpMaker,
                  ops::RankAttentionGradOpMaker<paddle::framework::OpDesc>,
                  ops::RankAttentionGradOpMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(rank_attention_grad, ops::RankAttentionGradOp,
                  ops::RankAttentionGradOpNoNeedBufferVarsInference);

REGISTER_OP_CPU_KERNEL(
    rank_attention,
    ops::RankAttentionKernel<paddle::platform::CPUDeviceContext, float>,
    ops::RankAttentionKernel<paddle::platform::CPUDeviceContext, double>);
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196

REGISTER_OP_VERSION(rank_attention)
    .AddCheckpoint(
        R"ROC(
        Upgrade rank_attention, add 1 outputs [InputHelp] and 1 attribute
        [MaxSize].
      )ROC",
        paddle::framework::compatible::OpVersionDesc()
            .NewOutput("InputHelp",
                       "Output tensor of rank_attention_Op operator "
                       "in order to assist calculation in the reverse process.")
            .NewAttr(
                "MaxSize",
                "Forward calculation to set the pre-applied video memory size",
                0));