未验证 提交 bd294378 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix gather & stack op (#14355)

* Add int type support for stack_op

* Improve gather op to support index with shape N x 1

test=develop

* Fix stack_op kernel's registry

test=develop
上级 9d4425dd
......@@ -50,7 +50,9 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1);
PADDLE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1));
int index_size = index.dims()[0];
auto src_dims = src.dims();
......
......@@ -38,7 +38,8 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1);
PADDLE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1));
int64_t index_size = index.dims()[0];
auto src_dims = src.dims();
......
......@@ -31,7 +31,8 @@ class GatherOp : public framework::OperatorWithKernel {
"Output(Out) of GatherOp should not be null.");
auto index_dims = ctx->GetInputDim("Index");
PADDLE_ENFORCE(index_dims.size() == 1);
PADDLE_ENFORCE(index_dims.size() == 1 ||
(index_dims.size() == 2 && index_dims[1] == 1));
int batch_size = ctx->GetInputDim("Index")[0];
framework::DDim output_dims(ctx->GetInputDim("X"));
output_dims[0] = batch_size;
......@@ -53,6 +54,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
}
protected:
......
......@@ -51,7 +51,8 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1);
PADDLE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1));
int index_size = index.dims()[0];
auto src_dims = src.dims();
......
......@@ -37,7 +37,8 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1);
PADDLE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1));
int index_size = index.dims()[0];
auto src_dims = src.dims();
......
......@@ -21,8 +21,12 @@ REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);
REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel<plat::CPUDeviceContext, float>,
ops::StackKernel<plat::CPUDeviceContext, double>);
ops::StackKernel<plat::CPUDeviceContext, double>,
ops::StackKernel<plat::CPUDeviceContext, int>,
ops::StackKernel<plat::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(stack_grad,
ops::StackGradKernel<plat::CPUDeviceContext, float>,
ops::StackGradKernel<plat::CPUDeviceContext, double>);
ops::StackGradKernel<plat::CPUDeviceContext, double>,
ops::StackGradKernel<plat::CPUDeviceContext, int>,
ops::StackGradKernel<plat::CPUDeviceContext, int64_t>);
......@@ -18,8 +18,12 @@ namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(stack, ops::StackKernel<plat::CUDADeviceContext, float>,
ops::StackKernel<plat::CUDADeviceContext, double>);
ops::StackKernel<plat::CUDADeviceContext, double>,
ops::StackKernel<plat::CUDADeviceContext, int>,
ops::StackKernel<plat::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(stack_grad,
ops::StackGradKernel<plat::CUDADeviceContext, float>,
ops::StackGradKernel<plat::CUDADeviceContext, double>);
ops::StackGradKernel<plat::CUDADeviceContext, double>,
ops::StackGradKernel<plat::CUDADeviceContext, int>,
ops::StackGradKernel<plat::CUDADeviceContext, int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册