gather_op.cc 7.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/gather_op.h"
S
sneaxiy 已提交
16 17 18
#include <memory>
#include <string>
#include <vector>
19
#include "paddle/fluid/framework/op_version_registry.h"
20
#include "paddle/pten/core/ddim.h"
21

Z
zchen0211 已提交
22 23 24 25
namespace paddle {
namespace operators {

class GatherOp : public framework::OperatorWithKernel {
Z
zchen0211 已提交
26 27 28
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

29
  void InferShape(framework::InferShapeContext* ctx) const override {
30 31 32 33 34 35 36 37 38
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
                      platform::errors::InvalidArgument(
                          "Input(X) of GatherOp should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
                      platform::errors::InvalidArgument(
                          "Input(Index) of GatherOp should not be null."));
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
                      platform::errors::InvalidArgument(
                          "Output(Out) of GatherOp should not be null."));
39

Z
zchen0211 已提交
40
    auto index_dims = ctx->GetInputDim("Index");
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

    if (index_dims.size() == 2) {
      PADDLE_ENFORCE_EQ(
          index_dims[1], 1,
          platform::errors::InvalidArgument(
              "The last dim of index should be 1 when it is 2D, but we get %d",
              index_dims[1]));
    } else {
      PADDLE_ENFORCE_EQ(
          index_dims.size(), 1,
          platform::errors::InvalidArgument(
              "The index should be 1D, when it is not 2D, but we get %d",
              index_dims.size()));
    }

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
    auto axis = ctx->Attrs().Get<int>("axis");
    auto input_dim = ctx->GetInputDim("X");
    if (ctx->HasInput("Axis") || axis == 0) {
      // if HasInput("Axis"), we can not obtain correct shape of output
      int batch_size = index_dims[0];
      framework::DDim output_dims(input_dim);
      output_dims[0] = batch_size;
      ctx->SetOutputDim("Out", output_dims);
      ctx->ShareLoD("X", /*->*/ "Out");
    } else {
      int index_size = index_dims[0];
      std::vector<int> out_dim_vec;
      for (int i = 0; i < axis; i++) {
        out_dim_vec.push_back(input_dim[i]);
      }
      out_dim_vec.push_back(index_size);
      for (int i = axis + 1; i < input_dim.size(); i++) {
        out_dim_vec.push_back(input_dim[i]);
      }
75
      auto output_dims = pten::make_ddim(out_dim_vec);
76 77 78
      ctx->SetOutputDim("Out", output_dims);
      ctx->ShareLoD("X", /*->*/ "Out");
    }
Z
zchen0211 已提交
79
  }
Y
Yu Yang 已提交
80

81
 protected:
82
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
83
      const framework::ExecutionContext& ctx) const override {
84 85 86
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
Y
Yu Yang 已提交
87
  }
88 89 90
  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const framework::Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
91 92 93 94 95
    if (var_name == "Axis") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
96
  }
Z
zchen0211 已提交
97 98 99 100 101 102
};

class GatherGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

103
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
Qiao Longfei 已提交
104
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
Y
Yibing Liu 已提交
105
    ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
Z
zchen0211 已提交
106
  }
Y
Yu Yang 已提交
107

108
 protected:
109
  framework::OpKernelType GetExpectedKernelType(
Y
Yu Yang 已提交
110
      const framework::ExecutionContext& ctx) const override {
111 112 113
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.device_context());
Y
Yu Yang 已提交
114
  }
115 116 117 118 119 120 121 122 123
  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const framework::Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "Axis") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
Z
zchen0211 已提交
124 125 126 127
};

class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
128
  void Make() override {
Z
zchen0211 已提交
129 130
    AddInput("X", "The source input of gather op");
    AddInput("Index", "The index input of gather op");
131 132 133
    AddInput("Axis",
             "The Tensor which contains the axis that we do gather operation.")
        .AsDispensable();
K
kexinzhao 已提交
134
    AddOutput("Out", "The output of gather op");
135 136 137 138 139 140
    AddAttr<bool>(
        "overwrite",
        "(bool, default: False) "
        "In backward process, calc the grad when has same index,"
        "If true, update the grad using the overwrite mode in same index,"
        "If false, using the accumulate mode in same index.")
沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
141 142
        .SetDefault(true)
        .AsExtra();
143 144 145 146
    AddAttr<int>(
        "axis",
        "The Tensor which contains the axis that we do gather operation.")
        .SetDefault(0);
Z
zchen0211 已提交
147
    AddComment(R"DOC(
K
kexinzhao 已提交
148 149 150 151
Gather Operator.

$Out = X[Index]$

Y
Yibing Liu 已提交
152
Out is obtained by gathering entries of the outer-most dimension
K
kexinzhao 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166
of X indexed by Index and concatenate them together.

Example:

X = [[1, 2],
     [3, 4],
     [5, 6]]

Index = [[1, 2]]

Then:

Out = [[3, 4],
       [5, 6]]
Z
zchen0211 已提交
167 168 169 170

)DOC");
  }
};
S
sneaxiy 已提交
171

H
hong 已提交
172 173
template <typename T>
class GatherGradOpMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
174
 public:
H
hong 已提交
175
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
176 177

 protected:
178
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
179
    op->SetType("gather_grad");
H
hong 已提交
180
    op->SetInput("Index", this->Input("Index"));
181 182
    op->SetInput("Axis", this->Input("Axis"));

H
hong 已提交
183 184 185 186
    op->SetInput("X", this->Input("X"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
S
sneaxiy 已提交
187 188 189
  }
};

190
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X");
S
sneaxiy 已提交
191

Z
zchen0211 已提交
192 193 194 195
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
196
REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
H
hong 已提交
197 198
                  ops::GatherGradOpMaker<paddle::framework::OpDesc>,
                  ops::GatherGradOpMaker<paddle::imperative::OpBase>);
S
sneaxiy 已提交
199
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
200
                  ops::GatherGradNoNeedBufferVarInferer);
201
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
202
                       ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
203
                       ops::GatherOpKernel<uint8_t>,
204
                       ops::GatherOpKernel<int64_t>);
205
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
206
                       ops::GatherGradientOpKernel<double>,
207
                       ops::GatherGradientOpKernel<int>,
208
                       ops::GatherGradientOpKernel<uint8_t>,
209
                       ops::GatherGradientOpKernel<int64_t>);
210
REGISTER_OP_VERSION(gather)
W
wangchaochaohu 已提交
211 212 213
    .AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC",
                   paddle::framework::compatible::OpVersionDesc().NewInput(
                       "Axis", "Specify the axis of gather operation."));