gather_op.cc 7.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/gather_op.h"
S
sneaxiy 已提交
16 17 18
#include <memory>
#include <string>
#include <vector>
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/ddim.h"
20
#include "paddle/fluid/framework/op_version_registry.h"
21

Z
zchen0211 已提交
22 23 24 25
namespace paddle {
namespace operators {

class GatherOp : public framework::OperatorWithKernel {
Z
zchen0211 已提交
26 27 28
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

29
  void InferShape(framework::InferShapeContext* ctx) const override {
30 31 32 33 34 35 36 37 38
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
                      platform::errors::InvalidArgument(
                          "Input(X) of GatherOp should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
                      platform::errors::InvalidArgument(
                          "Input(Index) of GatherOp should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
                      platform::errors::InvalidArgument(
                          "Output(Out) of GatherOp should not be null."));
39

Z
zchen0211 已提交
40
    auto index_dims = ctx->GetInputDim("Index");
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

    if (index_dims.size() == 2) {
      PADDLE_ENFORCE_EQ(
          index_dims[1], 1,
          platform::errors::InvalidArgument(
              "The last dim of index should be 1 when it is 2D, but we get %d",
              index_dims[1]));
    } else {
      PADDLE_ENFORCE_EQ(
          index_dims.size(), 1,
          platform::errors::InvalidArgument(
              "The index should be 1D, when it is not 2D, but we get %d",
              index_dims.size()));
    }

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    auto axis = ctx->Attrs().Get<int>("axis");
    auto input_dim = ctx->GetInputDim("X");
    if (ctx->HasInput("Axis") || axis == 0) {
      // if HasInput("Axis"), we can not obtain correct shape of output
      int batch_size = index_dims[0];
      framework::DDim output_dims(input_dim);
      output_dims[0] = batch_size;
      ctx->SetOutputDim("Out", output_dims);
      ctx->ShareLoD("X", /*->*/ "Out");
    } else {
      int index_size = index_dims[0];
      std::vector<int> out_dim_vec;
      for (int i = 0; i < axis; i++) {
        out_dim_vec.push_back(input_dim[i]);
      }
      out_dim_vec.push_back(index_size);
      for (int i = axis + 1; i < input_dim.size(); i++) {
        out_dim_vec.push_back(input_dim[i]);
      }
      auto output_dims = framework::make_ddim(out_dim_vec);
      ctx->SetOutputDim("Out", output_dims);
      ctx->ShareLoD("X", /*->*/ "Out");
    }
Z
zchen0211 已提交
79
  }
Y
Yu Yang 已提交
80

81
 protected:
82
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
83
      const framework::ExecutionContext& ctx) const override {
84 85 86
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
Y
Yu Yang 已提交
87
  }
88 89 90
  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const framework::Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
91 92 93 94 95
    if (var_name == "Axis") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
96
  }
Z
zchen0211 已提交
97 98 99 100 101 102
};

class GatherGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

103
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
Qiao Longfei 已提交
104
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
Y
Yibing Liu 已提交
105
    ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
Z
zchen0211 已提交
106
  }
Y
Yu Yang 已提交
107

108
 protected:
109
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
110
      const framework::ExecutionContext& ctx) const override {
111 112 113
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.device_context());
Y
Yu Yang 已提交
114
  }
115 116 117 118 119 120 121 122 123
  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const framework::Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "Axis") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
Z
zchen0211 已提交
124 125 126 127
};

class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
128
  void Make() override {
Z
zchen0211 已提交
129 130
    AddInput("X", "The source input of gather op");
    AddInput("Index", "The index input of gather op");
131 132 133
    AddInput("Axis",
             "The Tensor which contains the axis that we do gather operation.")
        .AsDispensable();
K
kexinzhao 已提交
134
    AddOutput("Out", "The output of gather op");
135 136 137 138 139 140 141
    AddAttr<bool>(
        "overwrite",
        "(bool, default: False) "
        "In backward process, calc the grad when has same index,"
        "If true, update the grad using the overwrite mode in same index,"
        "If false, using the accumulate mode in same index.")
        .SetDefault(true);
142 143 144 145
    AddAttr<int>(
        "axis",
        "The Tensor which contains the axis that we do gather operation.")
        .SetDefault(0);
Z
zchen0211 已提交
146
    AddComment(R"DOC(
K
kexinzhao 已提交
147 148 149 150
Gather Operator.

$Out = X[Index]$

Y
Yibing Liu 已提交
151
Out is obtained by gathering entries of the outer-most dimension
K
kexinzhao 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165
of X indexed by Index and concatenate them together.

Example:

X = [[1, 2],
     [3, 4],
     [5, 6]]

Index = [[1, 2]]

Then:

Out = [[3, 4],
       [5, 6]]
Z
zchen0211 已提交
166 167 168 169

)DOC");
  }
};
S
sneaxiy 已提交
170

H
hong 已提交
171 172
template <typename T>
class GatherGradOpMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
173
 public:
H
hong 已提交
174
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
175 176

 protected:
177
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
178
    op->SetType("gather_grad");
H
hong 已提交
179
    op->SetInput("Index", this->Input("Index"));
180 181
    op->SetInput("Axis", this->Input("Axis"));

H
hong 已提交
182 183 184 185
    op->SetInput("X", this->Input("X"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
S
sneaxiy 已提交
186 187 188
  }
};

189
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X");
S
sneaxiy 已提交
190

Z
zchen0211 已提交
191 192 193 194
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
195
REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
H
hong 已提交
196 197
                  ops::GatherGradOpMaker<paddle::framework::OpDesc>,
                  ops::GatherGradOpMaker<paddle::imperative::OpBase>);
S
sneaxiy 已提交
198
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
199
                  ops::GatherGradNoNeedBufferVarInferer);
200
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
201
                       ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
202
                       ops::GatherOpKernel<uint8_t>,
203
                       ops::GatherOpKernel<int64_t>);
204
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
205
                       ops::GatherGradientOpKernel<double>,
206
                       ops::GatherGradientOpKernel<int>,
207
                       ops::GatherGradientOpKernel<uint8_t>,
208
                       ops::GatherGradientOpKernel<int64_t>);
209
REGISTER_OP_VERSION(gather)
W
wangchaochaohu 已提交
210 211 212
    .AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC",
                   paddle::framework::compatible::OpVersionDesc().NewInput(
                       "Axis", "Specify the axis of gather operation."));