gather_op.cc 6.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/gather_op.h"
S
sneaxiy 已提交
16 17 18
#include <memory>
#include <string>
#include <vector>
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/ddim.h"
20
#include "paddle/fluid/framework/op_version_registry.h"
Z
zchen0211 已提交
21 22 23 24
namespace paddle {
namespace operators {

class GatherOp : public framework::OperatorWithKernel {
Z
zchen0211 已提交
25 26 27
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

28
  void InferShape(framework::InferShapeContext* ctx) const override {
29 30 31 32 33 34 35 36 37
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
                      platform::errors::InvalidArgument(
                          "Input(X) of GatherOp should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
                      platform::errors::InvalidArgument(
                          "Input(Index) of GatherOp should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
                      platform::errors::InvalidArgument(
                          "Output(Out) of GatherOp should not be null."));
38

Z
zchen0211 已提交
39
    auto index_dims = ctx->GetInputDim("Index");
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54

    if (index_dims.size() == 2) {
      PADDLE_ENFORCE_EQ(
          index_dims[1], 1,
          platform::errors::InvalidArgument(
              "The last dim of index should be 1 when it is 2D, but we get %d",
              index_dims[1]));
    } else {
      PADDLE_ENFORCE_EQ(
          index_dims.size(), 1,
          platform::errors::InvalidArgument(
              "The index should be 1D, when it is not 2D, but we get %d",
              index_dims.size()));
    }

Q
Qiao Longfei 已提交
55 56
    int batch_size = ctx->GetInputDim("Index")[0];
    framework::DDim output_dims(ctx->GetInputDim("X"));
Z
zchen0211 已提交
57
    output_dims[0] = batch_size;
Q
Qiao Longfei 已提交
58
    ctx->SetOutputDim("Out", output_dims);
S
ShenLiang 已提交
59
    ctx->ShareLoD("X", /*->*/ "Out");
Z
zchen0211 已提交
60
  }
Y
Yu Yang 已提交
61

62
 protected:
63
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
64
      const framework::ExecutionContext& ctx) const override {
65 66 67
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
Y
Yu Yang 已提交
68
  }
Z
zchen0211 已提交
69 70 71 72 73 74
};

class GatherGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

75
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
Qiao Longfei 已提交
76
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
Y
Yibing Liu 已提交
77
    ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
Z
zchen0211 已提交
78
  }
Y
Yu Yang 已提交
79

80
 protected:
81
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
82
      const framework::ExecutionContext& ctx) const override {
83 84 85
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.device_context());
Y
Yu Yang 已提交
86
  }
Z
zchen0211 已提交
87 88 89 90
};

class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
91
  void Make() override {
Z
zchen0211 已提交
92 93
    AddInput("X", "The source input of gather op");
    AddInput("Index", "The index input of gather op");
94 95 96
    AddInput("Axis",
             "The Tensor which contains the axis that we do gather operation.")
        .AsDispensable();
K
kexinzhao 已提交
97
    AddOutput("Out", "The output of gather op");
98 99 100 101 102 103 104
    AddAttr<bool>(
        "overwrite",
        "(bool, default: False) "
        "In backward process, calc the grad when has same index,"
        "If true, update the grad using the overwrite mode in same index,"
        "If false, using the accumulate mode in same index.")
        .SetDefault(true);
Z
zchen0211 已提交
105
    AddComment(R"DOC(
K
kexinzhao 已提交
106 107 108 109
Gather Operator.

$Out = X[Index]$

Y
Yibing Liu 已提交
110
Out is obtained by gathering entries of the outer-most dimension
K
kexinzhao 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123 124
of X indexed by Index and concatenate them together.

Example:

X = [[1, 2],
     [3, 4],
     [5, 6]]

Index = [[1, 2]]

Then:

Out = [[3, 4],
       [5, 6]]
Z
zchen0211 已提交
125 126 127 128

)DOC");
  }
};
S
sneaxiy 已提交
129

H
hong 已提交
130 131
template <typename T>
class GatherGradOpMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
132
 public:
H
hong 已提交
133
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
134 135

 protected:
136
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
137
    op->SetType("gather_grad");
H
hong 已提交
138
    op->SetInput("Index", this->Input("Index"));
139 140
    op->SetInput("Axis", this->Input("Axis"));

H
hong 已提交
141 142 143 144
    op->SetInput("X", this->Input("X"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
S
sneaxiy 已提交
145 146 147
  }
};

148
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X");
S
sneaxiy 已提交
149

Z
zchen0211 已提交
150 151 152 153
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
154
REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
H
hong 已提交
155 156
                  ops::GatherGradOpMaker<paddle::framework::OpDesc>,
                  ops::GatherGradOpMaker<paddle::imperative::OpBase>);
S
sneaxiy 已提交
157
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
158
                  ops::GatherGradNoNeedBufferVarInferer);
159
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
160
                       ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
161
                       ops::GatherOpKernel<uint8_t>,
162
                       ops::GatherOpKernel<int64_t>);
163
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
164
                       ops::GatherGradientOpKernel<double>,
165
                       ops::GatherGradientOpKernel<int>,
166
                       ops::GatherGradientOpKernel<uint8_t>,
167
                       ops::GatherGradientOpKernel<int64_t>);
168 169 170 171
REGISTER_OP_VERSION(gather)
    .AddCheckpoint(R"ROC(upgrad gather, add attribut [axis])ROC",
                   paddle::framework::compatible::OpVersionDesc().NewAttr(
                       "axis", "Specify the axis of gather operation.", {}));