From abc17ef7e2481c0d6769a4d59db37c461811e433 Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Fri, 11 Jun 2021 20:25:36 +0800 Subject: [PATCH] Fix gather infer shape using axis (#33413) * fix gather shape bug * fix None * fix topo --- paddle/fluid/operators/gather.cu.h | 26 ++--- paddle/fluid/operators/gather.h | 26 ++--- paddle/fluid/operators/gather_op.cc | 33 +++++- paddle/fluid/operators/gather_op.cu | 108 +++++++----------- paddle/fluid/operators/gather_op.h | 92 +++++---------- .../fluid/tests/unittests/test_gather_op.py | 1 + python/paddle/tensor/manipulation.py | 37 +++--- 7 files changed, 134 insertions(+), 189 deletions(-) diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index 94fe45dac0c..95cb428abdf 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -202,12 +202,11 @@ __global__ void GatherGradGPUKernel(const T* input, const U* index, T* out, } } -template +template void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, + const int axis, Tensor* out, const paddle::platform::Place& place, const framework::ExecutionContext& ctx) { - int axis_size = axis->numel(); int index_size = index->numel(); int input_size = input->numel(); auto input_dim = input->dims(); @@ -215,12 +214,8 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, auto* index_data = index->data(); if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - Tensor cpu_axis; - framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); - int axis_index = cpu_axis.data()[0]; + + int axis_index = axis; int index_dim_size = input_dim[axis_index]; int inner_dim_size = 1; @@ -251,26 +246,19 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, index_size, index_dim_size, out_size); } -template +template void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, + const int axis, Tensor* out, const paddle::platform::Place& place, const framework::ExecutionContext& ctx) { auto* index_data = index->data(); - - int axis_size = axis->numel(); int index_size = index->numel(); int input_size = input->numel(); auto input_dim = input->dims(); auto* input_data = input->data(); if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - Tensor cpu_axis; - framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); - int axis_index = cpu_axis.data()[0]; + int axis_index = axis; int input_index_dim_size = input_dim[axis_index]; int inner_dim_size = 1; diff --git a/paddle/fluid/operators/gather.h b/paddle/fluid/operators/gather.h index c12a3b8adc9..8deab709220 100644 --- a/paddle/fluid/operators/gather.h +++ b/paddle/fluid/operators/gather.h @@ -126,24 +126,17 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input, } } -template -void GatherV2Function(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, - const paddle::platform::Place& place) { - auto* axis_data = axis->data(); +template +void GatherV2Function(const Tensor* input, const Tensor* index, int axis, + Tensor* out, const paddle::platform::Place& place) { auto* index_data = index->data(); - - int axis_size = axis->numel(); int index_size = index->numel(); int input_size = input->numel(); auto input_dim = input->dims(); auto* input_data = input->data(); if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - int axis_index = axis_data[0]; + int axis_index = axis; int input_index_dim_size = input_dim[axis_index]; for (int i = 0; i < index_size; i++) { @@ -186,22 +179,17 @@ void GatherV2Function(const Tensor* input, const Tensor* index, } } -template +template void GatherV2GradFunction(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, + const int axis, Tensor* out, const paddle::platform::Place& place) { - auto* axis_data = axis->data(); auto* index_data = index->data(); - int axis_size = axis->numel(); auto input_dim = input->dims(); auto* input_data = input->data(); if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - int axis_index = axis_data[0]; + int axis_index = axis; int input_index_dim_size = input_dim[axis_index]; int inner_dim_size = 1; diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 162766546b3..ea28c204ec9 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/op_version_registry.h" + namespace paddle { namespace operators { @@ -52,11 +53,29 @@ class GatherOp : public framework::OperatorWithKernel { index_dims.size())); } - int batch_size = ctx->GetInputDim("Index")[0]; - framework::DDim output_dims(ctx->GetInputDim("X")); - output_dims[0] = batch_size; - ctx->SetOutputDim("Out", output_dims); - ctx->ShareLoD("X", /*->*/ "Out"); + auto axis = ctx->Attrs().Get("axis"); + auto input_dim = ctx->GetInputDim("X"); + if (ctx->HasInput("Axis") || axis == 0) { + // if HasInput("Axis"), we can not obtain correct shape of output + int batch_size = index_dims[0]; + framework::DDim output_dims(input_dim); + output_dims[0] = batch_size; + ctx->SetOutputDim("Out", output_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } else { + int index_size = index_dims[0]; + std::vector out_dim_vec; + for (int i = 0; i < axis; i++) { + out_dim_vec.push_back(input_dim[i]); + } + out_dim_vec.push_back(index_size); + for (int i = axis + 1; i < input_dim.size(); i++) { + out_dim_vec.push_back(input_dim[i]); + } + auto output_dims = framework::make_ddim(out_dim_vec); + ctx->SetOutputDim("Out", output_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } } protected: @@ -120,6 +139,10 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker { "If true, update the grad using the overwrite mode in same index," "If false, using the accumulate mode in same index.") .SetDefault(true); + AddAttr( + "axis", + "The Tensor which contains the axis that we do gather operation.") + .SetDefault(0); AddComment(R"DOC( Gather Operator. diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 37fbfb21f60..6e27d95e018 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -31,47 +31,33 @@ class GatherOpCUDAKernel : public framework::OpKernel { auto *index = ctx.Input("Index"); auto *output = ctx.Output("Out"); + int axis = ctx.Attr("axis"); + + // get axis from tensor if (ctx.HasInput("Axis")) { - const Tensor *axis = ctx.Input("Axis"); - const auto &index_type = index->type(); - const auto &axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2CUDAFunction(x, index, axis, output, place, - ctx); - } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2CUDAFunction(x, index, axis, output, place, - ctx); + Tensor cpu_axis; + const Tensor *axis_tensor = ctx.Input("Axis"); + framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); + const auto &axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(cpu_axis.data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(cpu_axis.data()[0]); } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2CUDAFunction(x, index, axis, output, place, - ctx); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2CUDAFunction(x, index, axis, output, place, - ctx); + } + const auto &place = ctx.GetPlace(); + const auto &index_type = index->type(); + if (axis != 0) { + if (index_type == framework::proto::VarType::INT32) { + GatherV2CUDAFunction(x, index, axis, output, place, ctx); + } else if (index_type == framework::proto::VarType::INT64) { + GatherV2CUDAFunction(x, index, axis, output, place, ctx); } return; } + output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { GPUGather(ctx.device_context(), *x, *index, output); } else if (index_type == framework::proto::VarType::INT64) { @@ -91,30 +77,27 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); + int axis = ctx.Attr("axis"); if (ctx.HasInput("Axis")) { - const Tensor *axis = ctx.Input("Axis"); - const auto &index_type = index->type(); - const auto &axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - place, ctx); + const Tensor *axis_tensor = ctx.Input("Axis"); + Tensor cpu_axis; + framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); + const auto &axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(cpu_axis.data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(cpu_axis.data()[0]); } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - place, ctx); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - place, ctx); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - place, ctx); + } + + const auto &index_type = index->type(); + if (axis != 0) { + if (index_type == framework::proto::VarType::INT32) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + ctx.GetPlace(), ctx); + } else if (index_type == framework::proto::VarType::INT64) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + ctx.GetPlace(), ctx); } return; } @@ -125,19 +108,6 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { .eigen_device(); dxt.device(place) = dxt.constant(static_cast(0)); if (dO->numel() == 0) return; - - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { GPUScatterAssign(ctx, *dO, *index, dX, ctx.Attr("overwrite")); diff --git a/paddle/fluid/operators/gather_op.h b/paddle/fluid/operators/gather_op.h index 8ec0d6ce0b6..a2570c3e014 100644 --- a/paddle/fluid/operators/gather_op.h +++ b/paddle/fluid/operators/gather_op.h @@ -35,45 +35,30 @@ class GatherOpKernel : public framework::OpKernel { auto *index = ctx.Input("Index"); auto *output = ctx.Output("Out"); + int axis = ctx.Attr("axis"); + // get axis from tensor if (ctx.HasInput("Axis")) { - const Tensor *axis = ctx.Input("Axis"); - const auto &index_type = index->type(); - const auto &axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2Function(x, index, axis, output, place); + const Tensor *axis_tensor = ctx.Input("Axis"); + const auto &axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(axis_tensor->data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(axis_tensor->data()[0]); } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2Function(x, index, axis, output, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2Function(x, index, axis, output, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2Function(x, index, axis, output, place); + } + const auto &place = ctx.GetPlace(); + const auto &index_type = index->type(); + if (axis != 0) { + if (index_type == framework::proto::VarType::INT32) { + GatherV2Function(x, index, axis, output, place); + } else if (index_type == framework::proto::VarType::INT64) { + GatherV2Function(x, index, axis, output, place); } return; } output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; - - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { CPUGather(ctx.device_context(), *x, *index, output); } else if (index_type == framework::proto::VarType::INT64) { @@ -94,26 +79,23 @@ class GatherGradientOpKernel : public framework::OpKernel { auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); + int axis = ctx.Attr("axis"); if (ctx.HasInput("Axis")) { - const Tensor *axis = ctx.Input("Axis"); - const auto &index_type = index->type(); - const auto &axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradFunction(dO, index, axis, dX, place); + const Tensor *axis_tensor = ctx.Input("Axis"); + const auto &axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(axis_tensor->data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(axis_tensor->data()[0]); } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradFunction(dO, index, axis, dX, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradFunction(dO, index, axis, dX, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradFunction(dO, index, axis, dX, place); + } + const auto &index_type = index->type(); + + if (axis != 0) { + if (index_type == framework::proto::VarType::INT32) { + GatherV2GradFunction(dO, index, axis, dX, ctx.GetPlace()); + } else if (index_type == framework::proto::VarType::INT64) { + GatherV2GradFunction(dO, index, axis, dX, ctx.GetPlace()); } return; } @@ -126,18 +108,6 @@ class GatherGradientOpKernel : public framework::OpKernel { if (dO->numel() == 0) return; bool overwrite = ctx.Attr("overwrite"); - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { if (overwrite) { ScatterAssign(ctx.device_context(), *dO, *index, dX); diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 946027a22f8..2d56441bf3e 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -182,6 +182,7 @@ class TestGatherOp4(TestGatherOp1): self.index_type = "int64" self.axis = [0] self.axis_type = "int32" + self.attrs = {'overwrite': False} class API_TestGather(unittest.TestCase): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 67e6c7f8e44..c3031c41279 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -862,34 +862,39 @@ def gather(x, index, axis=None, name=None): """ if axis is None: axis = 0 - axis_tensor = axis - if not isinstance(axis, Variable) and axis == 0: - return paddle.fluid.layers.gather(input=x, index=index, overwrite=False) - if not isinstance(axis, Variable): - with device_guard("cpu"): - axis_tensor = fill_constant( - shape=[1], dtype='int64', value=axis, force_cpu=True) + if in_dygraph_mode(): - return core.ops.gather(x, index, axis_tensor) + axis = axis.item() if isinstance(axis, paddle.Tensor) else axis + return core.ops.gather(x, index, None, "axis", axis, "overwrite", False) check_variable_and_dtype( x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], 'gather') check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') + if isinstance(axis, Variable): check_variable_and_dtype(axis, 'axis', ['int32', 'int64'], 'gather') - else: - check_type(axis, 'axis', (int), 'gather') helper = LayerHelper('gather', **locals()) dtype = helper.input_dtype('x') out = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type="gather", - inputs={"X": x, - "Index": index, - "Axis": axis_tensor}, - outputs={"Out": out}) + if not isinstance(axis, Variable): + helper.append_op( + type="gather", + inputs={"X": x, + "Index": index}, + attrs={'axis': axis, + 'overwrite': False}, + outputs={"Out": out}) + else: + helper.append_op( + type="gather", + inputs={"X": x, + "Index": index, + "Axis": axis}, + attrs={"overwrite": False}, + outputs={"Out": out}) + return out -- GitLab