diff --git a/paddle/fluid/operators/detection/bbox_util.cu.h b/paddle/fluid/operators/detection/bbox_util.cu.h index b361bc3ab75e8ad84bbf2a353230a90e01b99b74..f170fbbe4b534ed5f6bb97508048a72ac766de90 100644 --- a/paddle/fluid/operators/detection/bbox_util.cu.h +++ b/paddle/fluid/operators/detection/bbox_util.cu.h @@ -23,7 +23,6 @@ limitations under the License. */ #include namespace cub = hipcub; #endif -#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu index ce9ac3de4e78c2aa562718719b111c9c47376bc8..860fdd01794ccc9898332f6f0d0ba4e9c3e296d6 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu @@ -23,11 +23,11 @@ namespace cub = hipcub; #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h" -#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" namespace paddle { namespace operators { @@ -160,9 +160,9 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel { sorted_rois.mutable_data({real_post_num, kBBoxSize}, dev_ctx.GetPlace()); Tensor sorted_batch_id; sorted_batch_id.mutable_data({real_post_num}, dev_ctx.GetPlace()); - GPUGather(dev_ctx, concat_rois, index_out_t, &sorted_rois); - GPUGather(dev_ctx, roi_batch_id_list_gpu, index_out_t, - &sorted_batch_id); + phi::funcs::GPUGather(dev_ctx, concat_rois, index_out_t, &sorted_rois); + phi::funcs::GPUGather(dev_ctx, roi_batch_id_list_gpu, index_out_t, + &sorted_batch_id); Tensor batch_index_t; int* batch_idx_in = @@ -190,7 +190,7 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel { out_id_data, batch_idx_in, index_out_t.data(), real_post_num, 0, sizeof(int) * 8, dev_ctx.stream()); - GPUGather(dev_ctx, sorted_rois, index_out_t, fpn_rois); + phi::funcs::GPUGather(dev_ctx, sorted_rois, index_out_t, fpn_rois); Tensor length_lod; int* length_lod_data = diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.h b/paddle/fluid/operators/detection/collect_fpn_proposals_op.h index a60f881ebf3e3bd825219dce1fb9f377d90c7a94..e5ae9a6ccbda5acbdb37d1190314c94ca4007c07 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.h +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.h @@ -21,7 +21,6 @@ limitations under the License.*/ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -66,7 +65,8 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel { auto multi_layer_scores = context.MultiInput("MultiLevelScores"); - auto multi_rois_num = context.MultiInput("MultiLevelRoIsNum"); + auto multi_rois_num = + context.MultiInput("MultiLevelRoIsNum"); int num_size = multi_rois_num.size(); auto* fpn_rois = context.Output("FpnRois"); @@ -176,7 +176,7 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel { } num_per_batch.emplace_back(post_nms_topN - pre_idx); if (context.HasOutput("RoisNum")) { - auto* rois_num = context.Output("RoisNum"); + auto* rois_num = context.Output("RoisNum"); int* rois_num_data = rois_num->mutable_data({batch_size}, context.GetPlace()); for (int i = 0; i < batch_size; i++) { diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu index c117fbd70f52827a724c07213cd020d1b58cce22..7ad25e003b491294287a62433b8bf494086a87c2 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu @@ -24,9 +24,9 @@ namespace cub = hipcub; #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h" -#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -193,7 +193,8 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel { start = end; multi_fpn_rois[i]->mutable_data({sub_rois_num, kBoxDim}, dev_ctx.GetPlace()); - GPUGather(dev_ctx, *fpn_rois, sub_idx, multi_fpn_rois[i]); + phi::funcs::GPUGather(dev_ctx, *fpn_rois, sub_idx, + multi_fpn_rois[i]); } else { multi_fpn_rois[i]->mutable_data({sub_rois_num, kBoxDim}, dev_ctx.GetPlace()); diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h index 628cbcd761186bd060fdcbd2b68fe8defec1bf17..5479e08c2a5efa96e64eca45d75af7a6a60a8862 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h @@ -20,7 +20,6 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -28,10 +27,11 @@ namespace operators { const int kBoxDim = 4; -inline std::vector GetLodFromRoisNum(const Tensor* rois_num) { +inline std::vector GetLodFromRoisNum( + const framework::Tensor* rois_num) { std::vector rois_lod; auto* rois_num_data = rois_num->data(); - Tensor cpu_tensor; + framework::Tensor cpu_tensor; if (platform::is_gpu_place(rois_num->place())) { paddle::framework::TensorCopySync(*rois_num, platform::CPUPlace(), &cpu_tensor); @@ -93,7 +93,7 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel { std::vector fpn_rois_lod; int fpn_rois_num; if (context.HasInput("RoisNum")) { - auto* rois_num = context.Input("RoisNum"); + auto* rois_num = context.Input("RoisNum"); fpn_rois_lod = GetLodFromRoisNum(rois_num); } else { fpn_rois_lod = fpn_rois->lod().back(); @@ -105,7 +105,7 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel { std::vector num_rois_level(num_level, 0); std::vector num_rois_level_integral(num_level + 1, 0); for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) { - Tensor fpn_rois_slice = + auto fpn_rois_slice = fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]); const T* rois_data = fpn_rois_slice.data(); for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) { @@ -140,7 +140,7 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel { std::vector restore_index_inter(fpn_rois_num, -1); // distribute the rois into different fpn level by target level for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) { - Tensor fpn_rois_slice = + auto fpn_rois_slice = fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]); const T* rois_data = fpn_rois_slice.data(); size_t cur_offset = fpn_rois_lod[i]; @@ -163,7 +163,8 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel { for (int i = 0; i < fpn_rois_num; ++i) { restore_index_data[restore_index_inter[i]] = i; } - auto multi_rois_num = context.MultiOutput("MultiLevelRoIsNum"); + auto multi_rois_num = + context.MultiOutput("MultiLevelRoIsNum"); if (multi_rois_num.size() > 0) { int batch_size = fpn_rois_lod.size() - 1; for (int i = 0; i < num_level; ++i) { diff --git a/paddle/fluid/operators/detection/generate_mask_labels_op.cc b/paddle/fluid/operators/detection/generate_mask_labels_op.cc index e6af1a5bbf71cf24cd355dc09cb439e0bc9fbfba..c9cc4e722071c69f0bf658ad69363dbdd75b63e4 100644 --- a/paddle/fluid/operators/detection/generate_mask_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_mask_labels_op.cc @@ -17,7 +17,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/mask_util.h" -#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc index 424aa0714400d3c8a897f98b9209222aa61acef8..cbf17048400bfd967e311897bf8d6d6e11d6000b 100644 --- a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc @@ -16,8 +16,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/detection/bbox_util.h" -#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/phi/kernels/funcs/gather.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -281,22 +281,22 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context, Tensor fg_boxes, bg_boxes, fg_labels, bg_labels; fg_boxes.mutable_data({fg_num, kBoxDim}, context.GetPlace()); - CPUGather(context, boxes, fg_inds_t, &fg_boxes); + phi::funcs::CPUGather(context, boxes, fg_inds_t, &fg_boxes); bg_boxes.mutable_data({bg_num, kBoxDim}, context.GetPlace()); - CPUGather(context, boxes, bg_inds_t, &bg_boxes); + phi::funcs::CPUGather(context, boxes, bg_inds_t, &bg_boxes); Concat(context, fg_boxes, bg_boxes, sampled_boxes); - CPUGather(context, gt_boxes, gt_box_inds_t, sampled_gts); + phi::funcs::CPUGather(context, gt_boxes, gt_box_inds_t, sampled_gts); fg_labels.mutable_data({fg_num}, context.GetPlace()); - CPUGather(context, gt_classes, gt_label_inds_t, &fg_labels); + phi::funcs::CPUGather(context, gt_classes, gt_label_inds_t, &fg_labels); bg_labels.mutable_data({bg_num}, context.GetPlace()); phi::funcs::set_constant(context, &bg_labels, 0); Concat(context, fg_labels, bg_labels, sampled_labels); Tensor fg_max_overlap, bg_max_overlap; fg_max_overlap.mutable_data({fg_num}, context.GetPlace()); - CPUGather(context, max_overlap, fg_inds_t, &fg_max_overlap); + phi::funcs::CPUGather(context, max_overlap, fg_inds_t, &fg_max_overlap); bg_max_overlap.mutable_data({bg_num}, context.GetPlace()); - CPUGather(context, max_overlap, bg_inds_t, &bg_max_overlap); + phi::funcs::CPUGather(context, max_overlap, bg_inds_t, &bg_max_overlap); Concat(context, fg_max_overlap, bg_max_overlap, sampled_max_overlap); } @@ -334,7 +334,7 @@ std::vector SampleRoisForOneImage( } else { proposals_num = keep.numel(); roi_filter.mutable_data({proposals_num, kBoxDim}, context.GetPlace()); - CPUGather(context, rpn_rois, keep, &roi_filter); + phi::funcs::CPUGather(context, rpn_rois, keep, &roi_filter); } T* roi_filter_dt = roi_filter.data(); memcpy(rpn_rois_dt, roi_filter_dt, roi_filter.numel() * sizeof(T)); diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index 8c4bd4ac61320356073107b7a109e3c27d6b41a1..d6130823271f05c83e590d28b41c3baf73e054f0 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/nms_util.h" -#include "paddle/fluid/operators/gather.h" +#include "paddle/phi/kernels/funcs/gather.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -196,10 +196,10 @@ class GenerateProposalsKernel : public framework::OpKernel { anchor_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); var_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); - CPUGather(ctx, scores_slice, index_t, &scores_sel); - CPUGather(ctx, bbox_deltas_slice, index_t, &bbox_sel); - CPUGather(ctx, anchors, index_t, &anchor_sel); - CPUGather(ctx, variances, index_t, &var_sel); + phi::funcs::CPUGather(ctx, scores_slice, index_t, &scores_sel); + phi::funcs::CPUGather(ctx, bbox_deltas_slice, index_t, &bbox_sel); + phi::funcs::CPUGather(ctx, anchors, index_t, &anchor_sel); + phi::funcs::CPUGather(ctx, variances, index_t, &var_sel); Tensor proposals; proposals.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); @@ -223,8 +223,8 @@ class GenerateProposalsKernel : public framework::OpKernel { Tensor scores_filter; bbox_sel.mutable_data({keep.numel(), 4}, ctx.GetPlace()); scores_filter.mutable_data({keep.numel(), 1}, ctx.GetPlace()); - CPUGather(ctx, proposals, keep, &bbox_sel); - CPUGather(ctx, scores_sel, keep, &scores_filter); + phi::funcs::CPUGather(ctx, proposals, keep, &bbox_sel); + phi::funcs::CPUGather(ctx, scores_sel, keep, &scores_filter); if (nms_thresh <= 0) { return std::make_pair(bbox_sel, scores_filter); } @@ -237,8 +237,8 @@ class GenerateProposalsKernel : public framework::OpKernel { proposals.mutable_data({keep_nms.numel(), 4}, ctx.GetPlace()); scores_sel.mutable_data({keep_nms.numel(), 1}, ctx.GetPlace()); - CPUGather(ctx, bbox_sel, keep_nms, &proposals); - CPUGather(ctx, scores_filter, keep_nms, &scores_sel); + phi::funcs::CPUGather(ctx, bbox_sel, keep_nms, &proposals); + phi::funcs::CPUGather(ctx, scores_filter, keep_nms, &scores_sel); return std::make_pair(proposals, scores_sel); } diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cu b/paddle/fluid/operators/detection/generate_proposals_op.cu index 6e3c322c1748353d4f447dd6a927e13c4d04025c..5fb7973fd89e49f1cc19458059bffe0dadb9aa3e 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cu +++ b/paddle/fluid/operators/detection/generate_proposals_op.cu @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/detection/bbox_util.cu.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -85,8 +86,8 @@ static std::pair ProposalForOneImage( } proposals_filter.mutable_data({keep_num, 4}, ctx.GetPlace()); scores_filter.mutable_data({keep_num, 1}, ctx.GetPlace()); - GPUGather(ctx, proposals, keep_index, &proposals_filter); - GPUGather(ctx, scores_sort, keep_index, &scores_filter); + phi::funcs::GPUGather(ctx, proposals, keep_index, &proposals_filter); + phi::funcs::GPUGather(ctx, scores_sort, keep_index, &scores_filter); if (nms_thresh <= 0) { return std::make_pair(proposals_filter, scores_filter); @@ -102,8 +103,8 @@ static std::pair ProposalForOneImage( Tensor scores_nms, proposals_nms; proposals_nms.mutable_data({keep_nms.numel(), 4}, ctx.GetPlace()); scores_nms.mutable_data({keep_nms.numel(), 1}, ctx.GetPlace()); - GPUGather(ctx, proposals_filter, keep_nms, &proposals_nms); - GPUGather(ctx, scores_filter, keep_nms, &scores_nms); + phi::funcs::GPUGather(ctx, proposals_filter, keep_nms, &proposals_nms); + phi::funcs::GPUGather(ctx, scores_filter, keep_nms, &scores_nms); return std::make_pair(proposals_nms, scores_nms); } diff --git a/paddle/fluid/operators/detection/generate_proposals_v2_op.cc b/paddle/fluid/operators/detection/generate_proposals_v2_op.cc index 6351ea865cd0eb3891f2b4882a587b2feeb6c67a..1f1802574c5b82281b0a7ecc79d9057df61c37e6 100644 --- a/paddle/fluid/operators/detection/generate_proposals_v2_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_v2_op.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/nms_util.h" -#include "paddle/fluid/operators/gather.h" +#include "paddle/phi/kernels/funcs/gather.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -197,10 +197,10 @@ class GenerateProposalsV2Kernel : public framework::OpKernel { anchor_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); var_sel.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); - CPUGather(ctx, scores_slice, index_t, &scores_sel); - CPUGather(ctx, bbox_deltas_slice, index_t, &bbox_sel); - CPUGather(ctx, anchors, index_t, &anchor_sel); - CPUGather(ctx, variances, index_t, &var_sel); + phi::funcs::CPUGather(ctx, scores_slice, index_t, &scores_sel); + phi::funcs::CPUGather(ctx, bbox_deltas_slice, index_t, &bbox_sel); + phi::funcs::CPUGather(ctx, anchors, index_t, &anchor_sel); + phi::funcs::CPUGather(ctx, variances, index_t, &var_sel); Tensor proposals; proposals.mutable_data({index_t.numel(), 4}, ctx.GetPlace()); @@ -227,8 +227,8 @@ class GenerateProposalsV2Kernel : public framework::OpKernel { Tensor scores_filter; bbox_sel.mutable_data({keep.numel(), 4}, ctx.GetPlace()); scores_filter.mutable_data({keep.numel(), 1}, ctx.GetPlace()); - CPUGather(ctx, proposals, keep, &bbox_sel); - CPUGather(ctx, scores_sel, keep, &scores_filter); + phi::funcs::CPUGather(ctx, proposals, keep, &bbox_sel); + phi::funcs::CPUGather(ctx, scores_sel, keep, &scores_filter); if (nms_thresh <= 0) { return std::make_pair(bbox_sel, scores_filter); } @@ -242,8 +242,8 @@ class GenerateProposalsV2Kernel : public framework::OpKernel { proposals.mutable_data({keep_nms.numel(), 4}, ctx.GetPlace()); scores_sel.mutable_data({keep_nms.numel(), 1}, ctx.GetPlace()); - CPUGather(ctx, bbox_sel, keep_nms, &proposals); - CPUGather(ctx, scores_filter, keep_nms, &scores_sel); + phi::funcs::CPUGather(ctx, bbox_sel, keep_nms, &proposals); + phi::funcs::CPUGather(ctx, scores_filter, keep_nms, &scores_sel); return std::make_pair(proposals, scores_sel); } diff --git a/paddle/fluid/operators/detection/generate_proposals_v2_op.cu b/paddle/fluid/operators/detection/generate_proposals_v2_op.cu index 93ba3deca5fc4f1b0247f90f21936faaaf9c0b43..005309e8ee577119fd295126c40b46a11a762497 100644 --- a/paddle/fluid/operators/detection/generate_proposals_v2_op.cu +++ b/paddle/fluid/operators/detection/generate_proposals_v2_op.cu @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/detection/bbox_util.cu.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -86,8 +87,8 @@ static std::pair ProposalForOneImage( } proposals_filter.mutable_data({keep_num, 4}, ctx.GetPlace()); scores_filter.mutable_data({keep_num, 1}, ctx.GetPlace()); - GPUGather(ctx, proposals, keep_index, &proposals_filter); - GPUGather(ctx, scores_sort, keep_index, &scores_filter); + phi::funcs::GPUGather(ctx, proposals, keep_index, &proposals_filter); + phi::funcs::GPUGather(ctx, scores_sort, keep_index, &scores_filter); if (nms_thresh <= 0) { return std::make_pair(proposals_filter, scores_filter); @@ -104,8 +105,8 @@ static std::pair ProposalForOneImage( Tensor scores_nms, proposals_nms; proposals_nms.mutable_data({keep_nms.numel(), 4}, ctx.GetPlace()); scores_nms.mutable_data({keep_nms.numel(), 1}, ctx.GetPlace()); - GPUGather(ctx, proposals_filter, keep_nms, &proposals_nms); - GPUGather(ctx, scores_filter, keep_nms, &scores_nms); + phi::funcs::GPUGather(ctx, proposals_filter, keep_nms, &proposals_nms); + phi::funcs::GPUGather(ctx, scores_filter, keep_nms, &scores_nms); return std::make_pair(proposals_nms, scores_nms); } diff --git a/paddle/fluid/operators/gather_nd_op.cu b/paddle/fluid/operators/gather_nd_op.cu index 0de2798bf750915e99c9b60ed8ccb94d7d1201ab..338c44116183415ab09881c470e6d34283b015ed 100644 --- a/paddle/fluid/operators/gather_nd_op.cu +++ b/paddle/fluid/operators/gather_nd_op.cu @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/gather_nd_op.h" -#include "paddle/fluid/operators/scatter.cu.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" +#include "paddle/phi/kernels/funcs/scatter.cu.h" namespace paddle { namespace operators { -template +template class GatherNdOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { @@ -33,27 +33,25 @@ class GatherNdOpCUDAKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s], but " - "desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - if (index_type == framework::proto::VarType::INT32) { - GPUGatherNd(ctx, *x, *index, output); - } else if (index_type == framework::proto::VarType::INT64) { - GPUGatherNd(ctx, *x, *index, output); + const auto &index_type = index->dtype(); + bool index_type_match = index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s], but " + "desires to be [%s] or [%s].", + index_type, phi::DataType::INT32, phi::DataType::INT64)); + auto &dev_ctx = ctx.cuda_device_context(); + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUGatherNd(dev_ctx, *x, *index, output); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::GPUGatherNd(dev_ctx, *x, *index, output); } } }; -template +template class GatherNdGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { @@ -71,24 +69,22 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel { dxt.device(place) = dxt.constant(static_cast(0)); if (dO->numel() == 0) return; - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; + const auto &index_type = index->dtype(); + bool index_type_match = index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s].", + index_type, phi::DataType::INT32, phi::DataType::INT64)); - if (index_type == framework::proto::VarType::INT32) { - GPUScatterNdAdd(ctx, *dO, *index, dX); - } else if (index_type == framework::proto::VarType::INT64) { - GPUScatterNdAdd(ctx, *dO, *index, dX); + auto &dev_ctx = ctx.cuda_device_context(); + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUScatterNdAdd(dev_ctx, *dO, *index, dX); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::GPUScatterNdAdd(dev_ctx, *dO, *index, dX); } } }; @@ -98,18 +94,16 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; -using CUDA = paddle::platform::CUDADeviceContext; -REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel, - ops::GatherNdOpCUDAKernel, - ops::GatherNdOpCUDAKernel, - ops::GatherNdOpCUDAKernel, - ops::GatherNdOpCUDAKernel, - ops::GatherNdOpCUDAKernel, - ops::GatherNdOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel, + ops::GatherNdOpCUDAKernel, + ops::GatherNdOpCUDAKernel, + ops::GatherNdOpCUDAKernel, + ops::GatherNdOpCUDAKernel, + ops::GatherNdOpCUDAKernel, + ops::GatherNdOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(gather_nd_grad, - ops::GatherNdGradOpCUDAKernel, - ops::GatherNdGradOpCUDAKernel, - ops::GatherNdGradOpCUDAKernel, - ops::GatherNdGradOpCUDAKernel, - ops::GatherNdGradOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(gather_nd_grad, ops::GatherNdGradOpCUDAKernel, + ops::GatherNdGradOpCUDAKernel, + ops::GatherNdGradOpCUDAKernel, + ops::GatherNdGradOpCUDAKernel, + ops::GatherNdGradOpCUDAKernel); diff --git a/paddle/fluid/operators/gather_nd_op.h b/paddle/fluid/operators/gather_nd_op.h index f458c0e18013b4d7a85d960e0e7df1b2d21638fe..d54261008e47b89151248a8372ede4b524d999bf 100644 --- a/paddle/fluid/operators/gather_nd_op.h +++ b/paddle/fluid/operators/gather_nd_op.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather.h" -#include "paddle/fluid/operators/scatter.h" +#include "paddle/phi/kernels/funcs/gather.h" +#include "paddle/phi/kernels/funcs/scatter.h" namespace paddle { namespace operators { @@ -38,22 +38,20 @@ class GatherNdOpKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s]", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - if (index_type == framework::proto::VarType::INT32) { - CPUGatherNd(ctx.device_context(), *x, *index, output); - } else if (index_type == framework::proto::VarType::INT64) { - CPUGatherNd(ctx.device_context(), *x, *index, output); + auto index_type = index->dtype(); + bool index_type_match = index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + index_type, phi::DataType::INT32, phi::DataType::INT64)); + auto &dev_ctx = ctx.template device_context(); + if (index_type == phi::DataType::INT32) { + phi::funcs::CPUGatherNd(dev_ctx, *x, *index, output); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::CPUGatherNd(dev_ctx, *x, *index, output); } } }; @@ -65,6 +63,7 @@ class GatherNdGradOpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ( platform::is_cpu_place(ctx.GetPlace()), true, platform::errors::PreconditionNotMet("This kernel only runs on CPU.")); + auto *index = ctx.Input("Index"); auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); @@ -75,22 +74,21 @@ class GatherNdGradOpKernel : public framework::OpKernel { dxt.device(place) = dxt.constant(static_cast(0)); if (dO->numel() == 0) return; - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s]", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - if (index_type == framework::proto::VarType::INT32) { - ScatterNdAdd(ctx, *dO, *index, dX); - } else if (index_type == framework::proto::VarType::INT64) { - ScatterNdAdd(ctx, *dO, *index, dX); + auto index_type = index->dtype(); + bool index_type_match = index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + index_type, phi::DataType::INT32, phi::DataType::INT64)); + + auto &dev_ctx = ctx.template device_context(); + if (index_type == phi::DataType::INT32) { + phi::funcs::ScatterNdAdd(dev_ctx, *dO, *index, dX); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::ScatterNdAdd(dev_ctx, *dO, *index, dX); } } }; diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index a502a13040949a34e88a4d585327a58ffe92562c..8f1d9284c503813ef3dd9688891048a5bca57b29 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -14,9 +14,9 @@ limitations under the License. */ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/gather_op.h" -#include "paddle/fluid/operators/scatter.cu.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" +#include "paddle/phi/kernels/funcs/scatter.cu.h" namespace paddle { namespace operators { @@ -49,11 +49,14 @@ class GatherOpCUDAKernel : public framework::OpKernel { } const auto &place = ctx.GetPlace(); const auto &index_type = framework::TransToProtoVarType(index->dtype()); + const auto &dev_ctx = ctx.cuda_device_context(); if (axis != 0) { if (index_type == framework::proto::VarType::INT32) { - GatherV2CUDAFunction(x, index, axis, output, place, ctx); + phi::funcs::GatherV2CUDAFunction(x, index, axis, output, + dev_ctx); } else if (index_type == framework::proto::VarType::INT64) { - GatherV2CUDAFunction(x, index, axis, output, place, ctx); + phi::funcs::GatherV2CUDAFunction(x, index, axis, output, + dev_ctx); } return; } @@ -61,9 +64,9 @@ class GatherOpCUDAKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; if (index_type == framework::proto::VarType::INT32) { - GPUGather(ctx.device_context(), *x, *index, output); + phi::funcs::GPUGather(dev_ctx, *x, *index, output); } else if (index_type == framework::proto::VarType::INT64) { - GPUGather(ctx.device_context(), *x, *index, output); + phi::funcs::GPUGather(dev_ctx, *x, *index, output); } } }; @@ -93,14 +96,15 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { } } + const auto &dev_ctx = ctx.cuda_device_context(); const auto &index_type = framework::TransToProtoVarType(index->dtype()); if (axis != 0) { if (index_type == framework::proto::VarType::INT32) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - ctx.GetPlace(), ctx); + phi::funcs::GatherV2GradCUDAFunction(dO, index, axis, dX, + dev_ctx); } else if (index_type == framework::proto::VarType::INT64) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - ctx.GetPlace(), ctx); + phi::funcs::GatherV2GradCUDAFunction(dO, index, axis, dX, + dev_ctx); } return; } @@ -112,11 +116,11 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { dxt.device(place) = dxt.constant(static_cast(0)); if (dO->numel() == 0) return; if (index_type == framework::proto::VarType::INT32) { - GPUScatterAssign(ctx, *dO, *index, dX, - ctx.Attr("overwrite")); + phi::funcs::GPUScatterAssign(dev_ctx, *dO, *index, dX, + ctx.Attr("overwrite")); } else if (index_type == framework::proto::VarType::INT64) { - GPUScatterAssign(ctx, *dO, *index, dX, - ctx.Attr("overwrite")); + phi::funcs::GPUScatterAssign(dev_ctx, *dO, *index, dX, + ctx.Attr("overwrite")); } } }; diff --git a/paddle/fluid/operators/gather_op.h b/paddle/fluid/operators/gather_op.h index 016c2b398daaad92ec60e37606345e0c6c4e13f5..94de694b2f9bc484cdb60298b60d5a9433dac181 100644 --- a/paddle/fluid/operators/gather_op.h +++ b/paddle/fluid/operators/gather_op.h @@ -16,8 +16,8 @@ limitations under the License. */ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather.h" -#include "paddle/fluid/operators/scatter.h" +#include "paddle/phi/kernels/funcs/gather.h" +#include "paddle/phi/kernels/funcs/scatter.h" namespace paddle { namespace operators { @@ -40,31 +40,32 @@ class GatherOpKernel : public framework::OpKernel { // get axis from tensor if (ctx.HasInput("Axis")) { const Tensor *axis_tensor = ctx.Input("Axis"); - const auto &axis_type = - framework::TransToProtoVarType(axis_tensor->dtype()); - if (axis_type == framework::proto::VarType::INT32) { + const auto &axis_type = axis_tensor->dtype(); + if (axis_type == phi::DataType::INT32) { axis = static_cast(axis_tensor->data()[0]); - } else if (axis_type == framework::proto::VarType::INT64) { + } else if (axis_type == phi::DataType::INT64) { axis = static_cast(axis_tensor->data()[0]); } } - const auto &place = ctx.GetPlace(); - const auto &index_type = framework::TransToProtoVarType(index->dtype()); + const auto &index_type = index->dtype(); + auto &dev_ctx = ctx.template device_context(); if (axis != 0) { - if (index_type == framework::proto::VarType::INT32) { - GatherV2Function(x, index, axis, output, place); - } else if (index_type == framework::proto::VarType::INT64) { - GatherV2Function(x, index, axis, output, place); + if (index_type == phi::DataType::INT32) { + phi::funcs::GatherV2Function(dev_ctx, x, index, axis, + output); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::GatherV2Function(dev_ctx, x, index, axis, + output); } return; } output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; - if (index_type == framework::proto::VarType::INT32) { - CPUGather(ctx.device_context(), *x, *index, output); - } else if (index_type == framework::proto::VarType::INT64) { - CPUGather(ctx.device_context(), *x, *index, output); + if (index_type == phi::DataType::INT32) { + phi::funcs::CPUGather(dev_ctx, *x, *index, output); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::CPUGather(dev_ctx, *x, *index, output); } } }; @@ -84,44 +85,45 @@ class GatherGradientOpKernel : public framework::OpKernel { int axis = ctx.Attr("axis"); if (ctx.HasInput("Axis")) { const Tensor *axis_tensor = ctx.Input("Axis"); - const auto &axis_type = - framework::TransToProtoVarType(axis_tensor->dtype()); - if (axis_type == framework::proto::VarType::INT32) { + const auto &axis_type = axis_tensor->dtype(); + if (axis_type == phi::DataType::INT32) { axis = static_cast(axis_tensor->data()[0]); - } else if (axis_type == framework::proto::VarType::INT64) { + } else if (axis_type == phi::DataType::INT64) { axis = static_cast(axis_tensor->data()[0]); } } - const auto &index_type = framework::TransToProtoVarType(index->dtype()); + const auto &index_type = index->dtype(); + auto &dev_ctx = ctx.template device_context(); if (axis != 0) { - if (index_type == framework::proto::VarType::INT32) { - GatherV2GradFunction(dO, index, axis, dX, ctx.GetPlace()); - } else if (index_type == framework::proto::VarType::INT64) { - GatherV2GradFunction(dO, index, axis, dX, ctx.GetPlace()); + if (index_type == phi::DataType::INT32) { + phi::funcs::GatherV2GradFunction(dev_ctx, dO, index, axis, + dX); + } else if (index_type == phi::DataType::INT64) { + phi::funcs::GatherV2GradFunction(dev_ctx, dO, index, axis, + dX); } return; } dX->mutable_data(ctx.GetPlace()); auto dxt = framework::EigenVector::Flatten(*dX); - auto &place = *ctx.template device_context() - .eigen_device(); + auto &place = *dev_ctx.eigen_device(); dxt.device(place) = dxt.constant(static_cast(0)); if (dO->numel() == 0) return; bool overwrite = ctx.Attr("overwrite"); - if (index_type == framework::proto::VarType::INT32) { + if (index_type == phi::DataType::INT32) { if (overwrite) { - ScatterAssign(ctx.device_context(), *dO, *index, dX); + phi::funcs::ScatterAssign(dev_ctx, *dO, *index, dX); } else { - ScatterAssignAdd(ctx, *dO, *index, dX); + phi::funcs::ScatterAssignAdd(dev_ctx, *dO, *index, dX); } - } else if (index_type == framework::proto::VarType::INT64) { + } else if (index_type == phi::DataType::INT64) { if (overwrite) { - ScatterAssign(ctx.device_context(), *dO, *index, dX); + phi::funcs::ScatterAssign(dev_ctx, *dO, *index, dX); } else { - ScatterAssignAdd(ctx, *dO, *index, dX); + phi::funcs::ScatterAssignAdd(dev_ctx, *dO, *index, dX); } } } diff --git a/paddle/fluid/operators/gather_test.cc b/paddle/fluid/operators/gather_test.cc index 0f3dcdadcf897dc05d131225cdffe11f84043c14..c962dd065234f37fe98481c9866f7d2f405db69c 100644 --- a/paddle/fluid/operators/gather_test.cc +++ b/paddle/fluid/operators/gather_test.cc @@ -15,8 +15,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/kernels/funcs/gather.h" TEST(Gather, GatherData) { paddle::framework::Tensor* src = new paddle::framework::Tensor(); @@ -39,7 +39,7 @@ TEST(Gather, GatherData) { auto* cpu_place = new paddle::platform::CPUPlace(); paddle::platform::CPUDeviceContext ctx(*cpu_place); - paddle::operators::CPUGather(ctx, *src, *index, output); + phi::funcs::CPUGather(ctx, *src, *index, output); delete cpu_place; cpu_place = NULL; for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); diff --git a/paddle/fluid/operators/grid_sampler_op.h b/paddle/fluid/operators/grid_sampler_op.h index 8f3c6660f51c4de80e5a98370eae0381abe333a6..93e96694270a458844bbcabf78f2559975902c2f 100644 --- a/paddle/fluid/operators/grid_sampler_op.h +++ b/paddle/fluid/operators/grid_sampler_op.h @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather.h" #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/math/segment_pooling.cu b/paddle/fluid/operators/math/segment_pooling.cu index bb6d8756bd0a35a2243d8a336c171e2cee51d9b5..fbdcb99c02ab97335e595d69b6215cd6a018a33a 100644 --- a/paddle/fluid/operators/math/segment_pooling.cu +++ b/paddle/fluid/operators/math/segment_pooling.cu @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/math/segment_pooling.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -379,9 +379,9 @@ class SegmentPoolGradFunctor { SimpleDiv<<>>(mean_grad.data(), summed_ids->data(), len, dim); - GPUGather(context, mean_grad, segments, in_grad); + phi::funcs::GPUGather(context, mean_grad, segments, in_grad); } else if (pooltype == "SUM") { - GPUGather(context, out_grad, segments, in_grad); + phi::funcs::GPUGather(context, out_grad, segments, in_grad); } else { PADDLE_THROW(platform::errors::InvalidArgument( "Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN " diff --git a/paddle/fluid/operators/scatter_nd_add_op.cu b/paddle/fluid/operators/scatter_nd_add_op.cu index 6448f8cc4056d2c11806c1c342df57d597e606ba..2fe3fcb759d348b36cd6a7a2609bea210d24705f 100644 --- a/paddle/fluid/operators/scatter_nd_add_op.cu +++ b/paddle/fluid/operators/scatter_nd_add_op.cu @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/gather_op.h" -#include "paddle/fluid/operators/scatter.cu.h" #include "paddle/fluid/operators/scatter_nd_add_op.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" +#include "paddle/phi/kernels/funcs/scatter.cu.h" namespace paddle { namespace operators { @@ -33,22 +33,20 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel { auto *Out = ctx.Output("Out"); framework::TensorCopySync(*X, ctx.GetPlace(), Out); - const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s], but " - "desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - if (index_type == framework::proto::VarType::INT32) { - GPUScatterNdAdd(ctx, *Updates, *Ids, Out); + const auto &index_type = Ids->dtype(); + bool index_type_match = index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s], but " + "desires to be [%s] or [%s].", + index_type, phi::DataType::INT32, phi::DataType::INT64)); + auto &dev_ctx = ctx.cuda_device_context(); + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUScatterNdAdd(dev_ctx, *Updates, *Ids, Out); } else { - GPUScatterNdAdd(ctx, *Updates, *Ids, Out); + phi::funcs::GPUScatterNdAdd(dev_ctx, *Updates, *Ids, Out); } } }; @@ -69,12 +67,13 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel { } if (dUpdates) { dUpdates->mutable_data(ctx.GetPlace()); + auto &dev_ctx = ctx.cuda_device_context(); // Gradient by Gather - const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); - if (index_type == framework::proto::VarType::INT32) { - GPUGatherNd(ctx, *dOut, *Ids, dUpdates); + const auto &index_type = Ids->dtype(); + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUGatherNd(dev_ctx, *dOut, *Ids, dUpdates); } else { - GPUGatherNd(ctx, *dOut, *Ids, dUpdates); + phi::funcs::GPUGatherNd(dev_ctx, *dOut, *Ids, dUpdates); } } } diff --git a/paddle/fluid/operators/scatter_nd_add_op.h b/paddle/fluid/operators/scatter_nd_add_op.h index 2bdf9ec58a850ea59f7f0697bc5d0eadde0adc99..81c95fe55abaad2e126a52ac7ab97dea24fe67f0 100644 --- a/paddle/fluid/operators/scatter_nd_add_op.h +++ b/paddle/fluid/operators/scatter_nd_add_op.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather.h" -#include "paddle/fluid/operators/scatter.h" +#include "paddle/phi/kernels/funcs/gather.h" +#include "paddle/phi/kernels/funcs/scatter.h" namespace paddle { namespace operators { @@ -37,23 +37,21 @@ class ScatterNdAddOpKernel : public framework::OpKernel { // In place output: Out = X framework::TensorCopySync(*X, ctx.GetPlace(), Out); - const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s], but " - "desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); + const auto &index_type = Ids->dtype(); + bool index_type_match = index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s], but " + "desires to be [%s] or [%s].", + index_type, phi::DataType::INT32, phi::DataType::INT64)); - if (index_type == framework::proto::VarType::INT32) { - ScatterNdAdd(ctx, *Updates, *Ids, Out); + auto &dev_ctx = ctx.template device_context(); + if (index_type == phi::DataType::INT32) { + phi::funcs::ScatterNdAdd(dev_ctx, *Updates, *Ids, Out); } else { - ScatterNdAdd(ctx, *Updates, *Ids, Out); + phi::funcs::ScatterNdAdd(dev_ctx, *Updates, *Ids, Out); } } }; @@ -76,11 +74,12 @@ class ScatterNdAddGradientOpKernel : public framework::OpKernel { if (dUpdates) { dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Ids] - const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); - if (index_type == framework::proto::VarType::INT32) { - CPUGatherNd(ctx.device_context(), *dOut, *Ids, dUpdates); + const auto &index_type = Ids->dtype(); + auto &dev_ctx = ctx.template device_context(); + if (index_type == phi::DataType::INT32) { + phi::funcs::CPUGatherNd(dev_ctx, *dOut, *Ids, dUpdates); } else { - CPUGatherNd(ctx.device_context(), *dOut, *Ids, dUpdates); + phi::funcs::CPUGatherNd(dev_ctx, *dOut, *Ids, dUpdates); } } } diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu index 549e30803b4647e3e107b0d16147c472c0dcb226..7755e376bc1956a1f9e09dc2eb8aead9fa083157 100644 --- a/paddle/fluid/operators/scatter_op.cu +++ b/paddle/fluid/operators/scatter_op.cu @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/gather_op.h" -#include "paddle/fluid/operators/scatter.cu.h" #include "paddle/fluid/operators/scatter_op.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" +#include "paddle/phi/kernels/funcs/scatter.cu.h" namespace paddle { namespace operators { @@ -35,23 +35,22 @@ class ScatterOpCUDAKernel : public framework::OpKernel { framework::TensorCopy(*X, ctx.GetPlace(), Out); // use template class to support int32_t and int64_t - const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; + auto index_type = Ids->dtype(); + bool index_type_match = index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64; PADDLE_ENFORCE_EQ( index_type_match, true, platform::errors::InvalidArgument( "scatter_op Index holds the wrong type, it holds [%s]," "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - if (index_type == framework::proto::VarType::INT32) { - GPUScatterAssign(ctx, *Updates, *Ids, Out, overwrite); + index_type, phi::DataType::INT32, phi::DataType::INT64)); + auto &dev_ctx = ctx.cuda_device_context(); + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUScatterAssign(dev_ctx, *Updates, *Ids, Out, + overwrite); } else { - GPUScatterAssign(ctx, *Updates, *Ids, Out, overwrite); + phi::funcs::GPUScatterAssign(dev_ctx, *Updates, *Ids, Out, + overwrite); } } }; @@ -68,36 +67,33 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { auto *Ids = ctx.Input("Ids"); auto *dOut = ctx.Input(framework::GradVarName("Out")); - const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; + auto index_type = Ids->dtype(); + bool index_type_match = index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64; PADDLE_ENFORCE_EQ( index_type_match, true, platform::errors::InvalidArgument( "scatter_op index holds the wrong type, it holds [%s]," "but desires to be [%s] or [%s]", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); + index_type, phi::DataType::INT32, phi::DataType::INT64)); + auto &dev_ctx = ctx.cuda_device_context(); if (dX) { framework::TensorCopy(*dOut, ctx.GetPlace(), dX); - if (index_type == framework::proto::VarType::INT32) { - GPUScatterGradForX(ctx.device_context(), *Ids, dX); + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUScatterGradForX(dev_ctx, *Ids, dX); } else { - GPUScatterGradForX(ctx.device_context(), *Ids, dX); + phi::funcs::GPUScatterGradForX(dev_ctx, *Ids, dX); } } if (dUpdates) { dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Ids] - if (index_type == framework::proto::VarType::INT32) { - GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + if (index_type == phi::DataType::INT32) { + phi::funcs::GPUGather(dev_ctx, *dOut, *Ids, dUpdates); } else { - GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + phi::funcs::GPUGather(dev_ctx, *dOut, *Ids, dUpdates); } } } diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index 69ab6c7135cd55468bbe8a4c65d45a466b8eaa75..7733181a93fb60c116ff3da964336b0a85d9a84c 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather.h" -#include "paddle/fluid/operators/scatter.h" +#include "paddle/phi/kernels/funcs/gather.h" +#include "paddle/phi/kernels/funcs/scatter.h" namespace paddle { namespace operators { @@ -39,29 +39,27 @@ class ScatterOpKernel : public framework::OpKernel { // In place output: Out = X, Out[Ids] = Updates framework::TensorCopy(*X, ctx.GetPlace(), Out); // Apply ScatterUpdate: Out[index] = Updates[:] - const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); + const auto &index_type = Ids->dtype(); + bool index_type_match = index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s].", + index_type, phi::DataType::INT32, phi::DataType::INT64)); + auto &dev_ctx = ctx.template device_context(); if (overwrite) { - if (index_type == framework::proto::VarType::INT32) { - ScatterAssign(ctx.device_context(), *Updates, *Ids, Out); + if (index_type == phi::DataType::INT32) { + phi::funcs::ScatterAssign(dev_ctx, *Updates, *Ids, Out); } else { - ScatterAssign(ctx.device_context(), *Updates, *Ids, Out); + phi::funcs::ScatterAssign(dev_ctx, *Updates, *Ids, Out); } } else { - if (index_type == framework::proto::VarType::INT32) { - ScatterAssignAdd(ctx, *Updates, *Ids, Out); + if (index_type == phi::DataType::INT32) { + phi::funcs::ScatterAssignAdd(dev_ctx, *Updates, *Ids, Out); } else { - ScatterAssignAdd(ctx, *Updates, *Ids, Out); + phi::funcs::ScatterAssignAdd(dev_ctx, *Updates, *Ids, Out); } } } @@ -79,36 +77,33 @@ class ScatterGradientOpKernel : public framework::OpKernel { auto *Ids = ctx.Input("Ids"); auto *dOut = ctx.Input(framework::GradVarName("Out")); - const auto &index_type = framework::TransToProtoVarType(Ids->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; + const auto &index_type = Ids->dtype(); + bool index_type_match = index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64; PADDLE_ENFORCE_EQ( index_type_match, true, platform::errors::InvalidArgument( "scatter_op index holds the wrong type, it holds [%s]," "but desires to be [%s] or [%s]", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); + index_type, phi::DataType::INT32, phi::DataType::INT64)); + auto &dev_ctx = ctx.template device_context(); if (dX) { framework::TensorCopy(*dOut, ctx.GetPlace(), dX); - if (index_type == framework::proto::VarType::INT32) { - CPUScatterGradForX(ctx.device_context(), *Ids, dX); + if (index_type == phi::DataType::INT32) { + phi::funcs::CPUScatterGradForX(dev_ctx, *Ids, dX); } else { - CPUScatterGradForX(ctx.device_context(), *Ids, dX); + phi::funcs::CPUScatterGradForX(dev_ctx, *Ids, dX); } } if (dUpdates) { dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Ids] - if (index_type == framework::proto::VarType::INT32) { - CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + if (index_type == phi::DataType::INT32) { + phi::funcs::CPUGather(dev_ctx, *dOut, *Ids, dUpdates); } else { - CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + phi::funcs::CPUGather(dev_ctx, *dOut, *Ids, dUpdates); } } } diff --git a/paddle/fluid/operators/scatter_test.cc b/paddle/fluid/operators/scatter_test.cc index 0a4cab5fac1abe92b2b2457098d71a7dc3624910..93f2d60e5f232767f8e604ca98e3c39fc00caf8b 100644 --- a/paddle/fluid/operators/scatter_test.cc +++ b/paddle/fluid/operators/scatter_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/scatter.h" +#include "paddle/phi/kernels/funcs/scatter.h" #include @@ -43,7 +43,7 @@ TEST(scatter, ScatterUpdate) { auto* cpu_place = new paddle::platform::CPUPlace(); paddle::platform::CPUDeviceContext ctx(*cpu_place); - paddle::operators::ScatterAssign(ctx, src, index, &output); + phi::funcs::ScatterAssign(ctx, src, index, &output); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], 0.0f); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output.data()[i], 0.0f); diff --git a/paddle/fluid/operators/segment_pool_op.cu b/paddle/fluid/operators/segment_pool_op.cu index 4e20844dc3275f840ff93029abb222e2ef02e0fa..e147e62a98354087121ca1443b20d9163ef00f73 100644 --- a/paddle/fluid/operators/segment_pool_op.cu +++ b/paddle/fluid/operators/segment_pool_op.cu @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/segment_pool_op.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" diff --git a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc index 2d4730635fd2aeb2e20aa5f4a637f94bce075566..25c12ab565a141f48d254d51bfca64f7422f1f42 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc @@ -16,8 +16,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather.h" -#include "paddle/fluid/operators/scatter.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.h b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.h index 365381abc4683580b9dffb94ace9876933de495b..2960b77d5ac0f81e4dd026d9de3448cac1459645 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.h @@ -15,8 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/gather.h" -#include "paddle/fluid/operators/scatter.h" +#include "paddle/phi/kernels/funcs/scatter.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/viterbi_decode_op.cu b/paddle/fluid/operators/viterbi_decode_op.cu index 3c546dd8156a2bdffc9615d171d4630faf3bb7fb..68628fb2748c424996e7f0ae24594ff04649f8d6 100644 --- a/paddle/fluid/operators/viterbi_decode_op.cu +++ b/paddle/fluid/operators/viterbi_decode_op.cu @@ -11,8 +11,8 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/viterbi_decode_op.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" #ifdef __NVCC__ #include "cub/cub.cuh" @@ -62,10 +62,11 @@ int64_t ComputeBlockSize(int64_t col) { template