未验证 提交 09258040 编写于 作者: S sneaxiy 提交者: GitHub

Move gather.h/gather.cu.h/scatter.h/scatter.cu.h to the phi library (#40043)

* move gather.h gather.cu.h scatter.h scatter.cu.h to phi library

* fix CI

* fix rocm ci
上级 2e6548a9
...@@ -23,7 +23,6 @@ limitations under the License. */ ...@@ -23,7 +23,6 @@ limitations under the License. */
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
namespace cub = hipcub; namespace cub = hipcub;
#endif #endif
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
......
...@@ -23,11 +23,11 @@ namespace cub = hipcub; ...@@ -23,11 +23,11 @@ namespace cub = hipcub;
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h" #include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -160,9 +160,9 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -160,9 +160,9 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
sorted_rois.mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace()); sorted_rois.mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
Tensor sorted_batch_id; Tensor sorted_batch_id;
sorted_batch_id.mutable_data<int>({real_post_num}, dev_ctx.GetPlace()); sorted_batch_id.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
GPUGather<T>(dev_ctx, concat_rois, index_out_t, &sorted_rois); phi::funcs::GPUGather<T>(dev_ctx, concat_rois, index_out_t, &sorted_rois);
GPUGather<int>(dev_ctx, roi_batch_id_list_gpu, index_out_t, phi::funcs::GPUGather<int>(dev_ctx, roi_batch_id_list_gpu, index_out_t,
&sorted_batch_id); &sorted_batch_id);
Tensor batch_index_t; Tensor batch_index_t;
int* batch_idx_in = int* batch_idx_in =
...@@ -190,7 +190,7 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -190,7 +190,7 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num, 0, out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num, 0,
sizeof(int) * 8, dev_ctx.stream()); sizeof(int) * 8, dev_ctx.stream());
GPUGather<T>(dev_ctx, sorted_rois, index_out_t, fpn_rois); phi::funcs::GPUGather<T>(dev_ctx, sorted_rois, index_out_t, fpn_rois);
Tensor length_lod; Tensor length_lod;
int* length_lod_data = int* length_lod_data =
......
...@@ -21,7 +21,6 @@ limitations under the License.*/ ...@@ -21,7 +21,6 @@ limitations under the License.*/
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -66,7 +65,8 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -66,7 +65,8 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
auto multi_layer_scores = auto multi_layer_scores =
context.MultiInput<paddle::framework::LoDTensor>("MultiLevelScores"); context.MultiInput<paddle::framework::LoDTensor>("MultiLevelScores");
auto multi_rois_num = context.MultiInput<Tensor>("MultiLevelRoIsNum"); auto multi_rois_num =
context.MultiInput<framework::Tensor>("MultiLevelRoIsNum");
int num_size = multi_rois_num.size(); int num_size = multi_rois_num.size();
auto* fpn_rois = context.Output<paddle::framework::LoDTensor>("FpnRois"); auto* fpn_rois = context.Output<paddle::framework::LoDTensor>("FpnRois");
...@@ -176,7 +176,7 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -176,7 +176,7 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
} }
num_per_batch.emplace_back(post_nms_topN - pre_idx); num_per_batch.emplace_back(post_nms_topN - pre_idx);
if (context.HasOutput("RoisNum")) { if (context.HasOutput("RoisNum")) {
auto* rois_num = context.Output<Tensor>("RoisNum"); auto* rois_num = context.Output<framework::Tensor>("RoisNum");
int* rois_num_data = int* rois_num_data =
rois_num->mutable_data<int>({batch_size}, context.GetPlace()); rois_num->mutable_data<int>({batch_size}, context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
......
...@@ -24,9 +24,9 @@ namespace cub = hipcub; ...@@ -24,9 +24,9 @@ namespace cub = hipcub;
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h" #include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -193,7 +193,8 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -193,7 +193,8 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
start = end; start = end;
multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim}, multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim},
dev_ctx.GetPlace()); dev_ctx.GetPlace());
GPUGather<T>(dev_ctx, *fpn_rois, sub_idx, multi_fpn_rois[i]); phi::funcs::GPUGather<T>(dev_ctx, *fpn_rois, sub_idx,
multi_fpn_rois[i]);
} else { } else {
multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim}, multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim},
dev_ctx.GetPlace()); dev_ctx.GetPlace());
......
...@@ -20,7 +20,6 @@ limitations under the License. */ ...@@ -20,7 +20,6 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -28,10 +27,11 @@ namespace operators { ...@@ -28,10 +27,11 @@ namespace operators {
const int kBoxDim = 4; const int kBoxDim = 4;
inline std::vector<size_t> GetLodFromRoisNum(const Tensor* rois_num) { inline std::vector<size_t> GetLodFromRoisNum(
const framework::Tensor* rois_num) {
std::vector<size_t> rois_lod; std::vector<size_t> rois_lod;
auto* rois_num_data = rois_num->data<int>(); auto* rois_num_data = rois_num->data<int>();
Tensor cpu_tensor; framework::Tensor cpu_tensor;
if (platform::is_gpu_place(rois_num->place())) { if (platform::is_gpu_place(rois_num->place())) {
paddle::framework::TensorCopySync(*rois_num, platform::CPUPlace(), paddle::framework::TensorCopySync(*rois_num, platform::CPUPlace(),
&cpu_tensor); &cpu_tensor);
...@@ -93,7 +93,7 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -93,7 +93,7 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
std::vector<size_t> fpn_rois_lod; std::vector<size_t> fpn_rois_lod;
int fpn_rois_num; int fpn_rois_num;
if (context.HasInput("RoisNum")) { if (context.HasInput("RoisNum")) {
auto* rois_num = context.Input<Tensor>("RoisNum"); auto* rois_num = context.Input<framework::Tensor>("RoisNum");
fpn_rois_lod = GetLodFromRoisNum(rois_num); fpn_rois_lod = GetLodFromRoisNum(rois_num);
} else { } else {
fpn_rois_lod = fpn_rois->lod().back(); fpn_rois_lod = fpn_rois->lod().back();
...@@ -105,7 +105,7 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -105,7 +105,7 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
std::vector<int> num_rois_level(num_level, 0); std::vector<int> num_rois_level(num_level, 0);
std::vector<int> num_rois_level_integral(num_level + 1, 0); std::vector<int> num_rois_level_integral(num_level + 1, 0);
for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) { for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) {
Tensor fpn_rois_slice = auto fpn_rois_slice =
fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]); fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
const T* rois_data = fpn_rois_slice.data<T>(); const T* rois_data = fpn_rois_slice.data<T>();
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) { for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
...@@ -140,7 +140,7 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -140,7 +140,7 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
std::vector<int> restore_index_inter(fpn_rois_num, -1); std::vector<int> restore_index_inter(fpn_rois_num, -1);
// distribute the rois into different fpn level by target level // distribute the rois into different fpn level by target level
for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) { for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) {
Tensor fpn_rois_slice = auto fpn_rois_slice =
fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]); fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
const T* rois_data = fpn_rois_slice.data<T>(); const T* rois_data = fpn_rois_slice.data<T>();
size_t cur_offset = fpn_rois_lod[i]; size_t cur_offset = fpn_rois_lod[i];
...@@ -163,7 +163,8 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> { ...@@ -163,7 +163,8 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
for (int i = 0; i < fpn_rois_num; ++i) { for (int i = 0; i < fpn_rois_num; ++i) {
restore_index_data[restore_index_inter[i]] = i; restore_index_data[restore_index_inter[i]] = i;
} }
auto multi_rois_num = context.MultiOutput<Tensor>("MultiLevelRoIsNum"); auto multi_rois_num =
context.MultiOutput<framework::Tensor>("MultiLevelRoIsNum");
if (multi_rois_num.size() > 0) { if (multi_rois_num.size() > 0) {
int batch_size = fpn_rois_lod.size() - 1; int batch_size = fpn_rois_lod.size() - 1;
for (int i = 0; i < num_level; ++i) { for (int i = 0; i < num_level; ++i) {
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/mask_util.h" #include "paddle/fluid/operators/detection/mask_util.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
......
...@@ -16,8 +16,8 @@ limitations under the License. */ ...@@ -16,8 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -281,22 +281,22 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context, ...@@ -281,22 +281,22 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context,
Tensor fg_boxes, bg_boxes, fg_labels, bg_labels; Tensor fg_boxes, bg_boxes, fg_labels, bg_labels;
fg_boxes.mutable_data<T>({fg_num, kBoxDim}, context.GetPlace()); fg_boxes.mutable_data<T>({fg_num, kBoxDim}, context.GetPlace());
CPUGather<T>(context, boxes, fg_inds_t, &fg_boxes); phi::funcs::CPUGather<T>(context, boxes, fg_inds_t, &fg_boxes);
bg_boxes.mutable_data<T>({bg_num, kBoxDim}, context.GetPlace()); bg_boxes.mutable_data<T>({bg_num, kBoxDim}, context.GetPlace());
CPUGather<T>(context, boxes, bg_inds_t, &bg_boxes); phi::funcs::CPUGather<T>(context, boxes, bg_inds_t, &bg_boxes);
Concat<T>(context, fg_boxes, bg_boxes, sampled_boxes); Concat<T>(context, fg_boxes, bg_boxes, sampled_boxes);
CPUGather<T>(context, gt_boxes, gt_box_inds_t, sampled_gts); phi::funcs::CPUGather<T>(context, gt_boxes, gt_box_inds_t, sampled_gts);
fg_labels.mutable_data<int>({fg_num}, context.GetPlace()); fg_labels.mutable_data<int>({fg_num}, context.GetPlace());
CPUGather<int>(context, gt_classes, gt_label_inds_t, &fg_labels); phi::funcs::CPUGather<int>(context, gt_classes, gt_label_inds_t, &fg_labels);
bg_labels.mutable_data<int>({bg_num}, context.GetPlace()); bg_labels.mutable_data<int>({bg_num}, context.GetPlace());
phi::funcs::set_constant(context, &bg_labels, 0); phi::funcs::set_constant(context, &bg_labels, 0);
Concat<int>(context, fg_labels, bg_labels, sampled_labels); Concat<int>(context, fg_labels, bg_labels, sampled_labels);
Tensor fg_max_overlap, bg_max_overlap; Tensor fg_max_overlap, bg_max_overlap;
fg_max_overlap.mutable_data<T>({fg_num}, context.GetPlace()); fg_max_overlap.mutable_data<T>({fg_num}, context.GetPlace());
CPUGather<T>(context, max_overlap, fg_inds_t, &fg_max_overlap); phi::funcs::CPUGather<T>(context, max_overlap, fg_inds_t, &fg_max_overlap);
bg_max_overlap.mutable_data<T>({bg_num}, context.GetPlace()); bg_max_overlap.mutable_data<T>({bg_num}, context.GetPlace());
CPUGather<T>(context, max_overlap, bg_inds_t, &bg_max_overlap); phi::funcs::CPUGather<T>(context, max_overlap, bg_inds_t, &bg_max_overlap);
Concat<T>(context, fg_max_overlap, bg_max_overlap, sampled_max_overlap); Concat<T>(context, fg_max_overlap, bg_max_overlap, sampled_max_overlap);
} }
...@@ -334,7 +334,7 @@ std::vector<Tensor> SampleRoisForOneImage( ...@@ -334,7 +334,7 @@ std::vector<Tensor> SampleRoisForOneImage(
} else { } else {
proposals_num = keep.numel(); proposals_num = keep.numel();
roi_filter.mutable_data<T>({proposals_num, kBoxDim}, context.GetPlace()); roi_filter.mutable_data<T>({proposals_num, kBoxDim}, context.GetPlace());
CPUGather<T>(context, rpn_rois, keep, &roi_filter); phi::funcs::CPUGather<T>(context, rpn_rois, keep, &roi_filter);
} }
T* roi_filter_dt = roi_filter.data<T>(); T* roi_filter_dt = roi_filter.data<T>();
memcpy(rpn_rois_dt, roi_filter_dt, roi_filter.numel() * sizeof(T)); memcpy(rpn_rois_dt, roi_filter_dt, roi_filter.numel() * sizeof(T));
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/nms_util.h" #include "paddle/fluid/operators/detection/nms_util.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -196,10 +196,10 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -196,10 +196,10 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
anchor_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace()); anchor_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
var_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace()); var_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
CPUGather<T>(ctx, scores_slice, index_t, &scores_sel); phi::funcs::CPUGather<T>(ctx, scores_slice, index_t, &scores_sel);
CPUGather<T>(ctx, bbox_deltas_slice, index_t, &bbox_sel); phi::funcs::CPUGather<T>(ctx, bbox_deltas_slice, index_t, &bbox_sel);
CPUGather<T>(ctx, anchors, index_t, &anchor_sel); phi::funcs::CPUGather<T>(ctx, anchors, index_t, &anchor_sel);
CPUGather<T>(ctx, variances, index_t, &var_sel); phi::funcs::CPUGather<T>(ctx, variances, index_t, &var_sel);
Tensor proposals; Tensor proposals;
proposals.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace()); proposals.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
...@@ -223,8 +223,8 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -223,8 +223,8 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
Tensor scores_filter; Tensor scores_filter;
bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace()); bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({keep.numel(), 1}, ctx.GetPlace()); scores_filter.mutable_data<T>({keep.numel(), 1}, ctx.GetPlace());
CPUGather<T>(ctx, proposals, keep, &bbox_sel); phi::funcs::CPUGather<T>(ctx, proposals, keep, &bbox_sel);
CPUGather<T>(ctx, scores_sel, keep, &scores_filter); phi::funcs::CPUGather<T>(ctx, scores_sel, keep, &scores_filter);
if (nms_thresh <= 0) { if (nms_thresh <= 0) {
return std::make_pair(bbox_sel, scores_filter); return std::make_pair(bbox_sel, scores_filter);
} }
...@@ -237,8 +237,8 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -237,8 +237,8 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
proposals.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace()); proposals.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
scores_sel.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace()); scores_sel.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
CPUGather<T>(ctx, bbox_sel, keep_nms, &proposals); phi::funcs::CPUGather<T>(ctx, bbox_sel, keep_nms, &proposals);
CPUGather<T>(ctx, scores_filter, keep_nms, &scores_sel); phi::funcs::CPUGather<T>(ctx, scores_filter, keep_nms, &scores_sel);
return std::make_pair(proposals, scores_sel); return std::make_pair(proposals, scores_sel);
} }
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detection/bbox_util.cu.h" #include "paddle/fluid/operators/detection/bbox_util.cu.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -85,8 +86,8 @@ static std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -85,8 +86,8 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
} }
proposals_filter.mutable_data<T>({keep_num, 4}, ctx.GetPlace()); proposals_filter.mutable_data<T>({keep_num, 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({keep_num, 1}, ctx.GetPlace()); scores_filter.mutable_data<T>({keep_num, 1}, ctx.GetPlace());
GPUGather<T>(ctx, proposals, keep_index, &proposals_filter); phi::funcs::GPUGather<T>(ctx, proposals, keep_index, &proposals_filter);
GPUGather<T>(ctx, scores_sort, keep_index, &scores_filter); phi::funcs::GPUGather<T>(ctx, scores_sort, keep_index, &scores_filter);
if (nms_thresh <= 0) { if (nms_thresh <= 0) {
return std::make_pair(proposals_filter, scores_filter); return std::make_pair(proposals_filter, scores_filter);
...@@ -102,8 +103,8 @@ static std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -102,8 +103,8 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
Tensor scores_nms, proposals_nms; Tensor scores_nms, proposals_nms;
proposals_nms.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace()); proposals_nms.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
scores_nms.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace()); scores_nms.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
GPUGather<T>(ctx, proposals_filter, keep_nms, &proposals_nms); phi::funcs::GPUGather<T>(ctx, proposals_filter, keep_nms, &proposals_nms);
GPUGather<T>(ctx, scores_filter, keep_nms, &scores_nms); phi::funcs::GPUGather<T>(ctx, scores_filter, keep_nms, &scores_nms);
return std::make_pair(proposals_nms, scores_nms); return std::make_pair(proposals_nms, scores_nms);
} }
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/nms_util.h" #include "paddle/fluid/operators/detection/nms_util.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -197,10 +197,10 @@ class GenerateProposalsV2Kernel : public framework::OpKernel<T> { ...@@ -197,10 +197,10 @@ class GenerateProposalsV2Kernel : public framework::OpKernel<T> {
anchor_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace()); anchor_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
var_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace()); var_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
CPUGather<T>(ctx, scores_slice, index_t, &scores_sel); phi::funcs::CPUGather<T>(ctx, scores_slice, index_t, &scores_sel);
CPUGather<T>(ctx, bbox_deltas_slice, index_t, &bbox_sel); phi::funcs::CPUGather<T>(ctx, bbox_deltas_slice, index_t, &bbox_sel);
CPUGather<T>(ctx, anchors, index_t, &anchor_sel); phi::funcs::CPUGather<T>(ctx, anchors, index_t, &anchor_sel);
CPUGather<T>(ctx, variances, index_t, &var_sel); phi::funcs::CPUGather<T>(ctx, variances, index_t, &var_sel);
Tensor proposals; Tensor proposals;
proposals.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace()); proposals.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
...@@ -227,8 +227,8 @@ class GenerateProposalsV2Kernel : public framework::OpKernel<T> { ...@@ -227,8 +227,8 @@ class GenerateProposalsV2Kernel : public framework::OpKernel<T> {
Tensor scores_filter; Tensor scores_filter;
bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace()); bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({keep.numel(), 1}, ctx.GetPlace()); scores_filter.mutable_data<T>({keep.numel(), 1}, ctx.GetPlace());
CPUGather<T>(ctx, proposals, keep, &bbox_sel); phi::funcs::CPUGather<T>(ctx, proposals, keep, &bbox_sel);
CPUGather<T>(ctx, scores_sel, keep, &scores_filter); phi::funcs::CPUGather<T>(ctx, scores_sel, keep, &scores_filter);
if (nms_thresh <= 0) { if (nms_thresh <= 0) {
return std::make_pair(bbox_sel, scores_filter); return std::make_pair(bbox_sel, scores_filter);
} }
...@@ -242,8 +242,8 @@ class GenerateProposalsV2Kernel : public framework::OpKernel<T> { ...@@ -242,8 +242,8 @@ class GenerateProposalsV2Kernel : public framework::OpKernel<T> {
proposals.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace()); proposals.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
scores_sel.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace()); scores_sel.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
CPUGather<T>(ctx, bbox_sel, keep_nms, &proposals); phi::funcs::CPUGather<T>(ctx, bbox_sel, keep_nms, &proposals);
CPUGather<T>(ctx, scores_filter, keep_nms, &scores_sel); phi::funcs::CPUGather<T>(ctx, scores_filter, keep_nms, &scores_sel);
return std::make_pair(proposals, scores_sel); return std::make_pair(proposals, scores_sel);
} }
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detection/bbox_util.cu.h" #include "paddle/fluid/operators/detection/bbox_util.cu.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -86,8 +87,8 @@ static std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -86,8 +87,8 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
} }
proposals_filter.mutable_data<T>({keep_num, 4}, ctx.GetPlace()); proposals_filter.mutable_data<T>({keep_num, 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({keep_num, 1}, ctx.GetPlace()); scores_filter.mutable_data<T>({keep_num, 1}, ctx.GetPlace());
GPUGather<T>(ctx, proposals, keep_index, &proposals_filter); phi::funcs::GPUGather<T>(ctx, proposals, keep_index, &proposals_filter);
GPUGather<T>(ctx, scores_sort, keep_index, &scores_filter); phi::funcs::GPUGather<T>(ctx, scores_sort, keep_index, &scores_filter);
if (nms_thresh <= 0) { if (nms_thresh <= 0) {
return std::make_pair(proposals_filter, scores_filter); return std::make_pair(proposals_filter, scores_filter);
...@@ -104,8 +105,8 @@ static std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -104,8 +105,8 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
Tensor scores_nms, proposals_nms; Tensor scores_nms, proposals_nms;
proposals_nms.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace()); proposals_nms.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
scores_nms.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace()); scores_nms.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
GPUGather<T>(ctx, proposals_filter, keep_nms, &proposals_nms); phi::funcs::GPUGather<T>(ctx, proposals_filter, keep_nms, &proposals_nms);
GPUGather<T>(ctx, scores_filter, keep_nms, &scores_nms); phi::funcs::GPUGather<T>(ctx, scores_filter, keep_nms, &scores_nms);
return std::make_pair(proposals_nms, scores_nms); return std::make_pair(proposals_nms, scores_nms);
} }
......
...@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/gather_nd_op.h" #include "paddle/fluid/operators/gather_nd_op.h"
#include "paddle/fluid/operators/scatter.cu.h" #include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class GatherNdOpCUDAKernel : public framework::OpKernel<T> { class GatherNdOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -33,27 +33,25 @@ class GatherNdOpCUDAKernel : public framework::OpKernel<T> { ...@@ -33,27 +33,25 @@ class GatherNdOpCUDAKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return; if (x->numel() == 0) return;
const auto &index_type = framework::TransToProtoVarType(index->dtype()); const auto &index_type = index->dtype();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == phi::DataType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( index_type_match, true,
"Index holds the wrong type, it holds [%s], but " platform::errors::InvalidArgument(
"desires to be [%s] or [%s].", "Index holds the wrong type, it holds [%s], but "
paddle::framework::DataTypeToString(index_type), "desires to be [%s] or [%s].",
paddle::framework::DataTypeToString( index_type, phi::DataType::INT32, phi::DataType::INT64));
framework::proto::VarType::INT32), auto &dev_ctx = ctx.cuda_device_context();
paddle::framework::DataTypeToString( if (index_type == phi::DataType::INT32) {
framework::proto::VarType::INT64))); phi::funcs::GPUGatherNd<T, int>(dev_ctx, *x, *index, output);
if (index_type == framework::proto::VarType::INT32) { } else if (index_type == phi::DataType::INT64) {
GPUGatherNd<DeviceContext, T, int>(ctx, *x, *index, output); phi::funcs::GPUGatherNd<T, int64_t>(dev_ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
GPUGatherNd<DeviceContext, T, int64_t>(ctx, *x, *index, output);
} }
} }
}; };
template <typename DeviceContext, typename T> template <typename T>
class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> { class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -71,24 +69,22 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -71,24 +69,22 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return; if (dO->numel() == 0) return;
const auto &index_type = framework::TransToProtoVarType(index->dtype()); const auto &index_type = index->dtype();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == phi::DataType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( index_type_match, true,
"Index holds the wrong type, it holds [%s]," platform::errors::InvalidArgument(
"but desires to be [%s] or [%s].", "Index holds the wrong type, it holds [%s],"
paddle::framework::DataTypeToString(index_type), "but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString( index_type, phi::DataType::INT32, phi::DataType::INT64));
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { auto &dev_ctx = ctx.cuda_device_context();
GPUScatterNdAdd<DeviceContext, T, int>(ctx, *dO, *index, dX); if (index_type == phi::DataType::INT32) {
} else if (index_type == framework::proto::VarType::INT64) { phi::funcs::GPUScatterNdAdd<T, int>(dev_ctx, *dO, *index, dX);
GPUScatterNdAdd<DeviceContext, T, int64_t>(ctx, *dO, *index, dX); } else if (index_type == phi::DataType::INT64) {
phi::funcs::GPUScatterNdAdd<T, int64_t>(dev_ctx, *dO, *index, dX);
} }
} }
}; };
...@@ -98,18 +94,16 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -98,18 +94,16 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel<float>,
REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel<CUDA, float>, ops::GatherNdOpCUDAKernel<double>,
ops::GatherNdOpCUDAKernel<CUDA, double>, ops::GatherNdOpCUDAKernel<int64_t>,
ops::GatherNdOpCUDAKernel<CUDA, int64_t>, ops::GatherNdOpCUDAKernel<int>,
ops::GatherNdOpCUDAKernel<CUDA, int>, ops::GatherNdOpCUDAKernel<int16_t>,
ops::GatherNdOpCUDAKernel<CUDA, int16_t>, ops::GatherNdOpCUDAKernel<bool>,
ops::GatherNdOpCUDAKernel<CUDA, bool>, ops::GatherNdOpCUDAKernel<plat::float16>);
ops::GatherNdOpCUDAKernel<CUDA, plat::float16>);
REGISTER_OP_CUDA_KERNEL(gather_nd_grad, REGISTER_OP_CUDA_KERNEL(gather_nd_grad, ops::GatherNdGradOpCUDAKernel<float>,
ops::GatherNdGradOpCUDAKernel<CUDA, float>, ops::GatherNdGradOpCUDAKernel<double>,
ops::GatherNdGradOpCUDAKernel<CUDA, double>, ops::GatherNdGradOpCUDAKernel<int64_t>,
ops::GatherNdGradOpCUDAKernel<CUDA, int64_t>, ops::GatherNdGradOpCUDAKernel<int>,
ops::GatherNdGradOpCUDAKernel<CUDA, int>, ops::GatherNdGradOpCUDAKernel<plat::float16>);
ops::GatherNdGradOpCUDAKernel<CUDA, plat::float16>);
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/fluid/operators/scatter.h" #include "paddle/phi/kernels/funcs/scatter.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -38,22 +38,20 @@ class GatherNdOpKernel : public framework::OpKernel<T> { ...@@ -38,22 +38,20 @@ class GatherNdOpKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return; if (x->numel() == 0) return;
const auto &index_type = framework::TransToProtoVarType(index->dtype()); auto index_type = index->dtype();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == phi::DataType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( index_type_match, true,
"Index holds the wrong type, it holds [%s]," platform::errors::InvalidArgument(
"but desires to be [%s] or [%s]", "Index holds the wrong type, it holds [%s],"
paddle::framework::DataTypeToString(index_type), "but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString( index_type, phi::DataType::INT32, phi::DataType::INT64));
framework::proto::VarType::INT32), auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
paddle::framework::DataTypeToString( if (index_type == phi::DataType::INT32) {
framework::proto::VarType::INT64))); phi::funcs::CPUGatherNd<T, int>(dev_ctx, *x, *index, output);
if (index_type == framework::proto::VarType::INT32) { } else if (index_type == phi::DataType::INT64) {
CPUGatherNd<T, int>(ctx.device_context(), *x, *index, output); phi::funcs::CPUGatherNd<T, int64_t>(dev_ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
CPUGatherNd<T, int64_t>(ctx.device_context(), *x, *index, output);
} }
} }
}; };
...@@ -65,6 +63,7 @@ class GatherNdGradOpKernel : public framework::OpKernel<T> { ...@@ -65,6 +63,7 @@ class GatherNdGradOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true, platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
...@@ -75,22 +74,21 @@ class GatherNdGradOpKernel : public framework::OpKernel<T> { ...@@ -75,22 +74,21 @@ class GatherNdGradOpKernel : public framework::OpKernel<T> {
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return; if (dO->numel() == 0) return;
const auto &index_type = framework::TransToProtoVarType(index->dtype()); auto index_type = index->dtype();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == phi::DataType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( index_type_match, true,
"Index holds the wrong type, it holds [%s]," platform::errors::InvalidArgument(
"but desires to be [%s] or [%s]", "Index holds the wrong type, it holds [%s],"
paddle::framework::DataTypeToString(index_type), "but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString( index_type, phi::DataType::INT32, phi::DataType::INT64));
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString( auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
framework::proto::VarType::INT64))); if (index_type == phi::DataType::INT32) {
if (index_type == framework::proto::VarType::INT32) { phi::funcs::ScatterNdAdd<T, int32_t>(dev_ctx, *dO, *index, dX);
ScatterNdAdd<T, int32_t>(ctx, *dO, *index, dX); } else if (index_type == phi::DataType::INT64) {
} else if (index_type == framework::proto::VarType::INT64) { phi::funcs::ScatterNdAdd<T, int64_t>(dev_ctx, *dO, *index, dX);
ScatterNdAdd<T, int64_t>(ctx, *dO, *index, dX);
} }
} }
}; };
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/gather_op.h" #include "paddle/fluid/operators/gather_op.h"
#include "paddle/fluid/operators/scatter.cu.h" #include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -49,11 +49,14 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -49,11 +49,14 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
} }
const auto &place = ctx.GetPlace(); const auto &place = ctx.GetPlace();
const auto &index_type = framework::TransToProtoVarType(index->dtype()); const auto &index_type = framework::TransToProtoVarType(index->dtype());
const auto &dev_ctx = ctx.cuda_device_context();
if (axis != 0) { if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GatherV2CUDAFunction<T, int32_t>(x, index, axis, output, place, ctx); phi::funcs::GatherV2CUDAFunction<T, int32_t>(x, index, axis, output,
dev_ctx);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
GatherV2CUDAFunction<T, int64_t>(x, index, axis, output, place, ctx); phi::funcs::GatherV2CUDAFunction<T, int64_t>(x, index, axis, output,
dev_ctx);
} }
return; return;
} }
...@@ -61,9 +64,9 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -61,9 +64,9 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return; if (x->numel() == 0) return;
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int>(ctx.device_context(), *x, *index, output); phi::funcs::GPUGather<T, int>(dev_ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
GPUGather<T, int64_t>(ctx.device_context(), *x, *index, output); phi::funcs::GPUGather<T, int64_t>(dev_ctx, *x, *index, output);
} }
} }
}; };
...@@ -93,14 +96,15 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -93,14 +96,15 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
} }
} }
const auto &dev_ctx = ctx.cuda_device_context();
const auto &index_type = framework::TransToProtoVarType(index->dtype()); const auto &index_type = framework::TransToProtoVarType(index->dtype());
if (axis != 0) { if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GatherV2GradCUDAFunction<T, int32_t>(dO, index, axis, dX, phi::funcs::GatherV2GradCUDAFunction<T, int32_t>(dO, index, axis, dX,
ctx.GetPlace(), ctx); dev_ctx);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int64_t>(dO, index, axis, dX, phi::funcs::GatherV2GradCUDAFunction<T, int64_t>(dO, index, axis, dX,
ctx.GetPlace(), ctx); dev_ctx);
} }
return; return;
} }
...@@ -112,11 +116,11 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -112,11 +116,11 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return; if (dO->numel() == 0) return;
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int>(ctx, *dO, *index, dX, phi::funcs::GPUScatterAssign<T, int>(dev_ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite")); ctx.Attr<bool>("overwrite"));
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
GPUScatterAssign<T, int64_t>(ctx, *dO, *index, dX, phi::funcs::GPUScatterAssign<T, int64_t>(dev_ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite")); ctx.Attr<bool>("overwrite"));
} }
} }
}; };
......
...@@ -16,8 +16,8 @@ limitations under the License. */ ...@@ -16,8 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/fluid/operators/scatter.h" #include "paddle/phi/kernels/funcs/scatter.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -40,31 +40,32 @@ class GatherOpKernel : public framework::OpKernel<T> { ...@@ -40,31 +40,32 @@ class GatherOpKernel : public framework::OpKernel<T> {
// get axis from tensor // get axis from tensor
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis"); const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &axis_type = const auto &axis_type = axis_tensor->dtype();
framework::TransToProtoVarType(axis_tensor->dtype()); if (axis_type == phi::DataType::INT32) {
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(axis_tensor->data<int32_t>()[0]); axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) { } else if (axis_type == phi::DataType::INT64) {
axis = static_cast<int>(axis_tensor->data<int64_t>()[0]); axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
} }
} }
const auto &place = ctx.GetPlace(); const auto &index_type = index->dtype();
const auto &index_type = framework::TransToProtoVarType(index->dtype()); auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
if (axis != 0) { if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) { if (index_type == phi::DataType::INT32) {
GatherV2Function<T, int32_t>(x, index, axis, output, place); phi::funcs::GatherV2Function<T, int32_t>(dev_ctx, x, index, axis,
} else if (index_type == framework::proto::VarType::INT64) { output);
GatherV2Function<T, int64_t>(x, index, axis, output, place); } else if (index_type == phi::DataType::INT64) {
phi::funcs::GatherV2Function<T, int64_t>(dev_ctx, x, index, axis,
output);
} }
return; return;
} }
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return; if (x->numel() == 0) return;
if (index_type == framework::proto::VarType::INT32) { if (index_type == phi::DataType::INT32) {
CPUGather<T, int>(ctx.device_context(), *x, *index, output); phi::funcs::CPUGather<T, int>(dev_ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == phi::DataType::INT64) {
CPUGather<T, int64_t>(ctx.device_context(), *x, *index, output); phi::funcs::CPUGather<T, int64_t>(dev_ctx, *x, *index, output);
} }
} }
}; };
...@@ -84,44 +85,45 @@ class GatherGradientOpKernel : public framework::OpKernel<T> { ...@@ -84,44 +85,45 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis"); const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &axis_type = const auto &axis_type = axis_tensor->dtype();
framework::TransToProtoVarType(axis_tensor->dtype()); if (axis_type == phi::DataType::INT32) {
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(axis_tensor->data<int32_t>()[0]); axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) { } else if (axis_type == phi::DataType::INT64) {
axis = static_cast<int>(axis_tensor->data<int64_t>()[0]); axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
} }
} }
const auto &index_type = framework::TransToProtoVarType(index->dtype()); const auto &index_type = index->dtype();
auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
if (axis != 0) { if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) { if (index_type == phi::DataType::INT32) {
GatherV2GradFunction<T, int32_t>(dO, index, axis, dX, ctx.GetPlace()); phi::funcs::GatherV2GradFunction<T, int32_t>(dev_ctx, dO, index, axis,
} else if (index_type == framework::proto::VarType::INT64) { dX);
GatherV2GradFunction<T, int64_t>(dO, index, axis, dX, ctx.GetPlace()); } else if (index_type == phi::DataType::INT64) {
phi::funcs::GatherV2GradFunction<T, int64_t>(dev_ctx, dO, index, axis,
dX);
} }
return; return;
} }
dX->mutable_data<T>(ctx.GetPlace()); dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX); auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto &place = *ctx.template device_context<platform::CPUDeviceContext>() auto &place = *dev_ctx.eigen_device();
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return; if (dO->numel() == 0) return;
bool overwrite = ctx.Attr<bool>("overwrite"); bool overwrite = ctx.Attr<bool>("overwrite");
if (index_type == framework::proto::VarType::INT32) { if (index_type == phi::DataType::INT32) {
if (overwrite) { if (overwrite) {
ScatterAssign<T, int32_t>(ctx.device_context(), *dO, *index, dX); phi::funcs::ScatterAssign<T, int32_t>(dev_ctx, *dO, *index, dX);
} else { } else {
ScatterAssignAdd<T, int32_t>(ctx, *dO, *index, dX); phi::funcs::ScatterAssignAdd<T, int32_t>(dev_ctx, *dO, *index, dX);
} }
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == phi::DataType::INT64) {
if (overwrite) { if (overwrite) {
ScatterAssign<T, int64_t>(ctx.device_context(), *dO, *index, dX); phi::funcs::ScatterAssign<T, int64_t>(dev_ctx, *dO, *index, dX);
} else { } else {
ScatterAssignAdd<T, int64_t>(ctx, *dO, *index, dX); phi::funcs::ScatterAssignAdd<T, int64_t>(dev_ctx, *dO, *index, dX);
} }
} }
} }
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/kernels/funcs/gather.h"
TEST(Gather, GatherData) { TEST(Gather, GatherData) {
paddle::framework::Tensor* src = new paddle::framework::Tensor(); paddle::framework::Tensor* src = new paddle::framework::Tensor();
...@@ -39,7 +39,7 @@ TEST(Gather, GatherData) { ...@@ -39,7 +39,7 @@ TEST(Gather, GatherData) {
auto* cpu_place = new paddle::platform::CPUPlace(); auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place); paddle::platform::CPUDeviceContext ctx(*cpu_place);
paddle::operators::CPUGather<int>(ctx, *src, *index, output); phi::funcs::CPUGather<int>(ctx, *src, *index, output);
delete cpu_place; delete cpu_place;
cpu_place = NULL; cpu_place = NULL;
for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include <utility> #include <utility>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
......
...@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/segment_pooling.h" #include "paddle/fluid/operators/math/segment_pooling.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -379,9 +379,9 @@ class SegmentPoolGradFunctor<platform::CUDADeviceContext, T, IndexT> { ...@@ -379,9 +379,9 @@ class SegmentPoolGradFunctor<platform::CUDADeviceContext, T, IndexT> {
SimpleDiv<T><<<config.block_per_grid.x, config.thread_per_block.x, 0, SimpleDiv<T><<<config.block_per_grid.x, config.thread_per_block.x, 0,
context.stream()>>>(mean_grad.data<T>(), context.stream()>>>(mean_grad.data<T>(),
summed_ids->data<T>(), len, dim); summed_ids->data<T>(), len, dim);
GPUGather<T, IndexT>(context, mean_grad, segments, in_grad); phi::funcs::GPUGather<T, IndexT>(context, mean_grad, segments, in_grad);
} else if (pooltype == "SUM") { } else if (pooltype == "SUM") {
GPUGather<T, IndexT>(context, out_grad, segments, in_grad); phi::funcs::GPUGather<T, IndexT>(context, out_grad, segments, in_grad);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN " "Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN "
......
...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/gather_op.h" #include "paddle/fluid/operators/gather_op.h"
#include "paddle/fluid/operators/scatter.cu.h"
#include "paddle/fluid/operators/scatter_nd_add_op.h" #include "paddle/fluid/operators/scatter_nd_add_op.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -33,22 +33,20 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> { ...@@ -33,22 +33,20 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> {
auto *Out = ctx.Output<Tensor>("Out"); auto *Out = ctx.Output<Tensor>("Out");
framework::TensorCopySync(*X, ctx.GetPlace(), Out); framework::TensorCopySync(*X, ctx.GetPlace(), Out);
const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); const auto &index_type = Ids->dtype();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == phi::DataType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( index_type_match, true,
"Index holds the wrong type, it holds [%s], but " platform::errors::InvalidArgument(
"desires to be [%s] or [%s].", "Index holds the wrong type, it holds [%s], but "
paddle::framework::DataTypeToString(index_type), "desires to be [%s] or [%s].",
paddle::framework::DataTypeToString( index_type, phi::DataType::INT32, phi::DataType::INT64));
framework::proto::VarType::INT32), auto &dev_ctx = ctx.cuda_device_context();
paddle::framework::DataTypeToString( if (index_type == phi::DataType::INT32) {
framework::proto::VarType::INT64))); phi::funcs::GPUScatterNdAdd<T, int32_t>(dev_ctx, *Updates, *Ids, Out);
if (index_type == framework::proto::VarType::INT32) {
GPUScatterNdAdd<DeviceContext, T, int32_t>(ctx, *Updates, *Ids, Out);
} else { } else {
GPUScatterNdAdd<DeviceContext, T, int64_t>(ctx, *Updates, *Ids, Out); phi::funcs::GPUScatterNdAdd<T, int64_t>(dev_ctx, *Updates, *Ids, Out);
} }
} }
}; };
...@@ -69,12 +67,13 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -69,12 +67,13 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> {
} }
if (dUpdates) { if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.cuda_device_context();
// Gradient by Gather // Gradient by Gather
const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); const auto &index_type = Ids->dtype();
if (index_type == framework::proto::VarType::INT32) { if (index_type == phi::DataType::INT32) {
GPUGatherNd<DeviceContext, T, int32_t>(ctx, *dOut, *Ids, dUpdates); phi::funcs::GPUGatherNd<T, int32_t>(dev_ctx, *dOut, *Ids, dUpdates);
} else { } else {
GPUGatherNd<DeviceContext, T, int64_t>(ctx, *dOut, *Ids, dUpdates); phi::funcs::GPUGatherNd<T, int64_t>(dev_ctx, *dOut, *Ids, dUpdates);
} }
} }
} }
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/fluid/operators/scatter.h" #include "paddle/phi/kernels/funcs/scatter.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -37,23 +37,21 @@ class ScatterNdAddOpKernel : public framework::OpKernel<T> { ...@@ -37,23 +37,21 @@ class ScatterNdAddOpKernel : public framework::OpKernel<T> {
// In place output: Out = X // In place output: Out = X
framework::TensorCopySync(*X, ctx.GetPlace(), Out); framework::TensorCopySync(*X, ctx.GetPlace(), Out);
const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); const auto &index_type = Ids->dtype();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == phi::DataType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( index_type_match, true,
"Index holds the wrong type, it holds [%s], but " platform::errors::InvalidArgument(
"desires to be [%s] or [%s].", "Index holds the wrong type, it holds [%s], but "
paddle::framework::DataTypeToString(index_type), "desires to be [%s] or [%s].",
paddle::framework::DataTypeToString( index_type, phi::DataType::INT32, phi::DataType::INT64));
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
ScatterNdAdd<T, int32_t>(ctx, *Updates, *Ids, Out); if (index_type == phi::DataType::INT32) {
phi::funcs::ScatterNdAdd<T, int32_t>(dev_ctx, *Updates, *Ids, Out);
} else { } else {
ScatterNdAdd<T, int64_t>(ctx, *Updates, *Ids, Out); phi::funcs::ScatterNdAdd<T, int64_t>(dev_ctx, *Updates, *Ids, Out);
} }
} }
}; };
...@@ -76,11 +74,12 @@ class ScatterNdAddGradientOpKernel : public framework::OpKernel<T> { ...@@ -76,11 +74,12 @@ class ScatterNdAddGradientOpKernel : public framework::OpKernel<T> {
if (dUpdates) { if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids] // Gradient by Gather: dUpdates = dO[Ids]
const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); const auto &index_type = Ids->dtype();
if (index_type == framework::proto::VarType::INT32) { auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
CPUGatherNd<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates); if (index_type == phi::DataType::INT32) {
phi::funcs::CPUGatherNd<T, int32_t>(dev_ctx, *dOut, *Ids, dUpdates);
} else { } else {
CPUGatherNd<T, int64_t>(ctx.device_context(), *dOut, *Ids, dUpdates); phi::funcs::CPUGatherNd<T, int64_t>(dev_ctx, *dOut, *Ids, dUpdates);
} }
} }
} }
......
...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/gather_op.h" #include "paddle/fluid/operators/gather_op.h"
#include "paddle/fluid/operators/scatter.cu.h"
#include "paddle/fluid/operators/scatter_op.h" #include "paddle/fluid/operators/scatter_op.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -35,23 +35,22 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -35,23 +35,22 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
framework::TensorCopy(*X, ctx.GetPlace(), Out); framework::TensorCopy(*X, ctx.GetPlace(), Out);
// use template class to support int32_t and int64_t // use template class to support int32_t and int64_t
const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); auto index_type = Ids->dtype();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == phi::DataType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index_type_match, true, index_type_match, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"scatter_op Index holds the wrong type, it holds [%s]," "scatter_op Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].", "but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type), index_type, phi::DataType::INT32, phi::DataType::INT64));
paddle::framework::DataTypeToString( auto &dev_ctx = ctx.cuda_device_context();
framework::proto::VarType::INT32), if (index_type == phi::DataType::INT32) {
paddle::framework::DataTypeToString( phi::funcs::GPUScatterAssign<T, int32_t>(dev_ctx, *Updates, *Ids, Out,
framework::proto::VarType::INT64))); overwrite);
if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int32_t>(ctx, *Updates, *Ids, Out, overwrite);
} else { } else {
GPUScatterAssign<T, int64_t>(ctx, *Updates, *Ids, Out, overwrite); phi::funcs::GPUScatterAssign<T, int64_t>(dev_ctx, *Updates, *Ids, Out,
overwrite);
} }
} }
}; };
...@@ -68,36 +67,33 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -68,36 +67,33 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); auto index_type = Ids->dtype();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == phi::DataType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index_type_match, true, index_type_match, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"scatter_op index holds the wrong type, it holds [%s]," "scatter_op index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]", "but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type), index_type, phi::DataType::INT32, phi::DataType::INT64));
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
auto &dev_ctx = ctx.cuda_device_context();
if (dX) { if (dX) {
framework::TensorCopy(*dOut, ctx.GetPlace(), dX); framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
if (index_type == framework::proto::VarType::INT32) { if (index_type == phi::DataType::INT32) {
GPUScatterGradForX<T, int32_t>(ctx.device_context(), *Ids, dX); phi::funcs::GPUScatterGradForX<T, int32_t>(dev_ctx, *Ids, dX);
} else { } else {
GPUScatterGradForX<T, int64_t>(ctx.device_context(), *Ids, dX); phi::funcs::GPUScatterGradForX<T, int64_t>(dev_ctx, *Ids, dX);
} }
} }
if (dUpdates) { if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids] // Gradient by Gather: dUpdates = dO[Ids]
if (index_type == framework::proto::VarType::INT32) { if (index_type == phi::DataType::INT32) {
GPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates); phi::funcs::GPUGather<T, int32_t>(dev_ctx, *dOut, *Ids, dUpdates);
} else { } else {
GPUGather<T, int64_t>(ctx.device_context(), *dOut, *Ids, dUpdates); phi::funcs::GPUGather<T, int64_t>(dev_ctx, *dOut, *Ids, dUpdates);
} }
} }
} }
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/fluid/operators/scatter.h" #include "paddle/phi/kernels/funcs/scatter.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -39,29 +39,27 @@ class ScatterOpKernel : public framework::OpKernel<T> { ...@@ -39,29 +39,27 @@ class ScatterOpKernel : public framework::OpKernel<T> {
// In place output: Out = X, Out[Ids] = Updates // In place output: Out = X, Out[Ids] = Updates
framework::TensorCopy(*X, ctx.GetPlace(), Out); framework::TensorCopy(*X, ctx.GetPlace(), Out);
// Apply ScatterUpdate: Out[index] = Updates[:] // Apply ScatterUpdate: Out[index] = Updates[:]
const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); const auto &index_type = Ids->dtype();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == phi::DataType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( index_type_match, true,
"Index holds the wrong type, it holds [%s]," platform::errors::InvalidArgument(
"but desires to be [%s] or [%s].", "Index holds the wrong type, it holds [%s],"
paddle::framework::DataTypeToString(index_type), "but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString( index_type, phi::DataType::INT32, phi::DataType::INT64));
framework::proto::VarType::INT32), auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (overwrite) { if (overwrite) {
if (index_type == framework::proto::VarType::INT32) { if (index_type == phi::DataType::INT32) {
ScatterAssign<T, int32_t>(ctx.device_context(), *Updates, *Ids, Out); phi::funcs::ScatterAssign<T, int32_t>(dev_ctx, *Updates, *Ids, Out);
} else { } else {
ScatterAssign<T, int64_t>(ctx.device_context(), *Updates, *Ids, Out); phi::funcs::ScatterAssign<T, int64_t>(dev_ctx, *Updates, *Ids, Out);
} }
} else { } else {
if (index_type == framework::proto::VarType::INT32) { if (index_type == phi::DataType::INT32) {
ScatterAssignAdd<T, int32_t>(ctx, *Updates, *Ids, Out); phi::funcs::ScatterAssignAdd<T, int32_t>(dev_ctx, *Updates, *Ids, Out);
} else { } else {
ScatterAssignAdd<T, int64_t>(ctx, *Updates, *Ids, Out); phi::funcs::ScatterAssignAdd<T, int64_t>(dev_ctx, *Updates, *Ids, Out);
} }
} }
} }
...@@ -79,36 +77,33 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> { ...@@ -79,36 +77,33 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); const auto &index_type = Ids->dtype();
bool index_type_match = index_type == framework::proto::VarType::INT32 || bool index_type_match = index_type == phi::DataType::INT32 ||
index_type == framework::proto::VarType::INT64; index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index_type_match, true, index_type_match, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"scatter_op index holds the wrong type, it holds [%s]," "scatter_op index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]", "but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type), index_type, phi::DataType::INT32, phi::DataType::INT64));
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
if (dX) { if (dX) {
framework::TensorCopy(*dOut, ctx.GetPlace(), dX); framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
if (index_type == framework::proto::VarType::INT32) { if (index_type == phi::DataType::INT32) {
CPUScatterGradForX<T, int32_t>(ctx.device_context(), *Ids, dX); phi::funcs::CPUScatterGradForX<T, int32_t>(dev_ctx, *Ids, dX);
} else { } else {
CPUScatterGradForX<T, int64_t>(ctx.device_context(), *Ids, dX); phi::funcs::CPUScatterGradForX<T, int64_t>(dev_ctx, *Ids, dX);
} }
} }
if (dUpdates) { if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids] // Gradient by Gather: dUpdates = dO[Ids]
if (index_type == framework::proto::VarType::INT32) { if (index_type == phi::DataType::INT32) {
CPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates); phi::funcs::CPUGather<T, int32_t>(dev_ctx, *dOut, *Ids, dUpdates);
} else { } else {
CPUGather<T, int64_t>(ctx.device_context(), *dOut, *Ids, dUpdates); phi::funcs::CPUGather<T, int64_t>(dev_ctx, *dOut, *Ids, dUpdates);
} }
} }
} }
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/scatter.h" #include "paddle/phi/kernels/funcs/scatter.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -43,7 +43,7 @@ TEST(scatter, ScatterUpdate) { ...@@ -43,7 +43,7 @@ TEST(scatter, ScatterUpdate) {
auto* cpu_place = new paddle::platform::CPUPlace(); auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place); paddle::platform::CPUDeviceContext ctx(*cpu_place);
paddle::operators::ScatterAssign<float>(ctx, src, index, &output); phi::funcs::ScatterAssign<float>(ctx, src, index, &output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], 0.0f); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], 0.0f);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output.data<float>()[i], 0.0f); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output.data<float>()[i], 0.0f);
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/segment_pool_op.h" #include "paddle/fluid/operators/segment_pool_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
......
...@@ -16,8 +16,6 @@ limitations under the License. */ ...@@ -16,8 +16,6 @@ limitations under the License. */
#include <memory> #include <memory>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/scatter.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -15,8 +15,7 @@ limitations under the License. */ ...@@ -15,8 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/phi/kernels/funcs/scatter.h"
#include "paddle/fluid/operators/scatter.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -11,8 +11,8 @@ limitations under the License. */ ...@@ -11,8 +11,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/viterbi_decode_op.h" #include "paddle/fluid/operators/viterbi_decode_op.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#ifdef __NVCC__ #ifdef __NVCC__
#include "cub/cub.cuh" #include "cub/cub.cuh"
...@@ -62,10 +62,11 @@ int64_t ComputeBlockSize(int64_t col) { ...@@ -62,10 +62,11 @@ int64_t ComputeBlockSize(int64_t col) {
template <template <typename T> typename BinaryFunctor, typename T> template <template <typename T> typename BinaryFunctor, typename T>
struct BinaryOperation<platform::CUDADeviceContext, BinaryFunctor, T> { struct BinaryOperation<platform::CUDADeviceContext, BinaryFunctor, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx, const Tensor& lhs, void operator()(const platform::CUDADeviceContext& dev_ctx,
const Tensor& rhs, Tensor* output) { const framework::Tensor& lhs, const framework::Tensor& rhs,
std::vector<const Tensor*> ins{&lhs, &rhs}; framework::Tensor* output) {
std::vector<Tensor*> outs{output}; std::vector<const framework::Tensor*> ins{&lhs, &rhs};
std::vector<framework::Tensor*> outs{output};
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T,
T>(dev_ctx, ins, &outs, -1, T>(dev_ctx, ins, &outs, -1,
BinaryFunctor<T>()); BinaryFunctor<T>());
...@@ -75,10 +76,11 @@ struct BinaryOperation<platform::CUDADeviceContext, BinaryFunctor, T> { ...@@ -75,10 +76,11 @@ struct BinaryOperation<platform::CUDADeviceContext, BinaryFunctor, T> {
template <template <typename InT, typename OutT> typename CompareFunctor, template <template <typename InT, typename OutT> typename CompareFunctor,
typename T> typename T>
struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> { struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> {
void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs, void operator()(const framework::ExecutionContext& ctx,
const Tensor& rhs, Tensor* mask) { const framework::Tensor& lhs, const framework::Tensor& rhs,
std::vector<const Tensor*> ins = {&lhs, &rhs}; framework::Tensor* mask) {
std::vector<Tensor*> outs = {mask}; std::vector<const framework::Tensor*> ins = {&lhs, &rhs};
std::vector<framework::Tensor*> outs = {mask};
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>( paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
dev_ctx, ins, &outs, CompareFunctor<int64_t, T>()); dev_ctx, ins, &outs, CompareFunctor<int64_t, T>());
...@@ -131,8 +133,9 @@ struct ARange<platform::CUDADeviceContext> { ...@@ -131,8 +133,9 @@ struct ARange<platform::CUDADeviceContext> {
template <typename T, typename IndType> template <typename T, typename IndType>
struct Argmax<platform::CUDADeviceContext, T, IndType> { struct Argmax<platform::CUDADeviceContext, T, IndType> {
void operator()(const framework::ExecutionContext& ctx, const Tensor& input, void operator()(const framework::ExecutionContext& ctx,
Tensor* out_idx, Tensor* out, int axis) { const framework::Tensor& input, framework::Tensor* out_idx,
framework::Tensor* out, int axis) {
framework::DDim input_dims = input.dims(); framework::DDim input_dims = input.dims();
int64_t numel = input.numel(); int64_t numel = input.numel();
int64_t groups = numel / input_dims[axis]; int64_t groups = numel / input_dims[axis];
...@@ -166,8 +169,8 @@ struct Argmax<platform::CUDADeviceContext, T, IndType> { ...@@ -166,8 +169,8 @@ struct Argmax<platform::CUDADeviceContext, T, IndType> {
template <typename T> template <typename T>
struct GetMaxValue<platform::CUDADeviceContext, T> { struct GetMaxValue<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx, void operator()(const platform::CUDADeviceContext& dev_ctx,
const Tensor& input, T* max_value) { const framework::Tensor& input, T* max_value) {
Tensor out_data; framework::Tensor out_data;
out_data.Resize(phi::make_ddim({1})); out_data.Resize(phi::make_ddim({1}));
out_data.mutable_data<T>(platform::CUDAPlace()); out_data.mutable_data<T>(platform::CUDAPlace());
switch (ComputeBlockSize(input.numel())) { switch (ComputeBlockSize(input.numel())) {
...@@ -177,7 +180,7 @@ struct GetMaxValue<platform::CUDADeviceContext, T> { ...@@ -177,7 +180,7 @@ struct GetMaxValue<platform::CUDADeviceContext, T> {
1, input.numel(), 1, input.data<int64_t>(), nullptr, 1, input.numel(), 1, input.data<int64_t>(), nullptr,
out_data.data<int64_t>())); out_data.data<int64_t>()));
} }
Tensor max_value_tensor; framework::Tensor max_value_tensor;
framework::TensorCopy(out_data, platform::CPUPlace(), &max_value_tensor); framework::TensorCopy(out_data, platform::CPUPlace(), &max_value_tensor);
*max_value = max_value_tensor.data<T>()[0]; *max_value = max_value_tensor.data<T>()[0];
} }
...@@ -185,9 +188,10 @@ struct GetMaxValue<platform::CUDADeviceContext, T> { ...@@ -185,9 +188,10 @@ struct GetMaxValue<platform::CUDADeviceContext, T> {
template <typename T, typename IndexT> template <typename T, typename IndexT>
struct Gather<platform::CUDADeviceContext, T, IndexT> { struct Gather<platform::CUDADeviceContext, T, IndexT> {
void operator()(const platform::CUDADeviceContext& ctx, const Tensor& src, void operator()(const platform::CUDADeviceContext& ctx,
const Tensor& index, Tensor* output) { const framework::Tensor& src, const framework::Tensor& index,
GPUGather<T, IndexT>(ctx, src, index, output); framework::Tensor* output) {
phi::funcs::GPUGather<T, IndexT>(ctx, src, index, output);
} }
}; };
......
...@@ -17,10 +17,10 @@ limitations under the License. */ ...@@ -17,10 +17,10 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h" #include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/unique_op.h" #include "paddle/fluid/operators/unique_op.h"
#include "paddle/phi/kernels/funcs/gather.h"
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include <omp.h> #include <omp.h>
#endif #endif
...@@ -28,12 +28,11 @@ limitations under the License. */ ...@@ -28,12 +28,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T, typename IndType> template <typename DeviceContext, typename T, typename IndType>
struct Argmax { struct Argmax {
void operator()(const framework::ExecutionContext& ctx, const Tensor& input, void operator()(const framework::ExecutionContext& ctx,
Tensor* out_idx, Tensor* out, int axis) { const framework::Tensor& input, framework::Tensor* out_idx,
framework::Tensor* out, int axis) {
framework::DDim input_dims = input.dims(); framework::DDim input_dims = input.dims();
int64_t pre = 1; int64_t pre = 1;
int64_t post = 1; int64_t post = 1;
...@@ -82,7 +81,7 @@ struct ARange { ...@@ -82,7 +81,7 @@ struct ARange {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct GetMaxValue { struct GetMaxValue {
void operator()(const DeviceContext& dev_ctx, const Tensor& input, void operator()(const DeviceContext& dev_ctx, const framework::Tensor& input,
T* max_value) { T* max_value) {
auto input_ptr = input.data<T>(); auto input_ptr = input.data<T>();
auto num = input.numel(); auto num = input.numel();
...@@ -92,14 +91,15 @@ struct GetMaxValue { ...@@ -92,14 +91,15 @@ struct GetMaxValue {
template <typename DeviceContext, typename T, typename IndexT = int> template <typename DeviceContext, typename T, typename IndexT = int>
struct Gather { struct Gather {
void operator()(const DeviceContext& ctx, const Tensor& src, void operator()(const DeviceContext& ctx, const framework::Tensor& src,
const Tensor& index, Tensor* output) { const framework::Tensor& index, framework::Tensor* output) {
CPUGather<T, IndexT>(ctx, src, index, output); phi::funcs::CPUGather<T, IndexT>(ctx, src, index, output);
} }
}; };
template <typename T, typename Functor, typename OutT = T> template <typename T, typename Functor, typename OutT = T>
void SameDimsBinaryOP(const Tensor& lhs, const Tensor& rhs, Tensor* out) { void SameDimsBinaryOP(const framework::Tensor& lhs,
const framework::Tensor& rhs, framework::Tensor* out) {
const T* lhs_ptr = lhs.data<T>(); const T* lhs_ptr = lhs.data<T>();
const T* rhs_ptr = rhs.data<T>(); const T* rhs_ptr = rhs.data<T>();
OutT* out_ptr = out->data<OutT>(); OutT* out_ptr = out->data<OutT>();
...@@ -116,8 +116,9 @@ template <typename DeviceContext, ...@@ -116,8 +116,9 @@ template <typename DeviceContext,
template <typename InT, typename OutT> typename CompareFunctor, template <typename InT, typename OutT> typename CompareFunctor,
typename T> typename T>
struct GetMask { struct GetMask {
void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs, void operator()(const framework::ExecutionContext& ctx,
const Tensor& rhs, Tensor* mask) { const framework::Tensor& lhs, const framework::Tensor& rhs,
framework::Tensor* mask) {
SameDimsBinaryOP<int64_t, CompareFunctor<int64_t, T>, T>(lhs, rhs, mask); SameDimsBinaryOP<int64_t, CompareFunctor<int64_t, T>, T>(lhs, rhs, mask);
} }
}; };
...@@ -161,8 +162,9 @@ struct GetInputIndex<false> { ...@@ -161,8 +162,9 @@ struct GetInputIndex<false> {
}; };
template <typename T, typename Functor, bool is_multi_threads = false> template <typename T, typename Functor, bool is_multi_threads = false>
void SimpleBroadcastBinaryOP(const Tensor& lhs, const Tensor& rhs, void SimpleBroadcastBinaryOP(const framework::Tensor& lhs,
Tensor* out) { const framework::Tensor& rhs,
framework::Tensor* out) {
const T* lhs_ptr = lhs.data<T>(); const T* lhs_ptr = lhs.data<T>();
const T* rhs_ptr = rhs.data<T>(); const T* rhs_ptr = rhs.data<T>();
T* out_ptr = out->data<T>(); T* out_ptr = out->data<T>();
...@@ -200,8 +202,8 @@ void SimpleBroadcastBinaryOP(const Tensor& lhs, const Tensor& rhs, ...@@ -200,8 +202,8 @@ void SimpleBroadcastBinaryOP(const Tensor& lhs, const Tensor& rhs,
template <typename DeviceContext, template <typename T> typename BinaryFunctor, template <typename DeviceContext, template <typename T> typename BinaryFunctor,
typename T> typename T>
struct BinaryOperation { struct BinaryOperation {
void operator()(const DeviceContext& dev_ctx, const Tensor& lhs, void operator()(const DeviceContext& dev_ctx, const framework::Tensor& lhs,
const Tensor& rhs, Tensor* output) { const framework::Tensor& rhs, framework::Tensor* output) {
if (lhs.dims() == rhs.dims()) { if (lhs.dims() == rhs.dims()) {
SameDimsBinaryOP<T, BinaryFunctor<T>>(lhs, rhs, output); SameDimsBinaryOP<T, BinaryFunctor<T>>(lhs, rhs, output);
} else { } else {
...@@ -222,20 +224,21 @@ struct BinaryOperation { ...@@ -222,20 +224,21 @@ struct BinaryOperation {
class TensorBuffer { class TensorBuffer {
public: public:
explicit TensorBuffer(const LoDTensor& in) : buffer_(in), offset_(0) { explicit TensorBuffer(const framework::LoDTensor& in)
: buffer_(in), offset_(0) {
buffer_.Resize({buffer_.numel()}); buffer_.Resize({buffer_.numel()});
} }
Tensor GetBufferBlock(std::initializer_list<int64_t> shape) { framework::Tensor GetBufferBlock(std::initializer_list<int64_t> shape) {
int64_t size = std::accumulate(shape.begin(), shape.end(), 1, int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>()); std::multiplies<int64_t>());
Tensor block = buffer_.Slice(offset_, offset_ + size); framework::Tensor block = buffer_.Slice(offset_, offset_ + size);
offset_ += size; offset_ += size;
block.Resize(shape); block.Resize(shape);
return block; return block;
} }
private: private:
LoDTensor buffer_; // need to resize 1-D Tensor framework::LoDTensor buffer_; // need to resize 1-D Tensor
int offset_; int offset_;
}; };
...@@ -246,17 +249,17 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> { ...@@ -246,17 +249,17 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> {
bool include_bos_eos_tag = ctx.Attr<bool>("include_bos_eos_tag"); bool include_bos_eos_tag = ctx.Attr<bool>("include_bos_eos_tag");
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto curr_place = ctx.GetPlace(); auto curr_place = ctx.GetPlace();
auto* input = ctx.Input<Tensor>("Input"); auto* input = ctx.Input<framework::Tensor>("Input");
auto batch_size = static_cast<int>(input->dims()[0]); auto batch_size = static_cast<int>(input->dims()[0]);
auto seq_len = static_cast<int>(input->dims()[1]); auto seq_len = static_cast<int>(input->dims()[1]);
auto n_labels = static_cast<int>(input->dims()[2]); auto n_labels = static_cast<int>(input->dims()[2]);
phi::funcs::SetConstant<DeviceContext, T> float_functor; phi::funcs::SetConstant<DeviceContext, T> float_functor;
phi::funcs::SetConstant<DeviceContext, int64_t> int_functor; phi::funcs::SetConstant<DeviceContext, int64_t> int_functor;
std::vector<Tensor> historys; std::vector<framework::Tensor> historys;
// We create tensor buffer in order to avoid allocating memory frequently // We create tensor buffer in order to avoid allocating memory frequently
// 10 means allocate 10*batch_size bytes memory, such as int_mask, zero... // 10 means allocate 10*batch_size bytes memory, such as int_mask, zero...
int buffer_size = batch_size * (n_labels + 1) * seq_len + 10 * batch_size; int buffer_size = batch_size * (n_labels + 1) * seq_len + 10 * batch_size;
LoDTensor int_buffer; framework::LoDTensor int_buffer;
int_buffer.Resize(phi::make_ddim({buffer_size})); int_buffer.Resize(phi::make_ddim({buffer_size}));
int_buffer.mutable_data<int64_t>(ctx.GetPlace()); int_buffer.mutable_data<int64_t>(ctx.GetPlace());
TensorBuffer int_tensor_buffer(int_buffer); TensorBuffer int_tensor_buffer(int_buffer);
...@@ -264,64 +267,78 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> { ...@@ -264,64 +267,78 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> {
// 10 means allocate 10*batch_size*n_labels bytes, such as alpha, alpha_max // 10 means allocate 10*batch_size*n_labels bytes, such as alpha, alpha_max
buffer_size = batch_size * (seq_len + 10) * n_labels + buffer_size = batch_size * (seq_len + 10) * n_labels +
(batch_size + 2) * n_labels * n_labels; (batch_size + 2) * n_labels * n_labels;
LoDTensor float_buffer; framework::LoDTensor float_buffer;
float_buffer.Resize(phi::make_ddim({buffer_size})); float_buffer.Resize(phi::make_ddim({buffer_size}));
float_buffer.mutable_data<T>(ctx.GetPlace()); float_buffer.mutable_data<T>(ctx.GetPlace());
TensorBuffer float_tensor_buffer(float_buffer); TensorBuffer float_tensor_buffer(float_buffer);
auto* length = ctx.Input<Tensor>("Length"); auto* length = ctx.Input<framework::Tensor>("Length");
Tensor left_length = int_tensor_buffer.GetBufferBlock({batch_size, 1}); framework::Tensor left_length =
int_tensor_buffer.GetBufferBlock({batch_size, 1});
framework::TensorCopy(*length, curr_place, dev_ctx, &left_length); framework::TensorCopy(*length, curr_place, dev_ctx, &left_length);
int64_t max_seq_len = 0; int64_t max_seq_len = 0;
GetMaxValue<DeviceContext, int64_t> get_max_value; GetMaxValue<DeviceContext, int64_t> get_max_value;
get_max_value(dev_ctx, left_length, &max_seq_len); get_max_value(dev_ctx, left_length, &max_seq_len);
auto* scores = ctx.Output<Tensor>("Scores"); auto* scores = ctx.Output<framework::Tensor>("Scores");
scores->mutable_data<T>(curr_place); scores->mutable_data<T>(curr_place);
auto* path = ctx.Output<Tensor>("Path"); auto* path = ctx.Output<framework::Tensor>("Path");
path->Resize({batch_size, max_seq_len}); path->Resize({batch_size, max_seq_len});
path->mutable_data<int64_t>(curr_place); path->mutable_data<int64_t>(curr_place);
Tensor tpath = int_tensor_buffer.GetBufferBlock({max_seq_len, batch_size}); framework::Tensor tpath =
int_tensor_buffer.GetBufferBlock({max_seq_len, batch_size});
auto batch_path = Unbind(tpath); auto batch_path = Unbind(tpath);
for (auto it = batch_path.begin(); it != batch_path.end(); ++it) { for (auto it = batch_path.begin(); it != batch_path.end(); ++it) {
it->Resize({batch_size}); it->Resize({batch_size});
} }
// create and init required tensor // create and init required tensor
Tensor input_exp = framework::Tensor input_exp =
float_tensor_buffer.GetBufferBlock({seq_len, batch_size, n_labels}); float_tensor_buffer.GetBufferBlock({seq_len, batch_size, n_labels});
TransCompute<DeviceContext, T>(3, dev_ctx, *input, &input_exp, {1, 0, 2}); TransCompute<DeviceContext, T>(3, dev_ctx, *input, &input_exp, {1, 0, 2});
auto* transition = ctx.Input<Tensor>("Transition"); auto* transition = ctx.Input<framework::Tensor>("Transition");
Tensor trans_exp = float_tensor_buffer.GetBufferBlock({n_labels, n_labels}); framework::Tensor trans_exp =
float_tensor_buffer.GetBufferBlock({n_labels, n_labels});
framework::TensorCopy(*transition, curr_place, dev_ctx, &trans_exp); framework::TensorCopy(*transition, curr_place, dev_ctx, &trans_exp);
trans_exp.Resize({1, n_labels, n_labels}); trans_exp.Resize({1, n_labels, n_labels});
Tensor alpha = float_tensor_buffer.GetBufferBlock({batch_size, n_labels}); framework::Tensor alpha =
Tensor zero = int_tensor_buffer.GetBufferBlock({batch_size, 1}); float_tensor_buffer.GetBufferBlock({batch_size, n_labels});
framework::Tensor zero = int_tensor_buffer.GetBufferBlock({batch_size, 1});
int_functor(dev_ctx, &zero, 0); int_functor(dev_ctx, &zero, 0);
Tensor one = int_tensor_buffer.GetBufferBlock({batch_size, 1}); framework::Tensor one = int_tensor_buffer.GetBufferBlock({batch_size, 1});
int_functor(dev_ctx, &one, 1); int_functor(dev_ctx, &one, 1);
Tensor float_one = float_tensor_buffer.GetBufferBlock({batch_size, 1}); framework::Tensor float_one =
float_tensor_buffer.GetBufferBlock({batch_size, 1});
float_functor(dev_ctx, &float_one, static_cast<T>(1.0)); float_functor(dev_ctx, &float_one, static_cast<T>(1.0));
Tensor alpha_trn_sum = framework::Tensor alpha_trn_sum =
float_tensor_buffer.GetBufferBlock({batch_size, n_labels, n_labels}); float_tensor_buffer.GetBufferBlock({batch_size, n_labels, n_labels});
Tensor alpha_max = framework::Tensor alpha_max =
float_tensor_buffer.GetBufferBlock({batch_size, n_labels}); float_tensor_buffer.GetBufferBlock({batch_size, n_labels});
Tensor alpha_argmax = framework::Tensor alpha_argmax =
int_tensor_buffer.GetBufferBlock({seq_len, batch_size, n_labels}); int_tensor_buffer.GetBufferBlock({seq_len, batch_size, n_labels});
auto alpha_argmax_unbind = Unbind(alpha_argmax); auto alpha_argmax_unbind = Unbind(alpha_argmax);
Tensor alpha_nxt = framework::Tensor alpha_nxt =
float_tensor_buffer.GetBufferBlock({batch_size, n_labels}); float_tensor_buffer.GetBufferBlock({batch_size, n_labels});
Tensor int_mask = int_tensor_buffer.GetBufferBlock({batch_size}); framework::Tensor int_mask = int_tensor_buffer.GetBufferBlock({batch_size});
Tensor zero_len_mask = int_tensor_buffer.GetBufferBlock({batch_size}); framework::Tensor zero_len_mask =
Tensor float_mask = float_tensor_buffer.GetBufferBlock({batch_size, 1}); int_tensor_buffer.GetBufferBlock({batch_size});
Tensor stop_trans = float_tensor_buffer.GetBufferBlock({1, 1, n_labels}); framework::Tensor float_mask =
Tensor start_trans = float_tensor_buffer.GetBufferBlock({1, 1, n_labels}); float_tensor_buffer.GetBufferBlock({batch_size, 1});
Tensor rest_trans = framework::Tensor stop_trans =
float_tensor_buffer.GetBufferBlock({1, 1, n_labels});
framework::Tensor start_trans =
float_tensor_buffer.GetBufferBlock({1, 1, n_labels});
framework::Tensor rest_trans =
float_tensor_buffer.GetBufferBlock({1, n_labels - 2, n_labels}); float_tensor_buffer.GetBufferBlock({1, n_labels - 2, n_labels});
Tensor last_ids = int_tensor_buffer.GetBufferBlock({batch_size}); framework::Tensor last_ids = int_tensor_buffer.GetBufferBlock({batch_size});
Tensor last_ids_tmp = int_tensor_buffer.GetBufferBlock({batch_size}); framework::Tensor last_ids_tmp =
Tensor batch_offset = int_tensor_buffer.GetBufferBlock({batch_size}); int_tensor_buffer.GetBufferBlock({batch_size});
Tensor gather_idx = int_tensor_buffer.GetBufferBlock({batch_size}); framework::Tensor batch_offset =
std::vector<const Tensor*> shape{&rest_trans, &stop_trans, &start_trans}; int_tensor_buffer.GetBufferBlock({batch_size});
std::vector<Tensor*> outputs{&rest_trans, &stop_trans, &start_trans}; framework::Tensor gather_idx =
int_tensor_buffer.GetBufferBlock({batch_size});
std::vector<const framework::Tensor*> shape{&rest_trans, &stop_trans,
&start_trans};
std::vector<framework::Tensor*> outputs{&rest_trans, &stop_trans,
&start_trans};
math::SplitFunctor<DeviceContext, T> split_functor; math::SplitFunctor<DeviceContext, T> split_functor;
split_functor(dev_ctx, trans_exp, shape, 1, &outputs); split_functor(dev_ctx, trans_exp, shape, 1, &outputs);
stop_trans.Resize({1, n_labels}); stop_trans.Resize({1, n_labels});
...@@ -346,9 +363,9 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> { ...@@ -346,9 +363,9 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> {
SubInt(dev_ctx, left_length, one, &left_length); SubInt(dev_ctx, left_length, one, &left_length);
Argmax<DeviceContext, T, int64_t> argmax; Argmax<DeviceContext, T, int64_t> argmax;
for (int64_t i = 1; i < max_seq_len; ++i) { for (int64_t i = 1; i < max_seq_len; ++i) {
Tensor logit = input_exp.Slice(i, i + 1); framework::Tensor logit = input_exp.Slice(i, i + 1);
logit.Resize({batch_size, n_labels}); logit.Resize({batch_size, n_labels});
Tensor& alpha_exp = alpha.Resize({batch_size, n_labels, 1}); framework::Tensor& alpha_exp = alpha.Resize({batch_size, n_labels, 1});
AddFloat(dev_ctx, alpha_exp, trans_exp, &alpha_trn_sum); AddFloat(dev_ctx, alpha_exp, trans_exp, &alpha_trn_sum);
auto alpha_argmax_temp = alpha_argmax_unbind[i - 1]; auto alpha_argmax_temp = alpha_argmax_unbind[i - 1];
alpha_argmax_temp.Resize({batch_size, n_labels}); alpha_argmax_temp.Resize({batch_size, n_labels});
...@@ -395,7 +412,8 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> { ...@@ -395,7 +412,8 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> {
++last_ids_index; ++last_ids_index;
AddInt(dev_ctx, left_length, one, &left_length); AddInt(dev_ctx, left_length, one, &left_length);
AddInt(dev_ctx, batch_offset, last_ids, &gather_idx); AddInt(dev_ctx, batch_offset, last_ids, &gather_idx);
Tensor& last_ids_update = batch_path[actual_len - last_ids_index]; framework::Tensor& last_ids_update =
batch_path[actual_len - last_ids_index];
hist->Resize({batch_size * n_labels}); hist->Resize({batch_size * n_labels});
gather(dev_ctx, *hist, gather_idx, &last_ids_update); gather(dev_ctx, *hist, gather_idx, &last_ids_update);
GetMask<DeviceContext, GreaterThanFunctor, int64_t>()(ctx, left_length, GetMask<DeviceContext, GreaterThanFunctor, int64_t>()(ctx, left_length,
......
...@@ -13,24 +13,25 @@ See the License for the specific language governing permissions and ...@@ -13,24 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/framework/tensor.h" // TODO(paddle-dev): move gpu_primitives.h to phi
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/place.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/utils/dim.h" #include "paddle/phi/core/utils/dim.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using framework::Tensor; namespace phi {
using platform::DeviceContext; namespace funcs {
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
__global__ void GatherCUDAKernel(const T* params, const IndexT* indices, __global__ void GatherCUDAKernel(const T* params,
T* output, size_t index_size, const IndexT* indices,
T* output,
size_t index_size,
size_t slice_size) { size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size; int64_t indices_i = i / slice_size;
...@@ -42,9 +43,12 @@ __global__ void GatherCUDAKernel(const T* params, const IndexT* indices, ...@@ -42,9 +43,12 @@ __global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
} }
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
__global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims, __global__ void GatherNdCUDAKernel(const T* input,
const IndexT* indices, T* output, const int64_t* input_dims,
size_t remain_size, size_t slice_size, const IndexT* indices,
T* output,
size_t remain_size,
size_t slice_size,
size_t end_size) { size_t end_size) {
CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size; int64_t indices_i = i / slice_size;
...@@ -59,7 +63,8 @@ __global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims, ...@@ -59,7 +63,8 @@ __global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims,
"please check whether the dimensions of index and " "please check whether the dimensions of index and "
"input meet the requirements. It should " "input meet the requirements. It should "
"be less than [%d] and greater than or equal to 0, but received [%d]", "be less than [%d] and greater than or equal to 0, but received [%d]",
input_dims[j], index_value); input_dims[j],
index_value);
gather_i += (index_value * temp); gather_i += (index_value * temp);
temp *= input_dims[j]; temp *= input_dims[j];
} }
...@@ -76,13 +81,16 @@ __global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims, ...@@ -76,13 +81,16 @@ __global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims,
* return: output tensor * return: output tensor
*/ */
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, void GPUGather(const phi::GPUContext& ctx,
const Tensor& index, Tensor* output) { const DenseTensor& src,
const DenseTensor& index,
DenseTensor* output) {
if (index.dims().size() == 2) { if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( index.dims()[1],
"If the index's rank of gather_op is 2," 1,
" the second dimension should be 1.")); phi::errors::InvalidArgument("If the index's rank of gather_op is 2,"
" the second dimension should be 1."));
} }
// index size // index size
...@@ -90,7 +98,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -90,7 +98,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
if (index_size == 0) return; if (index_size == 0) return;
auto src_dims = src.dims(); auto src_dims = src.dims();
framework::DDim output_dims(src_dims); phi::DDim output_dims(src_dims);
output_dims[0] = index_size; output_dims[0] = index_size;
// slice size // slice size
...@@ -105,18 +113,17 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -105,18 +113,17 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
int64_t n = slice_size * index_size; int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block; int64_t grid = (n + block - 1) / block;
GatherCUDAKernel<T, IndexT><<< GatherCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_src, p_index, p_output, index_size, slice_size); p_src, p_index, p_output, index_size, slice_size);
} }
template <typename DeviceContext, typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void GPUGatherNd(const framework::ExecutionContext& context, void GPUGatherNd(const phi::GPUContext& ctx,
const Tensor& input, const Tensor& index, Tensor* output) { const DenseTensor& input,
const auto& ctx = context.template device_context<DeviceContext>(); const DenseTensor& index,
DenseTensor* output) {
const auto gplace = ctx.GetPlace(); const auto gplace = ctx.GetPlace();
auto cplace = platform::CPUPlace(); auto cplace = phi::CPUPlace();
auto index_dims = index.dims(); auto index_dims = index.dims();
auto index_dims_size = index_dims.size(); auto index_dims_size = index_dims.size();
...@@ -143,29 +150,36 @@ void GPUGatherNd(const framework::ExecutionContext& context, ...@@ -143,29 +150,36 @@ void GPUGatherNd(const framework::ExecutionContext& context,
v_input_dims[i] = input_dims[i]; v_input_dims[i] = input_dims[i];
} }
auto& dev_ctx = context.cuda_device_context(); phi::DenseTensor input_dims_tensor;
input_dims_tensor.Resize({input_dims_size});
auto* g_input_dims = ctx.Alloc<int64_t>(&input_dims_tensor);
int64_t bytes = input_dims_size * sizeof(int64_t); int64_t bytes = input_dims_size * sizeof(int64_t);
auto p_input_dims = memory::Alloc(dev_ctx, bytes);
int64_t* g_input_dims = reinterpret_cast<int64_t*>(p_input_dims->ptr()); paddle::memory::Copy(
memory::Copy(gplace, g_input_dims, cplace, v_input_dims.data(), bytes, gplace, g_input_dims, cplace, v_input_dims.data(), bytes, ctx.stream());
ctx.stream());
int block = 512; int block = 512;
int64_t n = slice_size * remain_numel; int64_t n = slice_size * remain_numel;
int64_t grid = (n + block - 1) / block; int64_t grid = (n + block - 1) / block;
GatherNdCUDAKernel<T, IndexT><<< GatherNdCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(p_input,
grid, block, 0, g_input_dims,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>( p_index,
p_input, g_input_dims, p_index, p_output, remain_numel, slice_size, p_output,
end_size); remain_numel,
slice_size,
end_size);
} }
template <typename T, typename U> template <typename T, typename U>
__global__ void GatherGPUKernel(const T* input, const U* index, T* out, __global__ void GatherGPUKernel(const T* input,
int64_t outer_dim_size, int64_t inner_dim_size, const U* index,
T* out,
int64_t outer_dim_size,
int64_t inner_dim_size,
int64_t out_index_dim_size, int64_t out_index_dim_size,
int64_t input_index_dim_size, int64_t size) { int64_t input_index_dim_size,
int64_t size) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
int64_t outer_size = outer_dim_size * out_index_dim_size; int64_t outer_size = outer_dim_size * out_index_dim_size;
for (; idx < size; idx += blockDim.x * gridDim.x) { for (; idx < size; idx += blockDim.x * gridDim.x) {
...@@ -180,7 +194,8 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out, ...@@ -180,7 +194,8 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out,
"please check whether the dimensions of index and " "please check whether the dimensions of index and "
"input meet the requirements. It should " "input meet the requirements. It should "
"be less than [%d] and greater than or equal to 0, but received [%d]", "be less than [%d] and greater than or equal to 0, but received [%d]",
input_index_dim_size, index_val); input_index_dim_size,
index_val);
int64_t out_dim_index = next_idx - outer_dim_size * index_dim_index; int64_t out_dim_index = next_idx - outer_dim_size * index_dim_index;
int64_t input_index = int64_t input_index =
...@@ -191,11 +206,14 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out, ...@@ -191,11 +206,14 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out,
} }
template <typename T, typename U> template <typename T, typename U>
__global__ void GatherGradGPUKernel(const T* input, const U* index, T* out, __global__ void GatherGradGPUKernel(const T* input,
const U* index,
T* out,
int64_t outer_dim_size, int64_t outer_dim_size,
int64_t inner_dim_size, int64_t inner_dim_size,
int64_t input_index_dim_size, int64_t input_index_dim_size,
int64_t out_index_dim_size, int64_t size) { int64_t out_index_dim_size,
int64_t size) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) { for (; idx < size; idx += blockDim.x * gridDim.x) {
int64_t inner_dim_index = idx / (outer_dim_size * input_index_dim_size); int64_t inner_dim_index = idx / (outer_dim_size * input_index_dim_size);
...@@ -210,10 +228,11 @@ __global__ void GatherGradGPUKernel(const T* input, const U* index, T* out, ...@@ -210,10 +228,11 @@ __global__ void GatherGradGPUKernel(const T* input, const U* index, T* out,
} }
template <typename T, typename U> template <typename T, typename U>
void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, void GatherV2CUDAFunction(const DenseTensor* input,
const int axis, Tensor* out, const DenseTensor* index,
const paddle::platform::Place& place, const int axis,
const framework::ExecutionContext& ctx) { DenseTensor* out,
const phi::GPUContext& ctx) {
int64_t index_size = index->numel(); int64_t index_size = index->numel();
int64_t input_size = input->numel(); int64_t input_size = input->numel();
auto input_dim = input->dims(); auto input_dim = input->dims();
...@@ -241,24 +260,31 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, ...@@ -241,24 +260,31 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
auto out_dim = phi::make_ddim(out_dim_vec); auto out_dim = phi::make_ddim(out_dim_vec);
out->Resize(out_dim); out->Resize(out_dim);
auto* out_data = out->mutable_data<T>(place); auto* out_data = ctx.Alloc<T>(out);
int64_t out_size = out->numel(); int64_t out_size = out->numel();
if (out_size == 0) return; if (out_size == 0) return;
platform::GpuLaunchConfig config = auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, out_size);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), out_size); auto stream = ctx.stream();
auto stream = ctx.cuda_device_context().stream();
GatherGPUKernel< GatherGPUKernel<
T, U><<<config.block_per_grid, config.thread_per_block, 0, stream>>>( T,
input_data, index_data, out_data, outer_dim_size, inner_dim_size, U><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
index_size, index_dim_size, out_size); input_data,
index_data,
out_data,
outer_dim_size,
inner_dim_size,
index_size,
index_dim_size,
out_size);
} }
template <typename T, typename U> template <typename T, typename U>
void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, void GatherV2GradCUDAFunction(const DenseTensor* input,
const int axis, Tensor* out, const DenseTensor* index,
const paddle::platform::Place& place, const int axis,
const framework::ExecutionContext& ctx) { DenseTensor* out,
const phi::GPUContext& ctx) {
auto* index_data = index->data<U>(); auto* index_data = index->data<U>();
int64_t index_size = index->numel(); int64_t index_size = index->numel();
int64_t input_size = input->numel(); int64_t input_size = input->numel();
...@@ -279,19 +305,25 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, ...@@ -279,19 +305,25 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
outer_dim_size *= input_dim[i]; outer_dim_size *= input_dim[i];
} }
auto* out_data = out->mutable_data<T>(place); auto* out_data = ctx.Alloc<T>(out);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto out_dim = out->dims(); auto out_dim = out->dims();
int64_t out_index_dim_size = out_dim[axis_index]; int64_t out_index_dim_size = out_dim[axis_index];
phi::funcs::set_constant(*dev_ctx, out, 0.0); phi::funcs::set_constant(ctx, out, 0.0);
platform::GpuLaunchConfig config = auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, input_size);
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), input_size); auto stream = ctx.stream();
auto stream = ctx.cuda_device_context().stream();
GatherGradGPUKernel< GatherGradGPUKernel<
T, U><<<config.block_per_grid, config.thread_per_block, 0, stream>>>( T,
input_data, index_data, out_data, outer_dim_size, inner_dim_size, U><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
input_index_dim_size, out_index_dim_size, input_size); input_data,
index_data,
out_data,
outer_dim_size,
inner_dim_size,
input_index_dim_size,
out_index_dim_size,
input_size);
} }
} // namespace operators
} // namespace paddle } // namespace funcs
} // namespace phi
...@@ -17,16 +17,13 @@ limitations under the License. */ ...@@ -17,16 +17,13 @@ limitations under the License. */
#include <cstring> #include <cstring>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/phi/common/place.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
using framework::Tensor;
/** /**
* A thin wrapper for gathering on cpu tensor * A thin wrapper for gathering on cpu tensor
...@@ -36,22 +33,23 @@ using framework::Tensor; ...@@ -36,22 +33,23 @@ using framework::Tensor;
* return: output tensor * return: output tensor
*/ */
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, void CPUGather(const phi::CPUContext& ctx,
const Tensor& index, Tensor* output) { const DenseTensor& src,
PADDLE_ENFORCE_EQ( const DenseTensor& index,
platform::is_cpu_place(ctx.GetPlace()), true, DenseTensor* output) {
platform::errors::PreconditionNotMet("It should be running on the CPU."));
// check index of shape 1-D // check index of shape 1-D
if (index.dims().size() == 2) { if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index.dims()[1], 1, index.dims()[1],
platform::errors::InvalidArgument( 1,
phi::errors::InvalidArgument(
"index.dims()[1] should be 1 when index.dims().size() = 2" "index.dims()[1] should be 1 when index.dims().size() = 2"
"in gather_op, but received value is [%d].", "in gather_op, but received value is [%d].",
index.dims()[1])); index.dims()[1]));
} else { } else {
PADDLE_ENFORCE_EQ(index.dims().size(), 1, PADDLE_ENFORCE_EQ(index.dims().size(),
platform::errors::InvalidArgument( 1,
phi::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in gather_op," "index.dims().size() should be 1 or 2 in gather_op,"
"but received shape's size is [%d].", "but received shape's size is [%d].",
index.dims().size())); index.dims().size()));
...@@ -74,29 +72,32 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -74,29 +72,32 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
for (int64_t i = 0; i < index_size; ++i) { for (int64_t i = 0; i < index_size; ++i) {
IndexT index_ = p_index[i]; IndexT index_ = p_index[i];
PADDLE_ENFORCE_LT(p_index[i], input_size, PADDLE_ENFORCE_LT(p_index[i],
platform::errors::OutOfRange( input_size,
phi::errors::OutOfRange(
"The element of Index must be less than the size of " "The element of Index must be less than the size of "
"input dim size of axis which is %d, but received " "input dim size of axis which is %d, but received "
"index element which is %d in the %d index.", "index element which is %d in the %d index.",
input_size, p_index[i], i)); input_size,
PADDLE_ENFORCE_GE(p_index[i], 0, p_index[i],
platform::errors::OutOfRange( i));
PADDLE_ENFORCE_GE(p_index[i],
0,
phi::errors::OutOfRange(
"The element of Index must be greater than or equal " "The element of Index must be greater than or equal "
"to 0, but received index element which is %d in the " "to 0, but received index element which is %d in the "
"%d index.", "%d index.",
p_index[i], i)); p_index[i],
i));
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes); memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
} }
} }
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input, void CPUGatherNd(const phi::CPUContext& ctx,
const Tensor& index, Tensor* output) { const DenseTensor& input,
PADDLE_ENFORCE_EQ( const DenseTensor& index,
platform::is_cpu_place(ctx.GetPlace()), true, DenseTensor* output) {
platform::errors::PreconditionNotMet("It should be running on the CPU."));
auto index_dims = index.dims(); auto index_dims = index.dims();
auto index_dims_size = index_dims.size(); auto index_dims_size = index_dims.size();
auto input_dims = input.dims(); auto input_dims = input.dims();
...@@ -124,25 +125,30 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input, ...@@ -124,25 +125,30 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
for (int64_t j = end_size - 1; j >= 0; --j) { for (int64_t j = end_size - 1; j >= 0; --j) {
IndexT index_value = p_index[i * end_size + j]; IndexT index_value = p_index[i * end_size + j];
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
index_value, input_dims[j], index_value,
platform::errors::InvalidArgument( input_dims[j],
phi::errors::InvalidArgument(
"Input(index[-1)] has wrong value, it is [%d]", index_value)); "Input(index[-1)] has wrong value, it is [%d]", index_value));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
index_value, 0, index_value,
platform::errors::InvalidArgument( 0,
phi::errors::InvalidArgument(
"The value of Input(index) must be no less than 0")); "The value of Input(index) must be no less than 0"));
index_ += (index_value * temp); index_ += (index_value * temp);
temp *= input_dims[j]; temp *= input_dims[j];
} }
memcpy(p_output + i * slice_size, p_input + index_ * slice_size, memcpy(
slice_bytes); p_output + i * slice_size, p_input + index_ * slice_size, slice_bytes);
} }
} }
template <typename T, typename U> template <typename T, typename U>
void GatherV2Function(const Tensor* input, const Tensor* index, int axis, void GatherV2Function(const phi::CPUContext& ctx,
Tensor* out, const paddle::platform::Place& place) { const DenseTensor* input,
const DenseTensor* index,
int axis,
DenseTensor* out) {
auto* index_data = index->data<U>(); auto* index_data = index->data<U>();
int64_t index_size = index->numel(); int64_t index_size = index->numel();
int64_t input_size = input->numel(); int64_t input_size = input->numel();
...@@ -154,18 +160,23 @@ void GatherV2Function(const Tensor* input, const Tensor* index, int axis, ...@@ -154,18 +160,23 @@ void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
int64_t input_index_dim_size = input_dim[axis_index]; int64_t input_index_dim_size = input_dim[axis_index];
for (int64_t i = 0; i < index_size; i++) { for (int64_t i = 0; i < index_size; i++) {
PADDLE_ENFORCE_LT(index_data[i], input_index_dim_size, PADDLE_ENFORCE_LT(index_data[i],
platform::errors::OutOfRange( input_index_dim_size,
phi::errors::OutOfRange(
"The element of Index must be less than the size of " "The element of Index must be less than the size of "
"input dim size of axis which is %d, but received " "input dim size of axis which is %d, but received "
"index element which is %d in the %d index.", "index element which is %d in the %d index.",
input_index_dim_size, index_data[i], i)); input_index_dim_size,
PADDLE_ENFORCE_GE(index_data[i], 0, index_data[i],
platform::errors::OutOfRange( i));
PADDLE_ENFORCE_GE(index_data[i],
0,
phi::errors::OutOfRange(
"The element of Index must be greater than or equal " "The element of Index must be greater than or equal "
"to 0, but received index element which is %d in the " "to 0, but received index element which is %d in the "
"%d index.", "%d index.",
index_data[i], i)); index_data[i],
i));
} }
int64_t inner_dim_size = 1; int64_t inner_dim_size = 1;
...@@ -184,7 +195,7 @@ void GatherV2Function(const Tensor* input, const Tensor* index, int axis, ...@@ -184,7 +195,7 @@ void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
auto out_dim = phi::make_ddim(out_dim_vec); auto out_dim = phi::make_ddim(out_dim_vec);
out->Resize(out_dim); out->Resize(out_dim);
auto* out_data = out->mutable_data<T>(place); auto* out_data = ctx.Alloc<T>(out);
int out_index = 0; int out_index = 0;
for (int64_t i = 0; i < inner_dim_size; i++) { for (int64_t i = 0; i < inner_dim_size; i++) {
...@@ -200,9 +211,11 @@ void GatherV2Function(const Tensor* input, const Tensor* index, int axis, ...@@ -200,9 +211,11 @@ void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
} }
template <typename T, typename U> template <typename T, typename U>
void GatherV2GradFunction(const Tensor* input, const Tensor* index, void GatherV2GradFunction(const phi::CPUContext& ctx,
const int axis, Tensor* out, const DenseTensor* input,
const paddle::platform::Place& place) { const DenseTensor* index,
const int axis,
DenseTensor* out) {
auto* index_data = index->data<U>(); auto* index_data = index->data<U>();
auto input_dim = input->dims(); auto input_dim = input->dims();
...@@ -222,11 +235,10 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index, ...@@ -222,11 +235,10 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index,
outer_dim_size *= input_dim[i]; outer_dim_size *= input_dim[i];
} }
auto* out_data = out->mutable_data<T>(place); auto* out_data = ctx.Alloc<T>(out);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto out_dim = out->dims(); auto out_dim = out->dims();
int64_t out_index_dim_size = out_dim[axis_index]; int64_t out_index_dim_size = out_dim[axis_index];
phi::funcs::set_constant(*dev_ctx, out, 0.0); phi::funcs::set_constant(ctx, out, 0.0);
for (int64_t i = 0; i < inner_dim_size; i++) { for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < input_index_dim_size; j++) { for (int64_t j = 0; j < input_index_dim_size; j++) {
...@@ -239,5 +251,5 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index, ...@@ -239,5 +251,5 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index,
} }
} }
} // namespace operators } // namespace funcs
} // namespace paddle } // namespace phi
...@@ -15,20 +15,19 @@ limitations under the License. */ ...@@ -15,20 +15,19 @@ limitations under the License. */
#pragma once #pragma once
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
using Tensor = framework::Tensor;
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
__global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output, __global__ void ScatterInitCUDAKernel(const IndexT* indices,
size_t index_size, size_t slice_size) { T* output,
size_t index_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size; int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
...@@ -47,9 +46,12 @@ __global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output, ...@@ -47,9 +46,12 @@ __global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output,
} }
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
__global__ void ScatterCUDAKernel(const T* params, const IndexT* indices, __global__ void ScatterCUDAKernel(const T* params,
T* output, size_t index_size, const IndexT* indices,
size_t slice_size, bool overwrite) { T* output,
size_t index_size,
size_t slice_size,
bool overwrite) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size; int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
...@@ -72,9 +74,12 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices, ...@@ -72,9 +74,12 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
} }
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
__global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices, __global__ void ScatterNdCUDAKernel(const T* update,
T* output, const int64_t* output_dims, const IndexT* indices,
size_t remain_size, size_t slice_size, T* output,
const int64_t* output_dims,
size_t remain_size,
size_t slice_size,
size_t end_size) { size_t end_size) {
CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) { CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size; int64_t indices_i = i / slice_size;
...@@ -90,7 +95,8 @@ __global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices, ...@@ -90,7 +95,8 @@ __global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices,
"please check whether the dimensions of index and " "please check whether the dimensions of index and "
"input meet the requirements. It should " "input meet the requirements. It should "
"be less than [%d] and greater or equal to 0, but received [%d]", "be less than [%d] and greater or equal to 0, but received [%d]",
output_dims[j], index_value); output_dims[j],
index_value);
gather_i += (index_value * temp); gather_i += (index_value * temp);
temp *= output_dims[j]; temp *= output_dims[j];
...@@ -109,21 +115,24 @@ __global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices, ...@@ -109,21 +115,24 @@ __global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices,
* return: output tensor * return: output tensor
*/ */
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void GPUScatterAssign(const framework::ExecutionContext& context, void GPUScatterAssign(const phi::GPUContext& ctx,
const Tensor& src, const Tensor& index, Tensor* output, const DenseTensor& src,
const DenseTensor& index,
DenseTensor* output,
bool overwrite = true) { bool overwrite = true) {
// check index of shape 1-D // check index of shape 1-D
const auto& ctx = context.device_context();
if (index.dims().size() == 2) { if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( index.dims()[1],
"index.dims()[1] should be 1 when " 1,
"index.dims().size() = 2 in scatter_op." phi::errors::InvalidArgument("index.dims()[1] should be 1 when "
"But received value is [%d]", "index.dims().size() = 2 in scatter_op."
index.dims()[1])); "But received value is [%d]",
index.dims()[1]));
} else { } else {
PADDLE_ENFORCE_EQ(index.dims().size(), 1, PADDLE_ENFORCE_EQ(index.dims().size(),
platform::errors::InvalidArgument( 1,
phi::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in scatter_op." "index.dims().size() should be 1 or 2 in scatter_op."
"But received value is [%d]", "But received value is [%d]",
index.dims().size())); index.dims().size()));
...@@ -131,7 +140,7 @@ void GPUScatterAssign(const framework::ExecutionContext& context, ...@@ -131,7 +140,7 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
int64_t index_size = index.dims()[0]; int64_t index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
framework::DDim output_dims(src_dims); phi::DDim output_dims(src_dims);
output_dims[0] = index_size; output_dims[0] = index_size;
// slice size // slice size
...@@ -150,23 +159,20 @@ void GPUScatterAssign(const framework::ExecutionContext& context, ...@@ -150,23 +159,20 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
// if not overwrite mode, init data // if not overwrite mode, init data
if (!overwrite) { if (!overwrite) {
ScatterInitCUDAKernel<T, IndexT><<< ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_index, p_output, index_size, slice_size); p_index, p_output, index_size, slice_size);
} }
ScatterCUDAKernel<T, IndexT><<< ScatterCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_src, p_index, p_output, index_size, slice_size, overwrite); p_src, p_index, p_output, index_size, slice_size, overwrite);
} }
// The function is only for scatter grad x, // The function is only for scatter grad x,
// however update grad use gather // however update grad use gather
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void GPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index, void GPUScatterGradForX(const phi::GPUContext& ctx,
Tensor* output) { const DenseTensor& index,
DenseTensor* output) {
int64_t index_size = index.dims()[0]; int64_t index_size = index.dims()[0];
auto dst_dims = output->dims(); auto dst_dims = output->dims();
// slice size // slice size
...@@ -181,21 +187,18 @@ void GPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index, ...@@ -181,21 +187,18 @@ void GPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index,
int64_t n = slice_size * index_size; int64_t n = slice_size * index_size;
int64_t height = (n + block - 1) / block; int64_t height = (n + block - 1) / block;
int64_t max_grid_dimx = int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
reinterpret_cast<const platform::CUDADeviceContext&>(ctx)
.GetCUDAMaxGridDimSize()[0];
int64_t grid = height < max_grid_dimx ? height : max_grid_dimx; int64_t grid = height < max_grid_dimx ? height : max_grid_dimx;
ScatterInitCUDAKernel<T, IndexT><<< ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_index, p_output, index_size, slice_size); p_index, p_output, index_size, slice_size);
} }
template <typename DeviceContext, typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void GPUScatterNdAdd(const framework::ExecutionContext& context, void GPUScatterNdAdd(const phi::GPUContext& ctx,
const Tensor& update, const Tensor& index, const DenseTensor& update,
Tensor* output) { const DenseTensor& index,
DenseTensor* output) {
auto index_dims = index.dims(); auto index_dims = index.dims();
auto index_dims_size = index_dims.size(); auto index_dims_size = index_dims.size();
...@@ -219,31 +222,34 @@ void GPUScatterNdAdd(const framework::ExecutionContext& context, ...@@ -219,31 +222,34 @@ void GPUScatterNdAdd(const framework::ExecutionContext& context,
const size_t slice_bytes = slice_size * sizeof(T); const size_t slice_bytes = slice_size * sizeof(T);
// put output_dims int CUDA // put output_dims int CUDA
// gplace and cplace // gplace and cplace
const auto& ctx = context.template device_context<DeviceContext>();
const auto gplace = ctx.GetPlace(); const auto gplace = ctx.GetPlace();
auto cplace = platform::CPUPlace(); auto cplace = phi::CPUPlace();
std::vector<int64_t> v_output_dims(output_dims_size); std::vector<int64_t> v_output_dims(output_dims_size);
for (int i = 0; i < output_dims_size; ++i) { for (int i = 0; i < output_dims_size; ++i) {
v_output_dims[i] = output_dims[i]; v_output_dims[i] = output_dims[i];
} }
auto& dev_ctx = context.cuda_device_context();
phi::DenseTensor out_dims_tensor;
out_dims_tensor.Resize({output_dims_size});
auto* g_output_dims = ctx.Alloc<int64_t>(&out_dims_tensor);
int64_t bytes = output_dims_size * sizeof(int64_t); int64_t bytes = output_dims_size * sizeof(int64_t);
auto output_dims_ptr = memory::Alloc(dev_ctx, bytes); paddle::memory::Copy(
int64_t* g_output_dims = reinterpret_cast<int64_t*>(output_dims_ptr->ptr()); gplace, g_output_dims, cplace, v_output_dims.data(), bytes, ctx.stream());
memory::Copy(gplace, g_output_dims, cplace, v_output_dims.data(), bytes,
ctx.stream());
int block = 512; int block = 512;
int64_t n = slice_size * remain_numel; int64_t n = slice_size * remain_numel;
int64_t grid = (n + block - 1) / block; int64_t grid = (n + block - 1) / block;
ScatterNdCUDAKernel<T, IndexT><<< ScatterNdCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
grid, block, 0, p_update,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>( p_index,
p_update, p_index, p_output, g_output_dims, remain_numel, slice_size, p_output,
g_output_dims,
remain_numel,
slice_size,
end_size); end_size);
} }
} // namespace operators } // namespace funcs
} // namespace paddle } // namespace pten
...@@ -15,18 +15,16 @@ limitations under the License. */ ...@@ -15,18 +15,16 @@ limitations under the License. */
#pragma once #pragma once
#include <cstring> #include <cstring>
#include <string> #include <string>
#include <unordered_set>
#include "paddle/fluid/framework/eigen.h" #include "paddle/phi/common/place.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "unordered_set" #include "paddle/phi/kernels/funcs/eigen/common.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
using Tensor = framework::Tensor;
/** /**
* Return the updated array pointer, use blas or eigen lib to optimize time * Return the updated array pointer, use blas or eigen lib to optimize time
...@@ -34,24 +32,31 @@ using Tensor = framework::Tensor; ...@@ -34,24 +32,31 @@ using Tensor = framework::Tensor;
*/ */
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
typename std::enable_if<std::is_floating_point<T>::value>::type typename std::enable_if<std::is_floating_point<T>::value>::type
elementwise_inner_add(const framework::ExecutionContext& ctx, elementwise_inner_add(const phi::CPUContext& ctx,
const T* src_pointer, T* dst_pointer, size_t src_index, const T* src_pointer,
IndexT dst_index, size_t slice_size) { T* dst_pointer,
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(ctx); size_t src_index,
blas.VADD(slice_size, src_pointer + src_index * slice_size, IndexT dst_index,
size_t slice_size) {
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(ctx);
blas.VADD(slice_size,
src_pointer + src_index * slice_size,
dst_pointer + dst_index * slice_size, dst_pointer + dst_index * slice_size,
dst_pointer + dst_index * slice_size); dst_pointer + dst_index * slice_size);
} }
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
typename std::enable_if<!std::is_floating_point<T>::value>::type typename std::enable_if<!std::is_floating_point<T>::value>::type
elementwise_inner_add(const framework::ExecutionContext& ctx, elementwise_inner_add(const phi::CPUContext& ctx,
const T* src_pointer, T* dst_pointer, size_t src_index, const T* src_pointer,
IndexT dst_index, size_t slice_size) { T* dst_pointer,
using EigenVector = typename framework::EigenTensor<T, 1>::Type; size_t src_index,
using ConstEigenVector = typename framework::EigenTensor<T, 1>::ConstType; IndexT dst_index,
size_t slice_size) {
framework::EigenDim<1>::Type dim; using EigenVector = typename phi::EigenTensor<T, 1>::Type;
using ConstEigenVector = typename phi::EigenTensor<T, 1>::ConstType;
phi::EigenDim<1>::Type dim;
dim[0] = slice_size; dim[0] = slice_size;
ConstEigenVector eigen_src(src_pointer + src_index * slice_size, dim); ConstEigenVector eigen_src(src_pointer + src_index * slice_size, dim);
...@@ -67,22 +72,23 @@ elementwise_inner_add(const framework::ExecutionContext& ctx, ...@@ -67,22 +72,23 @@ elementwise_inner_add(const framework::ExecutionContext& ctx,
* return: output tensor * return: output tensor
*/ */
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, void ScatterAssign(const phi::CPUContext& ctx,
const Tensor& index, Tensor* output) { const DenseTensor& src,
PADDLE_ENFORCE_EQ( const DenseTensor& index,
platform::is_cpu_place(ctx.GetPlace()), true, DenseTensor* output) {
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
// check index of shape 1-D // check index of shape 1-D
if (index.dims().size() == 2) { if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( index.dims()[1],
"index.dims()[1] should be 1 when " 1,
"index.dims().size() =2 in scatter_op." phi::errors::InvalidArgument("index.dims()[1] should be 1 when "
"But received value is [%d]", "index.dims().size() =2 in scatter_op."
index.dims()[1])); "But received value is [%d]",
index.dims()[1]));
} else { } else {
PADDLE_ENFORCE_EQ(index.dims().size(), 1, PADDLE_ENFORCE_EQ(index.dims().size(),
platform::errors::InvalidArgument( 1,
phi::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in scatter_op." "index.dims().size() should be 1 or 2 in scatter_op."
"But received value is [%d]", "But received value is [%d]",
index.dims().size())); index.dims().size()));
...@@ -99,12 +105,16 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -99,12 +105,16 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
// check src shape and dst shape should match // check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++) for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
src_dims[i], dst_dims[i], src_dims[i],
platform::errors::InvalidArgument( dst_dims[i],
phi::errors::InvalidArgument(
"The dimensions of the source tensor and target tensor should" "The dimensions of the source tensor and target tensor should"
" match, but received source tensor's %d-th dimension is %d," " match, but received source tensor's %d-th dimension is %d,"
"target tensor's %d-th dimension is %d.", "target tensor's %d-th dimension is %d.",
i, src_dims[i], i, dst_dims[i])); i,
src_dims[i],
i,
dst_dims[i]));
// slice size // slice size
size_t slice_size = 1; size_t slice_size = 1;
...@@ -115,8 +125,9 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -115,8 +125,9 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
for (int64_t i = 0; i < index_size; ++i) { for (int64_t i = 0; i < index_size; ++i) {
IndexT index_ = p_index[i]; IndexT index_ = p_index[i];
PADDLE_ENFORCE_GE(index_, 0, PADDLE_ENFORCE_GE(index_,
platform::errors::OutOfRange( 0,
phi::errors::OutOfRange(
"The index is out of bounds, " "The index is out of bounds, "
"please check whether the dimensions of index and " "please check whether the dimensions of index and "
"input meet the requirements. It should " "input meet the requirements. It should "
...@@ -128,20 +139,20 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -128,20 +139,20 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
} }
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, void ScatterAssignAdd(const phi::CPUContext& ctx,
const Tensor& index, Tensor* output) { const DenseTensor& src,
PADDLE_ENFORCE_EQ( const DenseTensor& index,
platform::is_cpu_place(ctx.device_context().GetPlace()), true, DenseTensor* output) {
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
index.dims().size() == 1 || index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1), (index.dims().size() == 2 && index.dims()[1] == 1),
true, platform::errors::InvalidArgument( true,
"index's shape is error, " phi::errors::InvalidArgument(
"expect index'dims shape is 1 or 2 and index.dims[1] is 1" "index's shape is error, "
"but got index'dims shape is %d", "expect index'dims shape is 1 or 2 and index.dims[1] is 1"
index.dims().size())); "but got index'dims shape is %d",
index.dims().size()));
int64_t index_size = index.dims()[0]; int64_t index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
...@@ -155,12 +166,16 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, ...@@ -155,12 +166,16 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
// check src shape and dst shape should match // check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++) for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
src_dims[i], dst_dims[i], src_dims[i],
platform::errors::InvalidArgument( dst_dims[i],
phi::errors::InvalidArgument(
"The dimensions of the source tensor and target tensor should" "The dimensions of the source tensor and target tensor should"
" match, but received source tensor's %d-th dimension is %d," " match, but received source tensor's %d-th dimension is %d,"
"target tensor's %d-th dimension is %d.", "target tensor's %d-th dimension is %d.",
i, src_dims[i], i, dst_dims[i])); i,
src_dims[i],
i,
dst_dims[i]));
// slice size // slice size
size_t slice_size = 1; size_t slice_size = 1;
...@@ -172,36 +187,40 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, ...@@ -172,36 +187,40 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
auto max_index = dst_dims[0]; auto max_index = dst_dims[0];
for (int64_t i = 0; i < index_size; ++i) { for (int64_t i = 0; i < index_size; ++i) {
const IndexT& index_val = p_index[i]; const IndexT& index_val = p_index[i];
PADDLE_ENFORCE_GE(index_val, 0, PADDLE_ENFORCE_GE(index_val,
platform::errors::OutOfRange( 0,
phi::errors::OutOfRange(
"The index is out of bounds, " "The index is out of bounds, "
"please check whether the dimensions of index and " "please check whether the dimensions of index and "
"input meet the requirements. It should " "input meet the requirements. It should "
"be greater than or equal to 0, but received [%d]", "be greater than or equal to 0, but received [%d]",
index_val)); index_val));
PADDLE_ENFORCE_LT(index_val, max_index, PADDLE_ENFORCE_LT(index_val,
platform::errors::OutOfRange( max_index,
phi::errors::OutOfRange(
"The index is out of bounds, " "The index is out of bounds, "
"please check whether the dimensions of index and " "please check whether the dimensions of index and "
"input meet the requirements. It should " "input meet the requirements. It should "
"be less than %d, but received %d", "be less than %d, but received %d",
max_index, index_val)); max_index,
index_val));
memset(p_output + slice_size * index_val, 0, slice_bytes); memset(p_output + slice_size * index_val, 0, slice_bytes);
} }
// if not in overwrite mode, need to init output data // if not in overwrite mode, need to init output data
for (int64_t i = 0; i < index_size; ++i) { for (int64_t i = 0; i < index_size; ++i) {
const IndexT& index_val = p_index[i]; const IndexT& index_val = p_index[i];
elementwise_inner_add<T, IndexT>(ctx, p_src, p_output, i, index_val, elementwise_inner_add<T, IndexT>(
slice_size); ctx, p_src, p_output, i, index_val, slice_size);
} }
} }
// The function is only for scatter grad x, // The function is only for scatter grad x,
// however update grad use gather // however update grad use gather
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void CPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index, void CPUScatterGradForX(const phi::CPUContext& ctx,
Tensor* output) { const DenseTensor& index,
DenseTensor* output) {
int64_t index_size = index.dims()[0]; int64_t index_size = index.dims()[0];
auto dst_dims = output->dims(); auto dst_dims = output->dims();
const IndexT* p_index = index.data<IndexT>(); const IndexT* p_index = index.data<IndexT>();
...@@ -216,12 +235,10 @@ void CPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index, ...@@ -216,12 +235,10 @@ void CPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index,
} }
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update, void ScatterNdAdd(const phi::CPUContext& ctx,
const Tensor& index, Tensor* output) { const DenseTensor& update,
PADDLE_ENFORCE_EQ( const DenseTensor& index,
platform::is_cpu_place(ctx.device_context().GetPlace()), true, DenseTensor* output) {
platform::errors::PreconditionNotMet("It should be running on the CPU"));
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:] // update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
auto index_dims = index.dims(); auto index_dims = index.dims();
auto index_dims_size = index_dims.size(); auto index_dims_size = index_dims.size();
...@@ -250,21 +267,23 @@ void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update, ...@@ -250,21 +267,23 @@ void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
for (int64_t j = end_size - 1; j >= 0; --j) { for (int64_t j = end_size - 1; j >= 0; --j) {
IndexT index_value = p_index[i * end_size + j]; IndexT index_value = p_index[i * end_size + j];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(index_value >= 0 && index_value < output_dims[j]), true, (index_value >= 0 && index_value < output_dims[j]),
platform::errors::OutOfRange( true,
phi::errors::OutOfRange(
"The index is out of bounds, " "The index is out of bounds, "
"please check whether the dimensions of index and " "please check whether the dimensions of index and "
"input meet the requirements. It should " "input meet the requirements. It should "
"be less than [%d] and greater or equal to 0, but received [%d]", "be less than [%d] and greater or equal to 0, but received [%d]",
output_dims[j], index_value)); output_dims[j],
index_value));
index_val += (index_value * temp); index_val += (index_value * temp);
temp *= output_dims[j]; temp *= output_dims[j];
} }
elementwise_inner_add<T, IndexT>(ctx, p_update, p_output, i, index_val, elementwise_inner_add<T, IndexT>(
slice_size); ctx, p_update, p_output, i, index_val, slice_size);
} }
} }
} // namespace operators } // namespace funcs
} // namespace paddle } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册