未验证 提交 1b585b28 编写于 作者: S seemingwang 提交者: GitHub

Move index sample (#39905)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions

* move graph files into a folder

* code style change

* remove graph operations from base table

* optimize get_feat function of graph engine

* fix long long count problem

* remove redandunt graph files

* remove unused shell

* recover dropout_op_pass.h

* fix potential stack overflow when request number is too large & node add & node clear & node remove

* when sample k is larger than neigbor num, return directly

* using random seed generator of paddle to speed up

* fix bug of random sample k

* fix code style

* fix code style

* add remove graph to fleet_py.cc

* fix blocking_queue problem

* fix style

* fix

* recover capacity check

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* fix distributed op combining problems

* optimize

* remove logs

* fix MultiSlotDataGenerator error

* cache for graph engine

* fix type compare error

* more test&fix thread terminating problem

* remove header

* change time interval of shrink

* use cache when sample nodes

* remove unused function

* change unique_ptr to shared_ptr

* simplify cache template

* cache api on client

* fix

* reduce sample threads when cache is not used

* reduce cache memory

* cache optimization

* remove test function

* remove extra fetch function

* graph-engine data transfer optimization

* support graph_split load&query

* remove logs

* change shards to pointer vector

* use inference

* remove test code

* renorm op

* simplify renorm op

* recover local changes

* recover renorm op kernel

* fix init

* add blanklines in renorm doc

* fix import

* fix import

* add renorm to init.py

* merge

* move index_sample op

* Delete api.h

* Delete api.cc

* fix

* remove logs

* recover infer shape of grad

* recover changes

* change shape

* fix label

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 49677636
...@@ -12,12 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/index_sample_op.h"
#include <vector> #include <vector>
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h" #include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class IndexSampleOpMaker : public framework::OpProtoAndCheckerMaker { class IndexSampleOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -42,44 +44,6 @@ class IndexSampleOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -42,44 +44,6 @@ class IndexSampleOpMaker : public framework::OpProtoAndCheckerMaker {
class IndexSampleOp : public framework::OperatorWithKernel { class IndexSampleOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Inputs(Input) of FindByIndex should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
platform::errors::InvalidArgument(
"Inputs(Index) of FindByIndex should not be null."));
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(
input_dims.size(), 2,
platform::errors::InvalidArgument(
"Inputs(X) shape of IndexSample op should be 2-D, but "
"got X's shape = [%s], please check X shape.",
input_dims));
auto index_dims = ctx->GetInputDim("Index");
PADDLE_ENFORCE_EQ(
input_dims.size(), 2,
platform::errors::InvalidArgument(
"Inputs(Index) shape of IndexSample op should be 2-D, but "
"got Index's shape [%s] , please check index shape.",
input_dims));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(input_dims[0], index_dims[0],
platform::errors::InvalidArgument(
"Inputs(X)'s value of dimension 0 must same with "
"Inputs(Index)'s value of dimension 0, but "
"got %d of Inputs(X), and got %d of Inputs(Index), "
"please check Inputs shape.",
input_dims[0], index_dims[0]));
}
ctx->SetOutputDim("Out", index_dims);
auto type = ctx->GetInputsVarType("Index")[0];
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("Index", /*->*/ "Out");
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
...@@ -136,20 +100,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSampleGradNoNeedBufferVarInferer, "X"); ...@@ -136,20 +100,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSampleGradNoNeedBufferVarInferer, "X");
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(index_sample, IndexSampleInferShapeFunctor,
PT_INFER_META(phi::IndexSampleInferMeta));
REGISTER_OPERATOR(index_sample, ops::IndexSampleOp, ops::IndexSampleOpMaker, REGISTER_OPERATOR(index_sample, ops::IndexSampleOp, ops::IndexSampleOpMaker,
ops::IndexSampleGradMaker<paddle::framework::OpDesc>, ops::IndexSampleGradMaker<paddle::framework::OpDesc>,
ops::IndexSampleGradMaker<paddle::imperative::OpBase>); ops::IndexSampleGradMaker<paddle::imperative::OpBase>,
IndexSampleInferShapeFunctor);
REGISTER_OPERATOR(index_sample_grad, ops::IndexSampleGradOp, REGISTER_OPERATOR(index_sample_grad, ops::IndexSampleGradOp,
ops::IndexSampleGradNoNeedBufferVarInferer); ops::IndexSampleGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
index_sample,
ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
index_sample_grad,
ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/index_sample_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#define PREDEFINED_BLOCK_SIZE_X 512
#define PREDEFINED_BLOCK_SIZE 1024
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace paddle {
namespace operators {
namespace {
void LimitGridDim(const framework::ExecutionContext& ctx, dim3* grid_dim) {
auto max_grid_dim = ctx.template device_context<platform::CUDADeviceContext>()
.GetCUDAMaxGridDimSize();
grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0];
grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1];
}
}
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, typename IndexT = int>
__global__ void IndexSampleForward(const IndexT* index, const T* in_data,
T* out_data, size_t index_length,
size_t input_length, size_t batch_size) {
unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;
for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
index_i = blockDim.x * blockIdx.x + threadIdx.x;
for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i;
IndexT sample_idx = index[index_idx];
out_data[index_idx] = in_data[in_idx - index_i + sample_idx];
}
}
}
template <typename T, typename IndexT = int>
__global__ void IndexSampleGrad(const IndexT* index, T* in_grad,
const T* out_grad, size_t index_length,
size_t input_length, size_t batch_size,
bool same_data_in_row = true) {
unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;
for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
index_i = blockDim.x * blockIdx.x + threadIdx.x;
for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i;
IndexT sample_idx = index[index_idx];
if (same_data_in_row) {
platform::CudaAtomicAdd(&(in_grad[in_idx - index_i + sample_idx]),
out_grad[sample_idx]);
} else {
in_grad[in_idx - index_i + sample_idx] = out_grad[index_idx];
}
}
}
}
template <typename T>
class IndexSampleKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("X");
auto* index = ctx.Input<LoDTensor>("Index");
auto* output = ctx.Output<LoDTensor>("Out");
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
const auto* in_data = input->data<T>();
auto* out_data = output->mutable_data<T>(ctx.GetPlace());
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
auto input_dim = input->dims();
auto index_dim = index->dims();
size_t batch_size = input_dim[0];
size_t input_length = input_dim[1];
size_t index_length = index_dim[1];
auto block_width = platform::RoundToPowerOfTwo(index_length);
block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X);
int block_height =
platform::RoundToPowerOfTwo(index_length * batch_size) / block_width;
block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width);
dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
LimitGridDim(ctx, &grid_dim);
if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
IndexSampleForward<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data, in_data, out_data, index_length, input_length,
batch_size);
} else if (index_type == framework::proto::VarType::INT32) {
const int* index_data = index->data<int>();
IndexSampleForward<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data, in_data, out_data, index_length, input_length,
batch_size);
}
}
};
template <typename T>
class IndexSampleGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* output_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* input_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto* index = ctx.Input<LoDTensor>("Index");
const auto* output_grad_data = output_grad->data<T>();
auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
auto input_num = input_grad->numel();
auto input_dim = input_grad->dims();
auto index_dim = index->dims();
size_t batch_size = index_dim[0];
size_t input_length = input_dim[1];
size_t index_length = index_dim[1];
bool same_data_in_index_row = index_length == 1 ? false : true;
auto block_width = platform::RoundToPowerOfTwo(index_length);
block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X);
auto block_height =
platform::RoundToPowerOfTwo(index_length * batch_size) / block_width;
block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width);
dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
LimitGridDim(ctx, &grid_dim);
phi::funcs::SetConstant<platform::CUDADeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
set_zero(dev_ctx, input_grad, static_cast<T>(0));
if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
IndexSampleGrad<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size, same_data_in_index_row);
} else if (index_type == framework::proto::VarType::INT32) {
const int* index_data = index->data<int>();
IndexSampleGrad<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size, same_data_in_index_row);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
index_sample,
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSampleKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
index_sample_grad,
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSampleGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <fstream>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
template <typename T, typename IndexT = int>
void IndexSampleInner(const framework::ExecutionContext &context,
const LoDTensor &input, const LoDTensor &index,
LoDTensor *output) {
auto input_dims = input.dims();
auto index_dims = index.dims();
int batch_size = input_dims[0];
auto value_length = input_dims[1];
auto index_length = index_dims[1];
int index_ids_num = index.numel();
std::vector<T> input_vec;
std::vector<IndexT> index_vec;
paddle::framework::TensorToVector(input, context.device_context(),
&input_vec);
paddle::framework::TensorToVector(index, context.device_context(),
&index_vec);
std::vector<T> res(index_ids_num);
for (int i = 0; i < index_ids_num; i++) {
int b = floor(i / index_length);
PADDLE_ENFORCE_GE(
index_vec[i], 0,
platform::errors::InvalidArgument(
"Variable value (index) of OP(index_sample) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
value_length, index_vec[i]));
PADDLE_ENFORCE_LT(
index_vec[i], value_length,
platform::errors::InvalidArgument(
"Variable value (index) of OP(index_sample) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
value_length, index_vec[i]));
int v_i = b * value_length + static_cast<int>(index_vec[i]);
T v = input_vec[v_i];
VLOG(4) << "Index Sample: batch = " << b << " index = " << v_i
<< " value = " << v;
res[i] = v;
}
auto ddim = phi::make_ddim({batch_size, index_length});
output->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(res, context.device_context(), output);
output->Resize(ddim);
}
template <typename DeviceContext, typename T>
class IndexSampleKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *input_var = ctx.InputVar("X");
auto *index_var = ctx.InputVar("Index");
auto &input_tensor = input_var->Get<LoDTensor>();
auto &index_tensor = index_var->Get<LoDTensor>();
auto *out_var = ctx.OutputVar("Out");
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
const auto &index_type =
framework::TransToProtoVarType(index_tensor.dtype());
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
IndexSampleInner<T, int>(ctx, input_tensor, index_tensor, out_tensor);
} else if (index_type == framework::proto::VarType::INT64) {
IndexSampleInner<T, int64_t>(ctx, input_tensor, index_tensor, out_tensor);
}
}
};
template <typename T, typename IndexT = int>
void IndexSampleGradInner(const framework::ExecutionContext &context,
const LoDTensor &out_grad, const LoDTensor &index,
LoDTensor *x_grad) {
std::vector<T> out_grad_vec;
std::vector<IndexT> index_vec;
paddle::framework::TensorToVector(out_grad, context.device_context(),
&out_grad_vec);
paddle::framework::TensorToVector(index, context.device_context(),
&index_vec);
auto index_dims = index.dims();
auto x_grad_dims = x_grad->dims();
auto value_length = x_grad_dims[1];
auto index_length = index_dims[1];
int index_ids_num = index.numel();
std::vector<T> x_grad_vec(x_grad->numel(), 0);
for (int i = 0; i < index_ids_num; i++) {
int b = floor(i / index_length);
PADDLE_ENFORCE_GE(
index_vec[i], 0,
platform::errors::InvalidArgument(
"Variable value (index) of OP(index_sample_grad) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
value_length, index_vec[i]));
PADDLE_ENFORCE_LT(
index_vec[i], value_length,
platform::errors::InvalidArgument(
"Variable value (index) of OP(index_sample_grad) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
value_length, index_vec[i]));
int v_i = b * value_length + static_cast<int>(index_vec[i]);
x_grad_vec[v_i] += out_grad_vec[i];
}
x_grad->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(x_grad_vec, context.device_context(), x_grad);
x_grad->Resize(x_grad_dims);
}
template <typename DeviceContext, typename T>
class IndexSampleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *index_var = context.InputVar("Index");
auto *x_grad_var = context.OutputVar(framework::GradVarName("X"));
auto *out_grad_var = context.InputVar(framework::GradVarName("Out"));
auto &index_tensor = index_var->Get<LoDTensor>();
auto &out_grad_tensor = out_grad_var->Get<LoDTensor>();
auto *x_grad_tensor = x_grad_var->GetMutable<framework::LoDTensor>();
const auto &index_type =
framework::TransToProtoVarType(index_tensor.dtype());
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
IndexSampleGradInner<T, int>(context, out_grad_tensor, index_tensor,
x_grad_tensor);
} else if (index_type == framework::proto::VarType::INT64) {
IndexSampleGradInner<T, int64_t>(context, out_grad_tensor, index_tensor,
x_grad_tensor);
}
}
};
} // namespace operators
} // namespace paddle
...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/index_sample_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle { namespace paddle {
......
...@@ -225,6 +225,41 @@ void HuberLossInferMeta(const MetaTensor& input, ...@@ -225,6 +225,41 @@ void HuberLossInferMeta(const MetaTensor& input,
out->share_lod(input); out->share_lod(input);
} }
void IndexSampleInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out,
MetaConfig config) {
auto input_dims = x.dims();
PADDLE_ENFORCE_EQ(input_dims.size(),
2,
errors::InvalidArgument(
"Inputs(X) shape of IndexSample op should be 2-D, but "
"got X's shape = [%s], please check X shape.",
input_dims));
auto index_dims = y.dims();
PADDLE_ENFORCE_EQ(
index_dims.size(),
2,
errors::InvalidArgument(
"Inputs(Index) shape of IndexSample op should be 2-D, but "
"got Index's shape [%s] , please check index shape.",
input_dims));
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(input_dims[0],
index_dims[0],
errors::InvalidArgument(
"Inputs(X)'s value of dimension 0 must same with "
"Inputs(Index)'s value of dimension 0, but "
"got %d of Inputs(X), and got %d of Inputs(Index), "
"please check Inputs shape.",
input_dims[0],
index_dims[0]));
}
out->set_dtype(x.dtype());
out->set_dims(index_dims);
out->share_lod(y);
}
void CrossInferMeta(const MetaTensor& x, void CrossInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
int axis, int axis,
......
...@@ -53,6 +53,11 @@ void HuberLossInferMeta(const MetaTensor& input_meta, ...@@ -53,6 +53,11 @@ void HuberLossInferMeta(const MetaTensor& input_meta,
MetaTensor* residual, MetaTensor* residual,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void IndexSampleInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out,
MetaConfig config = MetaConfig());
void CrossInferMeta(const MetaTensor& x, void CrossInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
int axis, int axis,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_sample_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context, typename IndexT = int>
void IndexSampleGradInner(const Context& context,
const DenseTensor& out_grad,
const DenseTensor& index,
DenseTensor* x_grad) {
std::vector<T> out_grad_vec;
std::vector<IndexT> index_vec;
paddle::framework::TensorToVector(out_grad, context, &out_grad_vec);
paddle::framework::TensorToVector(index, context, &index_vec);
auto index_dims = index.dims();
auto x_grad_dims = x_grad->dims();
auto value_length = x_grad_dims[1];
auto index_length = index_dims[1];
int index_ids_num = index.numel();
std::vector<T> x_grad_vec(x_grad->numel(), 0);
for (int i = 0; i < index_ids_num; i++) {
int b = floor(i / index_length);
PADDLE_ENFORCE_GE(
index_vec[i],
0,
errors::InvalidArgument(
"Variable value (index) of OP(index_sample_grad) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
value_length,
index_vec[i]));
PADDLE_ENFORCE_LT(
index_vec[i],
value_length,
errors::InvalidArgument(
"Variable value (index) of OP(index_sample_grad) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
value_length,
index_vec[i]));
int v_i = b * value_length + static_cast<int>(index_vec[i]);
x_grad_vec[v_i] += out_grad_vec[i];
}
context.template Alloc<T>(x_grad);
paddle::framework::TensorFromVector(x_grad_vec, context, x_grad);
x_grad->Resize(x_grad_dims);
}
template <typename T, typename Context>
void IndexSampleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& index,
DenseTensor* x_grad) {
auto index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(index_type)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
if (index_type == DataType::INT32) {
IndexSampleGradInner<T, Context, int>(ctx, out_grad, index, x_grad);
} else if (index_type == DataType::INT64) {
IndexSampleGradInner<T, Context, int64_t>(ctx, out_grad, index, x_grad);
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_sample_grad,
CPU,
ALL_LAYOUT,
phi::IndexSampleGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_sample_kernel.h"
#include <cmath>
#include <fstream>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context, typename IndexT = int>
void IndexSampleInner(const Context &context,
const DenseTensor &input,
const DenseTensor &index,
DenseTensor *output) {
auto input_dims = input.dims();
auto index_dims = index.dims();
int batch_size = input_dims[0];
auto value_length = input_dims[1];
auto index_length = index_dims[1];
int index_ids_num = index.numel();
std::vector<T> input_vec;
std::vector<IndexT> index_vec;
paddle::framework::TensorToVector(input, context, &input_vec);
paddle::framework::TensorToVector(index, context, &index_vec);
std::vector<T> res(index_ids_num);
for (int i = 0; i < index_ids_num; i++) {
int b = floor(i / index_length);
PADDLE_ENFORCE_GE(
index_vec[i],
0,
errors::InvalidArgument(
"Variable value (index) of OP(index_sample) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
value_length,
index_vec[i]));
PADDLE_ENFORCE_LT(
index_vec[i],
value_length,
errors::InvalidArgument(
"Variable value (index) of OP(index_sample) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
value_length,
index_vec[i]));
int v_i = b * value_length + static_cast<int>(index_vec[i]);
T v = input_vec[v_i];
VLOG(4) << "Index Sample: batch = " << b << " index = " << v_i
<< " value = " << v;
res[i] = v;
}
auto ddim = phi::make_ddim({batch_size, index_length});
context.template Alloc<T>(output);
paddle::framework::TensorFromVector(res, context, output);
output->Resize(ddim);
}
template <typename T, typename Context>
void IndexSampleKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &index,
DenseTensor *out) {
ctx.template Alloc<T>(out);
auto index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(index_type)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
if (index_type == DataType::INT32) {
IndexSampleInner<T, Context, int>(ctx, x, index, out);
} else if (index_type == DataType::INT64) {
IndexSampleInner<T, Context, int64_t>(ctx, x, index, out);
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_sample,
CPU,
ALL_LAYOUT,
phi::IndexSampleKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_sample_grad_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace {
template <typename Context>
void LimitGridDim(const Context& ctx, dim3* grid_dim) {
auto max_grid_dim =
reinterpret_cast<const phi::GPUContext&>(ctx).GetCUDAMaxGridDimSize();
grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0];
grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1];
}
#define PREDEFINED_BLOCK_SIZE_X 512
#define PREDEFINED_BLOCK_SIZE 1024
#define MIN(a, b) ((a) < (b) ? (a) : (b))
};
template <typename T, typename IndexT = int>
__global__ void IndexSampleGrad(const IndexT* index,
T* in_grad,
const T* out_grad,
size_t index_length,
size_t input_length,
size_t batch_size,
bool same_data_in_row = true) {
unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;
for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
index_i = blockDim.x * blockIdx.x + threadIdx.x;
for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i;
IndexT sample_idx = index[index_idx];
if (same_data_in_row) {
paddle::platform::CudaAtomicAdd(
&(in_grad[in_idx - index_i + sample_idx]), out_grad[sample_idx]);
} else {
in_grad[in_idx - index_i + sample_idx] = out_grad[index_idx];
}
}
}
}
template <typename T, typename Context>
void IndexSampleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& index,
DenseTensor* x_grad) {
const T* output_grad_data = out_grad.data<T>();
T* input_grad_data = ctx.template Alloc<T>(x_grad);
auto index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(index_type)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
auto input_num = x.numel();
auto input_dim = x.dims();
auto index_dim = index.dims();
size_t batch_size = index_dim[0];
size_t input_length = input_dim[1];
size_t index_length = index_dim[1];
bool same_data_in_index_row = index_length == 1 ? false : true;
auto block_width = paddle::platform::RoundToPowerOfTwo(index_length);
block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X);
auto block_height =
paddle::platform::RoundToPowerOfTwo(index_length * batch_size) /
block_width;
block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width);
dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
LimitGridDim(ctx, &grid_dim);
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, x_grad, static_cast<T>(0));
if (index_type == DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
IndexSampleGrad<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data,
input_grad_data,
output_grad_data,
index_length,
input_length,
batch_size,
same_data_in_index_row);
} else if (index_type == DataType::INT32) {
const int* index_data = index.data<int>();
IndexSampleGrad<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data,
input_grad_data,
output_grad_data,
index_length,
input_length,
batch_size,
same_data_in_index_row);
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_sample_grad,
GPU,
ALL_LAYOUT,
phi::IndexSampleGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_sample_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace {
template <typename Context>
void LimitGridDim(const Context& ctx, dim3* grid_dim) {
auto max_grid_dim =
reinterpret_cast<const phi::GPUContext&>(ctx).GetCUDAMaxGridDimSize();
grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0];
grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1];
}
#define PREDEFINED_BLOCK_SIZE_X 512
#define PREDEFINED_BLOCK_SIZE 1024
#define MIN(a, b) ((a) < (b) ? (a) : (b))
}
template <typename T, typename IndexT = int>
__global__ void IndexSampleForward(const IndexT* index,
const T* in_data,
T* out_data,
size_t index_length,
size_t input_length,
size_t batch_size) {
unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;
for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
index_i = blockDim.x * blockIdx.x + threadIdx.x;
for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i;
IndexT sample_idx = index[index_idx];
out_data[index_idx] = in_data[in_idx - index_i + sample_idx];
}
}
}
template <typename T, typename Context>
void IndexSampleKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
DenseTensor* out) {
auto index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(index_type)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
const T* in_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
auto input_dim = x.dims();
auto index_dim = index.dims();
size_t batch_size = input_dim[0];
size_t input_length = input_dim[1];
size_t index_length = index_dim[1];
auto block_width = paddle::platform::RoundToPowerOfTwo(index_length);
block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X);
int block_height =
paddle::platform::RoundToPowerOfTwo(index_length * batch_size) /
block_width;
block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width);
dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
LimitGridDim(ctx, &grid_dim);
if (index_type == DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
IndexSampleForward<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data, in_data, out_data, index_length, input_length, batch_size);
} else if (index_type == DataType::INT32) {
const int* index_data = index.data<int>();
IndexSampleForward<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data, in_data, out_data, index_length, input_length, batch_size);
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_sample,
GPU,
ALL_LAYOUT,
phi::IndexSampleKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void IndexSampleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& index,
DenseTensor* in_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void IndexSampleKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature IndexSampleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("index_sample_grad",
{GradVarName("Out"), "X", "Index"},
{},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(index_sample_grad,
phi::IndexSampleGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册