From 1b585b2896c05f08a69c6513ba16fd6817739118 Mon Sep 17 00:00:00 2001 From: seemingwang Date: Mon, 28 Feb 2022 22:50:21 +0800 Subject: [PATCH] Move index sample (#39905) * graph engine demo * upload unsaved changes * fix dependency error * fix shard_num problem * py client * remove lock and graph-type * add load direct graph * add load direct graph * add load direct graph * batch random_sample * batch_sample_k * fix num_nodes size * batch brpc * batch brpc * add test * add test * add load_nodes; change add_node function * change sample return type to pair * resolve conflict * resolved conflict * resolved conflict * separate server and client * merge pair type * fix * resolved conflict * fixed segment fault; high-level VLOG for load edges and load nodes * random_sample return 0 * rm useless loop * test:load edge * fix ret -1 * test: rm sample * rm sample * random_sample return future * random_sample return int * test fake node * fixed here * memory leak * remove test code * fix return problem * add common_graph_table * random sample node &test & change data-structure from linkedList to vector * add common_graph_table * sample with srand * add node_types * optimize nodes sample * recover test * random sample * destruct weighted sampler * GraphEdgeBlob * WeightedGraphEdgeBlob to GraphEdgeBlob * WeightedGraphEdgeBlob to GraphEdgeBlob * pybind sample nodes api * pull nodes with step * fixed pull_graph_list bug; add test for pull_graph_list by step * add graph table;name * add graph table;name * add pybind * add pybind * add FeatureNode * add FeatureNode * add FeatureNode Serialize * add FeatureNode Serialize * get_feat_node * avoid local rpc * fix get_node_feat * fix get_node_feat * remove log * get_node_feat return py:bytes * merge develop with graph_engine * fix threadpool.h head * fix * fix typo * resolve conflict * fix conflict * recover lost content * fix pybind of FeatureNode * recover cmake * recover tools * resolve conflict * resolve linking problem * code style * change test_server port * fix code problems * remove shard_num config * remove redundent threads * optimize start server * remove logs * fix code problems by reviewers' suggestions * move graph files into a folder * code style change * remove graph operations from base table * optimize get_feat function of graph engine * fix long long count problem * remove redandunt graph files * remove unused shell * recover dropout_op_pass.h * fix potential stack overflow when request number is too large & node add & node clear & node remove * when sample k is larger than neigbor num, return directly * using random seed generator of paddle to speed up * fix bug of random sample k * fix code style * fix code style * add remove graph to fleet_py.cc * fix blocking_queue problem * fix style * fix * recover capacity check * add remove graph node; add set_feature * add remove graph node; add set_feature * add remove graph node; add set_feature * add remove graph node; add set_feature * fix distributed op combining problems * optimize * remove logs * fix MultiSlotDataGenerator error * cache for graph engine * fix type compare error * more test&fix thread terminating problem * remove header * change time interval of shrink * use cache when sample nodes * remove unused function * change unique_ptr to shared_ptr * simplify cache template * cache api on client * fix * reduce sample threads when cache is not used * reduce cache memory * cache optimization * remove test function * remove extra fetch function * graph-engine data transfer optimization * support graph_split load&query * remove logs * change shards to pointer vector * use inference * remove test code * renorm op * simplify renorm op * recover local changes * recover renorm op kernel * fix init * add blanklines in renorm doc * fix import * fix import * add renorm to init.py * merge * move index_sample op * Delete api.h * Delete api.cc * fix * remove logs * recover infer shape of grad * recover changes * change shape * fix label * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix Co-authored-by: Huang Zhengjie <270018958@qq.com> Co-authored-by: Weiyue Su Co-authored-by: suweiyue Co-authored-by: luobin06 Co-authored-by: liweibin02 Co-authored-by: tangwei12 --- paddle/fluid/operators/index_sample_op.cc | 61 +---- paddle/fluid/operators/index_sample_op.cu | 215 ------------------ paddle/fluid/operators/index_sample_op.h | 198 ---------------- paddle/fluid/operators/index_sample_op_npu.cc | 3 +- paddle/phi/infermeta/binary.cc | 35 +++ paddle/phi/infermeta/binary.h | 5 + .../kernels/cpu/index_sample_grad_kernel.cc | 106 +++++++++ paddle/phi/kernels/cpu/index_sample_kernel.cc | 118 ++++++++++ .../kernels/gpu/index_sample_grad_kernel.cu | 146 ++++++++++++ paddle/phi/kernels/gpu/index_sample_kernel.cu | 119 ++++++++++ paddle/phi/kernels/index_sample_grad_kernel.h | 28 +++ paddle/phi/kernels/index_sample_kernel.h | 27 +++ paddle/phi/ops/compat/index_sample_sig.cc | 30 +++ 13 files changed, 623 insertions(+), 468 deletions(-) delete mode 100644 paddle/fluid/operators/index_sample_op.cu delete mode 100644 paddle/fluid/operators/index_sample_op.h create mode 100644 paddle/phi/kernels/cpu/index_sample_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/index_sample_kernel.cc create mode 100644 paddle/phi/kernels/gpu/index_sample_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/index_sample_kernel.cu create mode 100644 paddle/phi/kernels/index_sample_grad_kernel.h create mode 100644 paddle/phi/kernels/index_sample_kernel.h create mode 100644 paddle/phi/ops/compat/index_sample_sig.cc diff --git a/paddle/fluid/operators/index_sample_op.cc b/paddle/fluid/operators/index_sample_op.cc index 2d97797cfec..68d002fceea 100644 --- a/paddle/fluid/operators/index_sample_op.cc +++ b/paddle/fluid/operators/index_sample_op.cc @@ -12,12 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/index_sample_op.h" #include #include "paddle/fluid/framework/no_need_buffer_vars_inference.h" -#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { class IndexSampleOpMaker : public framework::OpProtoAndCheckerMaker { @@ -42,44 +44,6 @@ class IndexSampleOpMaker : public framework::OpProtoAndCheckerMaker { class IndexSampleOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::InvalidArgument( - "Inputs(Input) of FindByIndex should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true, - platform::errors::InvalidArgument( - "Inputs(Index) of FindByIndex should not be null.")); - - auto input_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ( - input_dims.size(), 2, - platform::errors::InvalidArgument( - "Inputs(X) shape of IndexSample op should be 2-D, but " - "got X's shape = [%s], please check X shape.", - input_dims)); - - auto index_dims = ctx->GetInputDim("Index"); - PADDLE_ENFORCE_EQ( - input_dims.size(), 2, - platform::errors::InvalidArgument( - "Inputs(Index) shape of IndexSample op should be 2-D, but " - "got Index's shape [%s] , please check index shape.", - input_dims)); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(input_dims[0], index_dims[0], - platform::errors::InvalidArgument( - "Inputs(X)'s value of dimension 0 must same with " - "Inputs(Index)'s value of dimension 0, but " - "got %d of Inputs(X), and got %d of Inputs(Index), " - "please check Inputs shape.", - input_dims[0], index_dims[0])); - } - ctx->SetOutputDim("Out", index_dims); - auto type = ctx->GetInputsVarType("Index")[0]; - if (type == framework::proto::VarType::LOD_TENSOR) { - ctx->ShareLoD("Index", /*->*/ "Out"); - } - } protected: framework::OpKernelType GetExpectedKernelType( @@ -136,20 +100,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSampleGradNoNeedBufferVarInferer, "X"); } // namespace paddle namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(index_sample, IndexSampleInferShapeFunctor, + PT_INFER_META(phi::IndexSampleInferMeta)); REGISTER_OPERATOR(index_sample, ops::IndexSampleOp, ops::IndexSampleOpMaker, ops::IndexSampleGradMaker, - ops::IndexSampleGradMaker); + ops::IndexSampleGradMaker, + IndexSampleInferShapeFunctor); REGISTER_OPERATOR(index_sample_grad, ops::IndexSampleGradOp, ops::IndexSampleGradNoNeedBufferVarInferer); -REGISTER_OP_CPU_KERNEL( - index_sample, - ops::IndexSampleKernel, - ops::IndexSampleKernel, - ops::IndexSampleKernel, - ops::IndexSampleKernel); -REGISTER_OP_CPU_KERNEL( - index_sample_grad, - ops::IndexSampleGradKernel, - ops::IndexSampleGradKernel, - ops::IndexSampleGradKernel, - ops::IndexSampleGradKernel); diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu deleted file mode 100644 index e8acbfb8be9..00000000000 --- a/paddle/fluid/operators/index_sample_op.cu +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/index_sample_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -#define PREDEFINED_BLOCK_SIZE_X 512 -#define PREDEFINED_BLOCK_SIZE 1024 -#define MIN(a, b) ((a) < (b) ? (a) : (b)) - -namespace paddle { -namespace operators { - -namespace { -void LimitGridDim(const framework::ExecutionContext& ctx, dim3* grid_dim) { - auto max_grid_dim = ctx.template device_context() - .GetCUDAMaxGridDimSize(); - grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0]; - grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1]; -} -} - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - -template -__global__ void IndexSampleForward(const IndexT* index, const T* in_data, - T* out_data, size_t index_length, - size_t input_length, size_t batch_size) { - unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x; - unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y; - for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) { - index_i = blockDim.x * blockIdx.x + threadIdx.x; - for (; index_i < index_length; index_i += blockDim.x * gridDim.x) { - unsigned int index_idx = index_j * index_length + index_i; - unsigned int in_idx = index_j * input_length + index_i; - IndexT sample_idx = index[index_idx]; - out_data[index_idx] = in_data[in_idx - index_i + sample_idx]; - } - } -} - -template -__global__ void IndexSampleGrad(const IndexT* index, T* in_grad, - const T* out_grad, size_t index_length, - size_t input_length, size_t batch_size, - bool same_data_in_row = true) { - unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x; - unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y; - - for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) { - index_i = blockDim.x * blockIdx.x + threadIdx.x; - for (; index_i < index_length; index_i += blockDim.x * gridDim.x) { - unsigned int index_idx = index_j * index_length + index_i; - unsigned int in_idx = index_j * input_length + index_i; - IndexT sample_idx = index[index_idx]; - if (same_data_in_row) { - platform::CudaAtomicAdd(&(in_grad[in_idx - index_i + sample_idx]), - out_grad[sample_idx]); - } else { - in_grad[in_idx - index_i + sample_idx] = out_grad[index_idx]; - } - } - } -} - -template -class IndexSampleKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* index = ctx.Input("Index"); - auto* output = ctx.Output("Out"); - - const auto& index_type = framework::TransToProtoVarType(index->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT64 || - index_type == framework::proto::VarType::INT32; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Input(Index) holds the wrong type, it holds %s, but " - "desires to be %s or %s", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - const auto* in_data = input->data(); - auto* out_data = output->mutable_data(ctx.GetPlace()); - auto stream = - ctx.template device_context().stream(); - - auto input_dim = input->dims(); - auto index_dim = index->dims(); - size_t batch_size = input_dim[0]; - size_t input_length = input_dim[1]; - size_t index_length = index_dim[1]; - - auto block_width = platform::RoundToPowerOfTwo(index_length); - block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X); - int block_height = - platform::RoundToPowerOfTwo(index_length * batch_size) / block_width; - block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width); - dim3 block_dim(block_width, block_height); - dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, - (batch_size + block_dim.y - 1) / block_dim.y); - LimitGridDim(ctx, &grid_dim); - - if (index_type == framework::proto::VarType::INT64) { - const int64_t* index_data = index->data(); - IndexSampleForward<<>>( - index_data, in_data, out_data, index_length, input_length, - batch_size); - } else if (index_type == framework::proto::VarType::INT32) { - const int* index_data = index->data(); - IndexSampleForward<<>>( - index_data, in_data, out_data, index_length, input_length, - batch_size); - } - } -}; - -template -class IndexSampleGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* output_grad = ctx.Input(framework::GradVarName("Out")); - auto* input_grad = ctx.Output(framework::GradVarName("X")); - auto* index = ctx.Input("Index"); - - const auto* output_grad_data = output_grad->data(); - auto* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - - const auto& index_type = framework::TransToProtoVarType(index->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT64 || - index_type == framework::proto::VarType::INT32; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Input(Index) holds the wrong type, it holds %s, but " - "desires to be %s or %s", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - - auto stream = - ctx.template device_context().stream(); - auto input_num = input_grad->numel(); - auto input_dim = input_grad->dims(); - auto index_dim = index->dims(); - size_t batch_size = index_dim[0]; - size_t input_length = input_dim[1]; - size_t index_length = index_dim[1]; - bool same_data_in_index_row = index_length == 1 ? false : true; - - auto block_width = platform::RoundToPowerOfTwo(index_length); - block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X); - auto block_height = - platform::RoundToPowerOfTwo(index_length * batch_size) / block_width; - block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width); - dim3 block_dim(block_width, block_height); - dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, - (batch_size + block_dim.y - 1) / block_dim.y); - LimitGridDim(ctx, &grid_dim); - - phi::funcs::SetConstant set_zero; - auto& dev_ctx = ctx.template device_context(); - set_zero(dev_ctx, input_grad, static_cast(0)); - - if (index_type == framework::proto::VarType::INT64) { - const int64_t* index_data = index->data(); - IndexSampleGrad<<>>( - index_data, input_grad_data, output_grad_data, index_length, - input_length, batch_size, same_data_in_index_row); - } else if (index_type == framework::proto::VarType::INT32) { - const int* index_data = index->data(); - IndexSampleGrad<<>>( - index_data, input_grad_data, output_grad_data, index_length, - input_length, batch_size, same_data_in_index_row); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - index_sample, - ops::IndexSampleKernel, - ops::IndexSampleKernel, - ops::IndexSampleKernel, - ops::IndexSampleKernel); -REGISTER_OP_CUDA_KERNEL( - index_sample_grad, - ops::IndexSampleGradKernel, - ops::IndexSampleGradKernel, - ops::IndexSampleGradKernel, - ops::IndexSampleGradKernel); diff --git a/paddle/fluid/operators/index_sample_op.h b/paddle/fluid/operators/index_sample_op.h deleted file mode 100644 index 6cc8ff04c54..00000000000 --- a/paddle/fluid/operators/index_sample_op.h +++ /dev/null @@ -1,198 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "gflags/gflags.h" -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; -using DDim = framework::DDim; - -template -void IndexSampleInner(const framework::ExecutionContext &context, - const LoDTensor &input, const LoDTensor &index, - LoDTensor *output) { - auto input_dims = input.dims(); - auto index_dims = index.dims(); - - int batch_size = input_dims[0]; - auto value_length = input_dims[1]; - auto index_length = index_dims[1]; - int index_ids_num = index.numel(); - - std::vector input_vec; - std::vector index_vec; - paddle::framework::TensorToVector(input, context.device_context(), - &input_vec); - paddle::framework::TensorToVector(index, context.device_context(), - &index_vec); - - std::vector res(index_ids_num); - for (int i = 0; i < index_ids_num; i++) { - int b = floor(i / index_length); - PADDLE_ENFORCE_GE( - index_vec[i], 0, - platform::errors::InvalidArgument( - "Variable value (index) of OP(index_sample) " - "expected >= 0 and < %ld, but got %ld. Please check input " - "value.", - value_length, index_vec[i])); - PADDLE_ENFORCE_LT( - index_vec[i], value_length, - platform::errors::InvalidArgument( - "Variable value (index) of OP(index_sample) " - "expected >= 0 and < %ld, but got %ld. Please check input " - "value.", - value_length, index_vec[i])); - - int v_i = b * value_length + static_cast(index_vec[i]); - T v = input_vec[v_i]; - VLOG(4) << "Index Sample: batch = " << b << " index = " << v_i - << " value = " << v; - res[i] = v; - } - - auto ddim = phi::make_ddim({batch_size, index_length}); - output->mutable_data(context.GetPlace()); - framework::TensorFromVector(res, context.device_context(), output); - output->Resize(ddim); -} - -template -class IndexSampleKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *input_var = ctx.InputVar("X"); - auto *index_var = ctx.InputVar("Index"); - - auto &input_tensor = input_var->Get(); - auto &index_tensor = index_var->Get(); - - auto *out_var = ctx.OutputVar("Out"); - auto *out_tensor = out_var->GetMutable(); - - const auto &index_type = - framework::TransToProtoVarType(index_tensor.dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Input(Index) holds the wrong type, it holds %s, but " - "desires to be %s or %s", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - if (index_type == framework::proto::VarType::INT32) { - IndexSampleInner(ctx, input_tensor, index_tensor, out_tensor); - } else if (index_type == framework::proto::VarType::INT64) { - IndexSampleInner(ctx, input_tensor, index_tensor, out_tensor); - } - } -}; - -template -void IndexSampleGradInner(const framework::ExecutionContext &context, - const LoDTensor &out_grad, const LoDTensor &index, - LoDTensor *x_grad) { - std::vector out_grad_vec; - std::vector index_vec; - paddle::framework::TensorToVector(out_grad, context.device_context(), - &out_grad_vec); - paddle::framework::TensorToVector(index, context.device_context(), - &index_vec); - - auto index_dims = index.dims(); - auto x_grad_dims = x_grad->dims(); - - auto value_length = x_grad_dims[1]; - auto index_length = index_dims[1]; - int index_ids_num = index.numel(); - - std::vector x_grad_vec(x_grad->numel(), 0); - - for (int i = 0; i < index_ids_num; i++) { - int b = floor(i / index_length); - PADDLE_ENFORCE_GE( - index_vec[i], 0, - platform::errors::InvalidArgument( - "Variable value (index) of OP(index_sample_grad) " - "expected >= 0 and < %ld, but got %ld. Please check input " - "value.", - value_length, index_vec[i])); - PADDLE_ENFORCE_LT( - index_vec[i], value_length, - platform::errors::InvalidArgument( - "Variable value (index) of OP(index_sample_grad) " - "expected >= 0 and < %ld, but got %ld. Please check input " - "value.", - value_length, index_vec[i])); - int v_i = b * value_length + static_cast(index_vec[i]); - x_grad_vec[v_i] += out_grad_vec[i]; - } - x_grad->mutable_data(context.GetPlace()); - framework::TensorFromVector(x_grad_vec, context.device_context(), x_grad); - x_grad->Resize(x_grad_dims); -} - -template -class IndexSampleGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *index_var = context.InputVar("Index"); - auto *x_grad_var = context.OutputVar(framework::GradVarName("X")); - auto *out_grad_var = context.InputVar(framework::GradVarName("Out")); - - auto &index_tensor = index_var->Get(); - auto &out_grad_tensor = out_grad_var->Get(); - auto *x_grad_tensor = x_grad_var->GetMutable(); - - const auto &index_type = - framework::TransToProtoVarType(index_tensor.dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Input(Index) holds the wrong type, it holds %s, but " - "desires to be %s or %s", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - if (index_type == framework::proto::VarType::INT32) { - IndexSampleGradInner(context, out_grad_tensor, index_tensor, - x_grad_tensor); - } else if (index_type == framework::proto::VarType::INT64) { - IndexSampleGradInner(context, out_grad_tensor, index_tensor, - x_grad_tensor); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/index_sample_op_npu.cc b/paddle/fluid/operators/index_sample_op_npu.cc index f460d0622bc..38eb5b45149 100644 --- a/paddle/fluid/operators/index_sample_op_npu.cc +++ b/paddle/fluid/operators/index_sample_op_npu.cc @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/index_sample_op.h" - +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index dfaabf7cae2..1905e33bd03 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -225,6 +225,41 @@ void HuberLossInferMeta(const MetaTensor& input, out->share_lod(input); } +void IndexSampleInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out, + MetaConfig config) { + auto input_dims = x.dims(); + PADDLE_ENFORCE_EQ(input_dims.size(), + 2, + errors::InvalidArgument( + "Inputs(X) shape of IndexSample op should be 2-D, but " + "got X's shape = [%s], please check X shape.", + input_dims)); + + auto index_dims = y.dims(); + PADDLE_ENFORCE_EQ( + index_dims.size(), + 2, + errors::InvalidArgument( + "Inputs(Index) shape of IndexSample op should be 2-D, but " + "got Index's shape [%s] , please check index shape.", + input_dims)); + if (config.is_runtime) { + PADDLE_ENFORCE_EQ(input_dims[0], + index_dims[0], + errors::InvalidArgument( + "Inputs(X)'s value of dimension 0 must same with " + "Inputs(Index)'s value of dimension 0, but " + "got %d of Inputs(X), and got %d of Inputs(Index), " + "please check Inputs shape.", + input_dims[0], + index_dims[0])); + } + out->set_dtype(x.dtype()); + out->set_dims(index_dims); + out->share_lod(y); +} void CrossInferMeta(const MetaTensor& x, const MetaTensor& y, int axis, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 02750482dcc..a0140c9a579 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -53,6 +53,11 @@ void HuberLossInferMeta(const MetaTensor& input_meta, MetaTensor* residual, MetaConfig config = MetaConfig()); +void IndexSampleInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void CrossInferMeta(const MetaTensor& x, const MetaTensor& y, int axis, diff --git a/paddle/phi/kernels/cpu/index_sample_grad_kernel.cc b/paddle/phi/kernels/cpu/index_sample_grad_kernel.cc new file mode 100644 index 00000000000..006711ceef7 --- /dev/null +++ b/paddle/phi/kernels/cpu/index_sample_grad_kernel.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_sample_grad_kernel.h" +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +namespace phi { +template +void IndexSampleGradInner(const Context& context, + const DenseTensor& out_grad, + const DenseTensor& index, + DenseTensor* x_grad) { + std::vector out_grad_vec; + std::vector index_vec; + paddle::framework::TensorToVector(out_grad, context, &out_grad_vec); + paddle::framework::TensorToVector(index, context, &index_vec); + + auto index_dims = index.dims(); + auto x_grad_dims = x_grad->dims(); + + auto value_length = x_grad_dims[1]; + auto index_length = index_dims[1]; + int index_ids_num = index.numel(); + + std::vector x_grad_vec(x_grad->numel(), 0); + + for (int i = 0; i < index_ids_num; i++) { + int b = floor(i / index_length); + PADDLE_ENFORCE_GE( + index_vec[i], + 0, + errors::InvalidArgument( + "Variable value (index) of OP(index_sample_grad) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + value_length, + index_vec[i])); + PADDLE_ENFORCE_LT( + index_vec[i], + value_length, + errors::InvalidArgument( + "Variable value (index) of OP(index_sample_grad) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + value_length, + index_vec[i])); + int v_i = b * value_length + static_cast(index_vec[i]); + x_grad_vec[v_i] += out_grad_vec[i]; + } + context.template Alloc(x_grad); + paddle::framework::TensorFromVector(x_grad_vec, context, x_grad); + x_grad->Resize(x_grad_dims); +} + +template +void IndexSampleGradKernel(const Context& ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& index, + DenseTensor* x_grad) { + auto index_type = index.dtype(); + bool index_type_match = + index_type == DataType::INT32 || index_type == DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, + true, + errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString( + paddle::framework::TransToProtoVarType(index_type)), + paddle::framework::DataTypeToString( + paddle::framework::TransToProtoVarType(DataType::INT32)), + paddle::framework::DataTypeToString( + paddle::framework::TransToProtoVarType((DataType::INT64))))); + if (index_type == DataType::INT32) { + IndexSampleGradInner(ctx, out_grad, index, x_grad); + } else if (index_type == DataType::INT64) { + IndexSampleGradInner(ctx, out_grad, index, x_grad); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(index_sample_grad, + CPU, + ALL_LAYOUT, + phi::IndexSampleGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/index_sample_kernel.cc b/paddle/phi/kernels/cpu/index_sample_kernel.cc new file mode 100644 index 00000000000..21bf9faee13 --- /dev/null +++ b/paddle/phi/kernels/cpu/index_sample_kernel.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_sample_kernel.h" +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +namespace phi { +template +void IndexSampleInner(const Context &context, + const DenseTensor &input, + const DenseTensor &index, + DenseTensor *output) { + auto input_dims = input.dims(); + auto index_dims = index.dims(); + + int batch_size = input_dims[0]; + auto value_length = input_dims[1]; + auto index_length = index_dims[1]; + int index_ids_num = index.numel(); + + std::vector input_vec; + std::vector index_vec; + paddle::framework::TensorToVector(input, context, &input_vec); + paddle::framework::TensorToVector(index, context, &index_vec); + + std::vector res(index_ids_num); + for (int i = 0; i < index_ids_num; i++) { + int b = floor(i / index_length); + PADDLE_ENFORCE_GE( + index_vec[i], + 0, + errors::InvalidArgument( + "Variable value (index) of OP(index_sample) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + value_length, + index_vec[i])); + PADDLE_ENFORCE_LT( + index_vec[i], + value_length, + errors::InvalidArgument( + "Variable value (index) of OP(index_sample) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + value_length, + index_vec[i])); + + int v_i = b * value_length + static_cast(index_vec[i]); + T v = input_vec[v_i]; + VLOG(4) << "Index Sample: batch = " << b << " index = " << v_i + << " value = " << v; + res[i] = v; + } + + auto ddim = phi::make_ddim({batch_size, index_length}); + context.template Alloc(output); + paddle::framework::TensorFromVector(res, context, output); + output->Resize(ddim); +} + +template +void IndexSampleKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + DenseTensor *out) { + ctx.template Alloc(out); + auto index_type = index.dtype(); + bool index_type_match = + index_type == DataType::INT32 || index_type == DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, + true, + errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString( + paddle::framework::TransToProtoVarType(index_type)), + paddle::framework::DataTypeToString( + paddle::framework::TransToProtoVarType(DataType::INT32)), + paddle::framework::DataTypeToString( + paddle::framework::TransToProtoVarType((DataType::INT64))))); + if (index_type == DataType::INT32) { + IndexSampleInner(ctx, x, index, out); + } else if (index_type == DataType::INT64) { + IndexSampleInner(ctx, x, index, out); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(index_sample, + CPU, + ALL_LAYOUT, + phi::IndexSampleKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu b/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu new file mode 100644 index 00000000000..8b1ef964124 --- /dev/null +++ b/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu @@ -0,0 +1,146 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_sample_grad_kernel.h" + +#include +#include +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +namespace { +template +void LimitGridDim(const Context& ctx, dim3* grid_dim) { + auto max_grid_dim = + reinterpret_cast(ctx).GetCUDAMaxGridDimSize(); + grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0]; + grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1]; +} +#define PREDEFINED_BLOCK_SIZE_X 512 +#define PREDEFINED_BLOCK_SIZE 1024 +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +}; + +template +__global__ void IndexSampleGrad(const IndexT* index, + T* in_grad, + const T* out_grad, + size_t index_length, + size_t input_length, + size_t batch_size, + bool same_data_in_row = true) { + unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x; + unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y; + + for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) { + index_i = blockDim.x * blockIdx.x + threadIdx.x; + for (; index_i < index_length; index_i += blockDim.x * gridDim.x) { + unsigned int index_idx = index_j * index_length + index_i; + unsigned int in_idx = index_j * input_length + index_i; + IndexT sample_idx = index[index_idx]; + if (same_data_in_row) { + paddle::platform::CudaAtomicAdd( + &(in_grad[in_idx - index_i + sample_idx]), out_grad[sample_idx]); + } else { + in_grad[in_idx - index_i + sample_idx] = out_grad[index_idx]; + } + } + } +} + +template +void IndexSampleGradKernel(const Context& ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& index, + DenseTensor* x_grad) { + const T* output_grad_data = out_grad.data(); + T* input_grad_data = ctx.template Alloc(x_grad); + auto index_type = index.dtype(); + bool index_type_match = + index_type == DataType::INT32 || index_type == DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, + true, + errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString( + paddle::framework::TransToProtoVarType(index_type)), + paddle::framework::DataTypeToString( + paddle::framework::TransToProtoVarType(DataType::INT32)), + paddle::framework::DataTypeToString( + paddle::framework::TransToProtoVarType((DataType::INT64))))); + + auto stream = reinterpret_cast(ctx).stream(); + auto input_num = x.numel(); + auto input_dim = x.dims(); + auto index_dim = index.dims(); + size_t batch_size = index_dim[0]; + size_t input_length = input_dim[1]; + size_t index_length = index_dim[1]; + bool same_data_in_index_row = index_length == 1 ? false : true; + + auto block_width = paddle::platform::RoundToPowerOfTwo(index_length); + block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X); + auto block_height = + paddle::platform::RoundToPowerOfTwo(index_length * batch_size) / + block_width; + block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width); + dim3 block_dim(block_width, block_height); + dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, + (batch_size + block_dim.y - 1) / block_dim.y); + LimitGridDim(ctx, &grid_dim); + + phi::funcs::SetConstant set_zero; + set_zero(ctx, x_grad, static_cast(0)); + + if (index_type == DataType::INT64) { + const int64_t* index_data = index.data(); + IndexSampleGrad<<>>( + index_data, + input_grad_data, + output_grad_data, + index_length, + input_length, + batch_size, + same_data_in_index_row); + } else if (index_type == DataType::INT32) { + const int* index_data = index.data(); + IndexSampleGrad<<>>( + index_data, + input_grad_data, + output_grad_data, + index_length, + input_length, + batch_size, + same_data_in_index_row); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(index_sample_grad, + GPU, + ALL_LAYOUT, + phi::IndexSampleGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/index_sample_kernel.cu b/paddle/phi/kernels/gpu/index_sample_kernel.cu new file mode 100644 index 00000000000..0e042089e1e --- /dev/null +++ b/paddle/phi/kernels/gpu/index_sample_kernel.cu @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_sample_kernel.h" + +#include +#include +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +namespace { +template +void LimitGridDim(const Context& ctx, dim3* grid_dim) { + auto max_grid_dim = + reinterpret_cast(ctx).GetCUDAMaxGridDimSize(); + grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0]; + grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1]; +} +#define PREDEFINED_BLOCK_SIZE_X 512 +#define PREDEFINED_BLOCK_SIZE 1024 +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +} + +template +__global__ void IndexSampleForward(const IndexT* index, + const T* in_data, + T* out_data, + size_t index_length, + size_t input_length, + size_t batch_size) { + unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x; + unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y; + for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) { + index_i = blockDim.x * blockIdx.x + threadIdx.x; + for (; index_i < index_length; index_i += blockDim.x * gridDim.x) { + unsigned int index_idx = index_j * index_length + index_i; + unsigned int in_idx = index_j * input_length + index_i; + IndexT sample_idx = index[index_idx]; + out_data[index_idx] = in_data[in_idx - index_i + sample_idx]; + } + } +} + +template +void IndexSampleKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& index, + DenseTensor* out) { + auto index_type = index.dtype(); + bool index_type_match = + index_type == DataType::INT32 || index_type == DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, + true, + errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString( + paddle::framework::TransToProtoVarType(index_type)), + paddle::framework::DataTypeToString( + paddle::framework::TransToProtoVarType(DataType::INT32)), + paddle::framework::DataTypeToString( + paddle::framework::TransToProtoVarType((DataType::INT64))))); + const T* in_data = x.data(); + T* out_data = ctx.template Alloc(out); + auto stream = reinterpret_cast(ctx).stream(); + auto input_dim = x.dims(); + auto index_dim = index.dims(); + size_t batch_size = input_dim[0]; + size_t input_length = input_dim[1]; + size_t index_length = index_dim[1]; + + auto block_width = paddle::platform::RoundToPowerOfTwo(index_length); + block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X); + int block_height = + paddle::platform::RoundToPowerOfTwo(index_length * batch_size) / + block_width; + block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width); + dim3 block_dim(block_width, block_height); + dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, + (batch_size + block_dim.y - 1) / block_dim.y); + LimitGridDim(ctx, &grid_dim); + + if (index_type == DataType::INT64) { + const int64_t* index_data = index.data(); + IndexSampleForward<<>>( + index_data, in_data, out_data, index_length, input_length, batch_size); + } else if (index_type == DataType::INT32) { + const int* index_data = index.data(); + IndexSampleForward<<>>( + index_data, in_data, out_data, index_length, input_length, batch_size); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(index_sample, + GPU, + ALL_LAYOUT, + phi::IndexSampleKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/index_sample_grad_kernel.h b/paddle/phi/kernels/index_sample_grad_kernel.h new file mode 100644 index 00000000000..5c6e101f1b4 --- /dev/null +++ b/paddle/phi/kernels/index_sample_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void IndexSampleGradKernel(const Context& ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& index, + DenseTensor* in_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/index_sample_kernel.h b/paddle/phi/kernels/index_sample_kernel.h new file mode 100644 index 00000000000..fb43c0c6c5f --- /dev/null +++ b/paddle/phi/kernels/index_sample_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void IndexSampleKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& index, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/index_sample_sig.cc b/paddle/phi/ops/compat/index_sample_sig.cc new file mode 100644 index 00000000000..0d2aed68a72 --- /dev/null +++ b/paddle/phi/ops/compat/index_sample_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature IndexSampleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("index_sample_grad", + {GradVarName("Out"), "X", "Index"}, + {}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(index_sample_grad, + phi::IndexSampleGradOpArgumentMapping); -- GitLab