gather_op.h 4.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16
#include "paddle/fluid/framework/convert_utils.h"
Y
Yi Wang 已提交
17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
19 20
#include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/phi/kernels/funcs/scatter.h"
Z
zchen0211 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

Z
zchen0211 已提交
27
template <typename T>
Y
Yu Yang 已提交
28
class GatherOpKernel : public framework::OpKernel<T> {
Z
zchen0211 已提交
29
 public:
Z
zchen0211 已提交
30
  void Compute(const framework::ExecutionContext &ctx) const override {
31 32 33
    PADDLE_ENFORCE_EQ(
        platform::is_cpu_place(ctx.GetPlace()), true,
        platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
Z
zchen0211 已提交
34 35 36 37 38

    auto *x = ctx.Input<Tensor>("X");
    auto *index = ctx.Input<Tensor>("Index");
    auto *output = ctx.Output<Tensor>("Out");

39 40
    int axis = ctx.Attr<int>("axis");
    // get axis from tensor
41
    if (ctx.HasInput("Axis")) {
42
      const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
43 44
      const auto &axis_type = axis_tensor->dtype();
      if (axis_type == phi::DataType::INT32) {
45
        axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
46
      } else if (axis_type == phi::DataType::INT64) {
47
        axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
48
      }
49
    }
50 51
    const auto &index_type = index->dtype();
    auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
52
    if (axis != 0) {
53 54 55 56 57 58
      if (index_type == phi::DataType::INT32) {
        phi::funcs::GatherV2Function<T, int32_t>(dev_ctx, x, index, axis,
                                                 output);
      } else if (index_type == phi::DataType::INT64) {
        phi::funcs::GatherV2Function<T, int64_t>(dev_ctx, x, index, axis,
                                                 output);
59 60 61 62
      }
      return;
    }

Z
zchen0211 已提交
63
    output->mutable_data<T>(ctx.GetPlace());
64
    if (x->numel() == 0) return;
65 66 67 68
    if (index_type == phi::DataType::INT32) {
      phi::funcs::CPUGather<T, int>(dev_ctx, *x, *index, output);
    } else if (index_type == phi::DataType::INT64) {
      phi::funcs::CPUGather<T, int64_t>(dev_ctx, *x, *index, output);
69
    }
Z
zchen0211 已提交
70 71 72
  }
};

Z
zchen0211 已提交
73
template <typename T>
Y
Yu Yang 已提交
74
class GatherGradientOpKernel : public framework::OpKernel<T> {
Z
zchen0211 已提交
75
 public:
Z
zchen0211 已提交
76
  void Compute(const framework::ExecutionContext &ctx) const override {
77 78 79
    PADDLE_ENFORCE_EQ(
        platform::is_cpu_place(ctx.GetPlace()), true,
        platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
Z
zchen0211 已提交
80

81
    auto *index = ctx.Input<Tensor>("Index");
Z
zchen0211 已提交
82 83
    auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
Z
zchen0211 已提交
84

85
    int axis = ctx.Attr<int>("axis");
86
    if (ctx.HasInput("Axis")) {
87
      const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
88 89
      const auto &axis_type = axis_tensor->dtype();
      if (axis_type == phi::DataType::INT32) {
90
        axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
91
      } else if (axis_type == phi::DataType::INT64) {
92
        axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
93
      }
94
    }
95 96
    const auto &index_type = index->dtype();
    auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
97 98

    if (axis != 0) {
99 100 101 102 103 104
      if (index_type == phi::DataType::INT32) {
        phi::funcs::GatherV2GradFunction<T, int32_t>(dev_ctx, dO, index, axis,
                                                     dX);
      } else if (index_type == phi::DataType::INT64) {
        phi::funcs::GatherV2GradFunction<T, int64_t>(dev_ctx, dO, index, axis,
                                                     dX);
105 106 107 108
      }
      return;
    }

Z
zchen0211 已提交
109
    dX->mutable_data<T>(ctx.GetPlace());
Z
zchen0211 已提交
110
    auto dxt = framework::EigenVector<T>::Flatten(*dX);
111
    auto &place = *dev_ctx.eigen_device();
Z
zchen0211 已提交
112
    dxt.device(place) = dxt.constant(static_cast<T>(0));
113
    if (dO->numel() == 0) return;
114
    bool overwrite = ctx.Attr<bool>("overwrite");
115

116
    if (index_type == phi::DataType::INT32) {
117
      if (overwrite) {
118
        phi::funcs::ScatterAssign<T, int32_t>(dev_ctx, *dO, *index, dX);
119
      } else {
120
        phi::funcs::ScatterAssignAdd<T, int32_t>(dev_ctx, *dO, *index, dX);
121
      }
122
    } else if (index_type == phi::DataType::INT64) {
123
      if (overwrite) {
124
        phi::funcs::ScatterAssign<T, int64_t>(dev_ctx, *dO, *index, dX);
125
      } else {
126
        phi::funcs::ScatterAssignAdd<T, int64_t>(dev_ctx, *dO, *index, dX);
127
      }
128
    }
Z
zchen0211 已提交
129 130 131 132 133
  }
};

}  // namespace operators
}  // namespace paddle