未验证 提交 bd294378 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix gather & stack op (#14355)

* Add int type support for stack_op

* Improve gather op to support index with shape N x 1

test=develop

* Fix stack_op kernel's registry

test=develop
上级 9d4425dd
...@@ -50,7 +50,9 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -50,7 +50,9 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place)); // PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1); PADDLE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1));
int index_size = index.dims()[0]; int index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
......
...@@ -38,7 +38,8 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -38,7 +38,8 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1); PADDLE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1));
int64_t index_size = index.dims()[0]; int64_t index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
......
...@@ -31,7 +31,8 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -31,7 +31,8 @@ class GatherOp : public framework::OperatorWithKernel {
"Output(Out) of GatherOp should not be null."); "Output(Out) of GatherOp should not be null.");
auto index_dims = ctx->GetInputDim("Index"); auto index_dims = ctx->GetInputDim("Index");
PADDLE_ENFORCE(index_dims.size() == 1); PADDLE_ENFORCE(index_dims.size() == 1 ||
(index_dims.size() == 2 && index_dims[1] == 1));
int batch_size = ctx->GetInputDim("Index")[0]; int batch_size = ctx->GetInputDim("Index")[0];
framework::DDim output_dims(ctx->GetInputDim("X")); framework::DDim output_dims(ctx->GetInputDim("X"));
output_dims[0] = batch_size; output_dims[0] = batch_size;
...@@ -53,6 +54,7 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -53,6 +54,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
} }
protected: protected:
...@@ -75,7 +77,7 @@ Gather Operator. ...@@ -75,7 +77,7 @@ Gather Operator.
$Out = X[Index]$ $Out = X[Index]$
Out is obtained by gathering entries of the outer-most dimension Out is obtained by gathering entries of the outer-most dimension
of X indexed by Index and concatenate them together. of X indexed by Index and concatenate them together.
Example: Example:
......
...@@ -51,7 +51,8 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -51,7 +51,8 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place)); // PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1); PADDLE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1));
int index_size = index.dims()[0]; int index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
......
...@@ -37,7 +37,8 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -37,7 +37,8 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1); PADDLE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1));
int index_size = index.dims()[0]; int index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
......
...@@ -21,8 +21,12 @@ REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker, ...@@ -21,8 +21,12 @@ REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad); REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);
REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel<plat::CPUDeviceContext, float>, REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel<plat::CPUDeviceContext, float>,
ops::StackKernel<plat::CPUDeviceContext, double>); ops::StackKernel<plat::CPUDeviceContext, double>,
ops::StackKernel<plat::CPUDeviceContext, int>,
ops::StackKernel<plat::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(stack_grad, REGISTER_OP_CPU_KERNEL(stack_grad,
ops::StackGradKernel<plat::CPUDeviceContext, float>, ops::StackGradKernel<plat::CPUDeviceContext, float>,
ops::StackGradKernel<plat::CPUDeviceContext, double>); ops::StackGradKernel<plat::CPUDeviceContext, double>,
ops::StackGradKernel<plat::CPUDeviceContext, int>,
ops::StackGradKernel<plat::CPUDeviceContext, int64_t>);
...@@ -18,8 +18,12 @@ namespace plat = paddle::platform; ...@@ -18,8 +18,12 @@ namespace plat = paddle::platform;
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(stack, ops::StackKernel<plat::CUDADeviceContext, float>, REGISTER_OP_CUDA_KERNEL(stack, ops::StackKernel<plat::CUDADeviceContext, float>,
ops::StackKernel<plat::CUDADeviceContext, double>); ops::StackKernel<plat::CUDADeviceContext, double>,
ops::StackKernel<plat::CUDADeviceContext, int>,
ops::StackKernel<plat::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(stack_grad, REGISTER_OP_CUDA_KERNEL(stack_grad,
ops::StackGradKernel<plat::CUDADeviceContext, float>, ops::StackGradKernel<plat::CUDADeviceContext, float>,
ops::StackGradKernel<plat::CUDADeviceContext, double>); ops::StackGradKernel<plat::CUDADeviceContext, double>,
ops::StackGradKernel<plat::CUDADeviceContext, int>,
ops::StackGradKernel<plat::CUDADeviceContext, int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册