diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index 94fe45dac0ce782d6e8f81c737de10b5aefdaaa5..95cb428abdf34edaff7ab2ad77d2a6ac9cd438f3 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -202,12 +202,11 @@ __global__ void GatherGradGPUKernel(const T* input, const U* index, T* out, } } -template +template void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, + const int axis, Tensor* out, const paddle::platform::Place& place, const framework::ExecutionContext& ctx) { - int axis_size = axis->numel(); int index_size = index->numel(); int input_size = input->numel(); auto input_dim = input->dims(); @@ -215,12 +214,8 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, auto* index_data = index->data(); if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - Tensor cpu_axis; - framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); - int axis_index = cpu_axis.data()[0]; + + int axis_index = axis; int index_dim_size = input_dim[axis_index]; int inner_dim_size = 1; @@ -251,26 +246,19 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, index_size, index_dim_size, out_size); } -template +template void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, + const int axis, Tensor* out, const paddle::platform::Place& place, const framework::ExecutionContext& ctx) { auto* index_data = index->data(); - - int axis_size = axis->numel(); int index_size = index->numel(); int input_size = input->numel(); auto input_dim = input->dims(); auto* input_data = input->data(); if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - Tensor cpu_axis; - framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis); - int axis_index = cpu_axis.data()[0]; + int axis_index = axis; int input_index_dim_size = input_dim[axis_index]; int inner_dim_size = 1; diff --git a/paddle/fluid/operators/gather.h b/paddle/fluid/operators/gather.h index c12a3b8adc97893f523b307a56c0e6b04ea8d675..8deab709220d7f6d5988b58f1d4cbb5540836ac8 100644 --- a/paddle/fluid/operators/gather.h +++ b/paddle/fluid/operators/gather.h @@ -126,24 +126,17 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input, } } -template -void GatherV2Function(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, - const paddle::platform::Place& place) { - auto* axis_data = axis->data(); +template +void GatherV2Function(const Tensor* input, const Tensor* index, int axis, + Tensor* out, const paddle::platform::Place& place) { auto* index_data = index->data(); - - int axis_size = axis->numel(); int index_size = index->numel(); int input_size = input->numel(); auto input_dim = input->dims(); auto* input_data = input->data(); if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - int axis_index = axis_data[0]; + int axis_index = axis; int input_index_dim_size = input_dim[axis_index]; for (int i = 0; i < index_size; i++) { @@ -186,22 +179,17 @@ void GatherV2Function(const Tensor* input, const Tensor* index, } } -template +template void GatherV2GradFunction(const Tensor* input, const Tensor* index, - const Tensor* axis, Tensor* out, + const int axis, Tensor* out, const paddle::platform::Place& place) { - auto* axis_data = axis->data(); auto* index_data = index->data(); - int axis_size = axis->numel(); auto input_dim = input->dims(); auto* input_data = input->data(); if (input->numel() == 0) return; - PADDLE_ENFORCE_EQ(axis_size, 1, - platform::errors::InvalidArgument( - "Axis size should be 1, but received %d", axis_size)); - int axis_index = axis_data[0]; + int axis_index = axis; int input_index_dim_size = input_dim[axis_index]; int inner_dim_size = 1; diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 162766546b3c264ebaf6d833adf9b04c38251f8e..ea28c204ec9cf9e63f1dace5c4a9188b0f1c1719 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/op_version_registry.h" + namespace paddle { namespace operators { @@ -52,11 +53,29 @@ class GatherOp : public framework::OperatorWithKernel { index_dims.size())); } - int batch_size = ctx->GetInputDim("Index")[0]; - framework::DDim output_dims(ctx->GetInputDim("X")); - output_dims[0] = batch_size; - ctx->SetOutputDim("Out", output_dims); - ctx->ShareLoD("X", /*->*/ "Out"); + auto axis = ctx->Attrs().Get("axis"); + auto input_dim = ctx->GetInputDim("X"); + if (ctx->HasInput("Axis") || axis == 0) { + // if HasInput("Axis"), we can not obtain correct shape of output + int batch_size = index_dims[0]; + framework::DDim output_dims(input_dim); + output_dims[0] = batch_size; + ctx->SetOutputDim("Out", output_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } else { + int index_size = index_dims[0]; + std::vector out_dim_vec; + for (int i = 0; i < axis; i++) { + out_dim_vec.push_back(input_dim[i]); + } + out_dim_vec.push_back(index_size); + for (int i = axis + 1; i < input_dim.size(); i++) { + out_dim_vec.push_back(input_dim[i]); + } + auto output_dims = framework::make_ddim(out_dim_vec); + ctx->SetOutputDim("Out", output_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } } protected: @@ -120,6 +139,10 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker { "If true, update the grad using the overwrite mode in same index," "If false, using the accumulate mode in same index.") .SetDefault(true); + AddAttr( + "axis", + "The Tensor which contains the axis that we do gather operation.") + .SetDefault(0); AddComment(R"DOC( Gather Operator. diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 37fbfb21f60a0568390c6798dc305c91fc8af886..6e27d95e01855ce6aa15e51b5a4768509be440f6 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -31,47 +31,33 @@ class GatherOpCUDAKernel : public framework::OpKernel { auto *index = ctx.Input("Index"); auto *output = ctx.Output("Out"); + int axis = ctx.Attr("axis"); + + // get axis from tensor if (ctx.HasInput("Axis")) { - const Tensor *axis = ctx.Input("Axis"); - const auto &index_type = index->type(); - const auto &axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2CUDAFunction(x, index, axis, output, place, - ctx); - } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2CUDAFunction(x, index, axis, output, place, - ctx); + Tensor cpu_axis; + const Tensor *axis_tensor = ctx.Input("Axis"); + framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); + const auto &axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(cpu_axis.data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(cpu_axis.data()[0]); } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2CUDAFunction(x, index, axis, output, place, - ctx); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2CUDAFunction(x, index, axis, output, place, - ctx); + } + const auto &place = ctx.GetPlace(); + const auto &index_type = index->type(); + if (axis != 0) { + if (index_type == framework::proto::VarType::INT32) { + GatherV2CUDAFunction(x, index, axis, output, place, ctx); + } else if (index_type == framework::proto::VarType::INT64) { + GatherV2CUDAFunction(x, index, axis, output, place, ctx); } return; } + output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { GPUGather(ctx.device_context(), *x, *index, output); } else if (index_type == framework::proto::VarType::INT64) { @@ -91,30 +77,27 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); + int axis = ctx.Attr("axis"); if (ctx.HasInput("Axis")) { - const Tensor *axis = ctx.Input("Axis"); - const auto &index_type = index->type(); - const auto &axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - place, ctx); + const Tensor *axis_tensor = ctx.Input("Axis"); + Tensor cpu_axis; + framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); + const auto &axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(cpu_axis.data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(cpu_axis.data()[0]); } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - place, ctx); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - place, ctx); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradCUDAFunction(dO, index, axis, dX, - place, ctx); + } + + const auto &index_type = index->type(); + if (axis != 0) { + if (index_type == framework::proto::VarType::INT32) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + ctx.GetPlace(), ctx); + } else if (index_type == framework::proto::VarType::INT64) { + GatherV2GradCUDAFunction(dO, index, axis, dX, + ctx.GetPlace(), ctx); } return; } @@ -125,19 +108,6 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { .eigen_device(); dxt.device(place) = dxt.constant(static_cast(0)); if (dO->numel() == 0) return; - - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { GPUScatterAssign(ctx, *dO, *index, dX, ctx.Attr("overwrite")); diff --git a/paddle/fluid/operators/gather_op.h b/paddle/fluid/operators/gather_op.h index 8ec0d6ce0b69c791f9bff58f1681f8d4543c57dd..a2570c3e014e11ec10bc98d22607572e2b92d6e5 100644 --- a/paddle/fluid/operators/gather_op.h +++ b/paddle/fluid/operators/gather_op.h @@ -35,45 +35,30 @@ class GatherOpKernel : public framework::OpKernel { auto *index = ctx.Input("Index"); auto *output = ctx.Output("Out"); + int axis = ctx.Attr("axis"); + // get axis from tensor if (ctx.HasInput("Axis")) { - const Tensor *axis = ctx.Input("Axis"); - const auto &index_type = index->type(); - const auto &axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2Function(x, index, axis, output, place); + const Tensor *axis_tensor = ctx.Input("Axis"); + const auto &axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(axis_tensor->data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(axis_tensor->data()[0]); } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2Function(x, index, axis, output, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2Function(x, index, axis, output, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2Function(x, index, axis, output, place); + } + const auto &place = ctx.GetPlace(); + const auto &index_type = index->type(); + if (axis != 0) { + if (index_type == framework::proto::VarType::INT32) { + GatherV2Function(x, index, axis, output, place); + } else if (index_type == framework::proto::VarType::INT64) { + GatherV2Function(x, index, axis, output, place); } return; } output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; - - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { CPUGather(ctx.device_context(), *x, *index, output); } else if (index_type == framework::proto::VarType::INT64) { @@ -94,26 +79,23 @@ class GatherGradientOpKernel : public framework::OpKernel { auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); + int axis = ctx.Attr("axis"); if (ctx.HasInput("Axis")) { - const Tensor *axis = ctx.Input("Axis"); - const auto &index_type = index->type(); - const auto &axis_type = axis->type(); - auto place = ctx.GetPlace(); - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradFunction(dO, index, axis, dX, place); + const Tensor *axis_tensor = ctx.Input("Axis"); + const auto &axis_type = axis_tensor->type(); + if (axis_type == framework::proto::VarType::INT32) { + axis = static_cast(axis_tensor->data()[0]); + } else if (axis_type == framework::proto::VarType::INT64) { + axis = static_cast(axis_tensor->data()[0]); } - if (index_type == framework::proto::VarType::INT32 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradFunction(dO, index, axis, dX, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT32) { - GatherV2GradFunction(dO, index, axis, dX, place); - } - if (index_type == framework::proto::VarType::INT64 && - axis_type == framework::proto::VarType::INT64) { - GatherV2GradFunction(dO, index, axis, dX, place); + } + const auto &index_type = index->type(); + + if (axis != 0) { + if (index_type == framework::proto::VarType::INT32) { + GatherV2GradFunction(dO, index, axis, dX, ctx.GetPlace()); + } else if (index_type == framework::proto::VarType::INT64) { + GatherV2GradFunction(dO, index, axis, dX, ctx.GetPlace()); } return; } @@ -126,18 +108,6 @@ class GatherGradientOpKernel : public framework::OpKernel { if (dO->numel() == 0) return; bool overwrite = ctx.Attr("overwrite"); - const auto &index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { if (overwrite) { ScatterAssign(ctx.device_context(), *dO, *index, dX); diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 55f86959c59f255c323b3d9f0d5ca8099a956a3e..1a8e9a0bf55d091986795ef577dc65be31020b5e 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -238,31 +238,39 @@ def new_group(ranks=None, backend=None): if global_rank not in ranks: gp = Group(-1, -1, ring_id, ranks) _group_map[ring_id] = gp - return gp - - ranks = sorted(ranks) - group_rank = ranks.index(global_rank) - group_size = len(ranks) - gp = Group(group_rank, group_size, ring_id, ranks) - _group_map[ring_id] = gp - - if group_size < 2: - return gp - - strategy = core.ParallelStrategy() - strategy.nranks = group_size - strategy.local_rank = group_rank - strategy.trainer_endpoints = [genv.trainer_endpoints[i] for i in ranks] - strategy.current_endpoint = genv.current_endpoint - strategy.nrings = 1 - - if core.is_compiled_with_cuda(): - place = core.CUDAPlace(genv.device_id) - core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id) else: - assert False, ("no cuda device found") - # need to barrier to construct group - barrier(gp) + ranks = sorted(ranks) + group_rank = ranks.index(global_rank) + group_size = len(ranks) + gp = Group(group_rank, group_size, ring_id, ranks) + _group_map[ring_id] = gp + + if group_size >= 2: + strategy = core.ParallelStrategy() + strategy.nranks = group_size + strategy.local_rank = group_rank + strategy.trainer_endpoints = [ + genv.trainer_endpoints[i] for i in ranks + ] + strategy.current_endpoint = genv.current_endpoint + strategy.nrings = 1 + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(genv.device_id) + core.NCCLParallelContext(strategy, + place).init_with_ring_id(ring_id) + else: + assert False, ("no cuda device found") + else: + return gp + + # TODO(shenliang03): This is a temporary solution to solve the problem of + # hang caused by cross-creation of new_group + tmp = paddle.to_tensor( + [1], dtype="int32") if in_dygraph_mode() else fill_constant( + [0], dtype="int32", value="1") + paddle.distributed.all_reduce(tmp, use_calc_stream=True) + paddle.distributed.wait(tmp) return gp diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 946027a22f88384a2bc968b8595ee1ed416a6439..2d56441bf3efff373d3e118692b879d229e8b9c4 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -182,6 +182,7 @@ class TestGatherOp4(TestGatherOp1): self.index_type = "int64" self.axis = [0] self.axis_type = "int32" + self.attrs = {'overwrite': False} class API_TestGather(unittest.TestCase): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 67e6c7f8e44d740f179961c7e183efdced9ff805..c3031c41279c3cdb26c8cc740a0a401344b8bf52 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -862,34 +862,39 @@ def gather(x, index, axis=None, name=None): """ if axis is None: axis = 0 - axis_tensor = axis - if not isinstance(axis, Variable) and axis == 0: - return paddle.fluid.layers.gather(input=x, index=index, overwrite=False) - if not isinstance(axis, Variable): - with device_guard("cpu"): - axis_tensor = fill_constant( - shape=[1], dtype='int64', value=axis, force_cpu=True) + if in_dygraph_mode(): - return core.ops.gather(x, index, axis_tensor) + axis = axis.item() if isinstance(axis, paddle.Tensor) else axis + return core.ops.gather(x, index, None, "axis", axis, "overwrite", False) check_variable_and_dtype( x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], 'gather') check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') + if isinstance(axis, Variable): check_variable_and_dtype(axis, 'axis', ['int32', 'int64'], 'gather') - else: - check_type(axis, 'axis', (int), 'gather') helper = LayerHelper('gather', **locals()) dtype = helper.input_dtype('x') out = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type="gather", - inputs={"X": x, - "Index": index, - "Axis": axis_tensor}, - outputs={"Out": out}) + if not isinstance(axis, Variable): + helper.append_op( + type="gather", + inputs={"X": x, + "Index": index}, + attrs={'axis': axis, + 'overwrite': False}, + outputs={"Out": out}) + else: + helper.append_op( + type="gather", + inputs={"X": x, + "Index": index, + "Axis": axis}, + attrs={"overwrite": False}, + outputs={"Out": out}) + return out