未验证 提交 a4e841e0 编写于 作者: S ShenLiang 提交者: GitHub

[cherry-pick] fix gather bug && fix hang of new_group (#33553)

* Fix gather infer shape using axis (#33413)

* fix gather shape bug

* fix None

* fix topo

* Fix hang of hybrid parallel in new_group  (#33141)

* fix hang of hybrid parallel

* fix new_group for hang problem

* fix hang
上级 036f81fc
...@@ -202,12 +202,11 @@ __global__ void GatherGradGPUKernel(const T* input, const U* index, T* out, ...@@ -202,12 +202,11 @@ __global__ void GatherGradGPUKernel(const T* input, const U* index, T* out,
} }
} }
template <typename T, typename U, typename V> template <typename T, typename U>
void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out, const int axis, Tensor* out,
const paddle::platform::Place& place, const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
int axis_size = axis->numel();
int index_size = index->numel(); int index_size = index->numel();
int input_size = input->numel(); int input_size = input->numel();
auto input_dim = input->dims(); auto input_dim = input->dims();
...@@ -215,12 +214,8 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, ...@@ -215,12 +214,8 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
auto* index_data = index->data<U>(); auto* index_data = index->data<U>();
if (input->numel() == 0) return; if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument( int axis_index = axis;
"Axis size should be 1, but received %d", axis_size));
Tensor cpu_axis;
framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis);
int axis_index = cpu_axis.data<V>()[0];
int index_dim_size = input_dim[axis_index]; int index_dim_size = input_dim[axis_index];
int inner_dim_size = 1; int inner_dim_size = 1;
...@@ -251,26 +246,19 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, ...@@ -251,26 +246,19 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
index_size, index_dim_size, out_size); index_size, index_dim_size, out_size);
} }
template <typename T, typename U, typename V> template <typename T, typename U>
void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out, const int axis, Tensor* out,
const paddle::platform::Place& place, const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto* index_data = index->data<U>(); auto* index_data = index->data<U>();
int axis_size = axis->numel();
int index_size = index->numel(); int index_size = index->numel();
int input_size = input->numel(); int input_size = input->numel();
auto input_dim = input->dims(); auto input_dim = input->dims();
auto* input_data = input->data<T>(); auto* input_data = input->data<T>();
if (input->numel() == 0) return; if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1, int axis_index = axis;
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
Tensor cpu_axis;
framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis);
int axis_index = cpu_axis.data<V>()[0];
int input_index_dim_size = input_dim[axis_index]; int input_index_dim_size = input_dim[axis_index];
int inner_dim_size = 1; int inner_dim_size = 1;
......
...@@ -126,24 +126,17 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input, ...@@ -126,24 +126,17 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
} }
} }
template <typename T, typename U, typename V> template <typename T, typename U>
void GatherV2Function(const Tensor* input, const Tensor* index, void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
const Tensor* axis, Tensor* out, Tensor* out, const paddle::platform::Place& place) {
const paddle::platform::Place& place) {
auto* axis_data = axis->data<V>();
auto* index_data = index->data<U>(); auto* index_data = index->data<U>();
int axis_size = axis->numel();
int index_size = index->numel(); int index_size = index->numel();
int input_size = input->numel(); int input_size = input->numel();
auto input_dim = input->dims(); auto input_dim = input->dims();
auto* input_data = input->data<T>(); auto* input_data = input->data<T>();
if (input->numel() == 0) return; if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1, int axis_index = axis;
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
int axis_index = axis_data[0];
int input_index_dim_size = input_dim[axis_index]; int input_index_dim_size = input_dim[axis_index];
for (int i = 0; i < index_size; i++) { for (int i = 0; i < index_size; i++) {
...@@ -186,22 +179,17 @@ void GatherV2Function(const Tensor* input, const Tensor* index, ...@@ -186,22 +179,17 @@ void GatherV2Function(const Tensor* input, const Tensor* index,
} }
} }
template <typename T, typename U, typename V> template <typename T, typename U>
void GatherV2GradFunction(const Tensor* input, const Tensor* index, void GatherV2GradFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out, const int axis, Tensor* out,
const paddle::platform::Place& place) { const paddle::platform::Place& place) {
auto* axis_data = axis->data<V>();
auto* index_data = index->data<U>(); auto* index_data = index->data<U>();
int axis_size = axis->numel();
auto input_dim = input->dims(); auto input_dim = input->dims();
auto* input_data = input->data<T>(); auto* input_data = input->data<T>();
if (input->numel() == 0) return; if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1, int axis_index = axis;
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
int axis_index = axis_data[0];
int input_index_dim_size = input_dim[axis_index]; int input_index_dim_size = input_dim[axis_index];
int inner_dim_size = 1; int inner_dim_size = 1;
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -52,11 +53,29 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -52,11 +53,29 @@ class GatherOp : public framework::OperatorWithKernel {
index_dims.size())); index_dims.size()));
} }
int batch_size = ctx->GetInputDim("Index")[0]; auto axis = ctx->Attrs().Get<int>("axis");
framework::DDim output_dims(ctx->GetInputDim("X")); auto input_dim = ctx->GetInputDim("X");
output_dims[0] = batch_size; if (ctx->HasInput("Axis") || axis == 0) {
ctx->SetOutputDim("Out", output_dims); // if HasInput("Axis"), we can not obtain correct shape of output
ctx->ShareLoD("X", /*->*/ "Out"); int batch_size = index_dims[0];
framework::DDim output_dims(input_dim);
output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis; i++) {
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = framework::make_ddim(out_dim_vec);
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
} }
protected: protected:
...@@ -120,6 +139,10 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -120,6 +139,10 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
"If true, update the grad using the overwrite mode in same index," "If true, update the grad using the overwrite mode in same index,"
"If false, using the accumulate mode in same index.") "If false, using the accumulate mode in same index.")
.SetDefault(true); .SetDefault(true);
AddAttr<int>(
"axis",
"The Tensor which contains the axis that we do gather operation.")
.SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
Gather Operator. Gather Operator.
......
...@@ -31,47 +31,33 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -31,47 +31,33 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out"); auto *output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
// get axis from tensor
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis"); Tensor cpu_axis;
const auto &index_type = index->type(); const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &axis_type = axis->type(); framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
auto place = ctx.GetPlace(); const auto &axis_type = axis_tensor->type();
if (index_type == framework::proto::VarType::INT32 && if (axis_type == framework::proto::VarType::INT32) {
axis_type == framework::proto::VarType::INT32) { axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
GatherV2CUDAFunction<T, int32_t, int32_t>(x, index, axis, output, place, } else if (axis_type == framework::proto::VarType::INT64) {
ctx); axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2CUDAFunction<T, int32_t, int64_t>(x, index, axis, output, place,
ctx);
} }
if (index_type == framework::proto::VarType::INT64 && }
axis_type == framework::proto::VarType::INT32) { const auto &place = ctx.GetPlace();
GatherV2CUDAFunction<T, int64_t, int32_t>(x, index, axis, output, place, const auto &index_type = index->type();
ctx); if (axis != 0) {
} if (index_type == framework::proto::VarType::INT32) {
if (index_type == framework::proto::VarType::INT64 && GatherV2CUDAFunction<T, int32_t>(x, index, axis, output, place, ctx);
axis_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
GatherV2CUDAFunction<T, int64_t, int64_t>(x, index, axis, output, place, GatherV2CUDAFunction<T, int64_t>(x, index, axis, output, place, ctx);
ctx);
} }
return; return;
} }
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return; if (x->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int>(ctx.device_context(), *x, *index, output); GPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
...@@ -91,30 +77,27 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -91,30 +77,27 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis"); const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type(); Tensor cpu_axis;
const auto &axis_type = axis->type(); framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
auto place = ctx.GetPlace(); const auto &axis_type = axis_tensor->type();
if (index_type == framework::proto::VarType::INT32 && if (axis_type == framework::proto::VarType::INT32) {
axis_type == framework::proto::VarType::INT32) { axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
GatherV2GradCUDAFunction<T, int32_t, int32_t>(dO, index, axis, dX, } else if (axis_type == framework::proto::VarType::INT64) {
place, ctx); axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
} }
if (index_type == framework::proto::VarType::INT32 && }
axis_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int32_t, int64_t>(dO, index, axis, dX, const auto &index_type = index->type();
place, ctx); if (axis != 0) {
} if (index_type == framework::proto::VarType::INT32) {
if (index_type == framework::proto::VarType::INT64 && GatherV2GradCUDAFunction<T, int32_t>(dO, index, axis, dX,
axis_type == framework::proto::VarType::INT32) { ctx.GetPlace(), ctx);
GatherV2GradCUDAFunction<T, int64_t, int32_t>(dO, index, axis, dX, } else if (index_type == framework::proto::VarType::INT64) {
place, ctx); GatherV2GradCUDAFunction<T, int64_t>(dO, index, axis, dX,
} ctx.GetPlace(), ctx);
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int64_t, int64_t>(dO, index, axis, dX,
place, ctx);
} }
return; return;
} }
...@@ -125,19 +108,6 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -125,19 +108,6 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
.eigen_device(); .eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return; if (dO->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int>(ctx, *dO, *index, dX, GPUScatterAssign<T, int>(ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite")); ctx.Attr<bool>("overwrite"));
......
...@@ -35,45 +35,30 @@ class GatherOpKernel : public framework::OpKernel<T> { ...@@ -35,45 +35,30 @@ class GatherOpKernel : public framework::OpKernel<T> {
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out"); auto *output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
// get axis from tensor
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis"); const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type(); const auto &axis_type = axis_tensor->type();
const auto &axis_type = axis->type(); if (axis_type == framework::proto::VarType::INT32) {
auto place = ctx.GetPlace(); axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
if (index_type == framework::proto::VarType::INT32 && } else if (axis_type == framework::proto::VarType::INT64) {
axis_type == framework::proto::VarType::INT32) { axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
GatherV2Function<T, int32_t, int32_t>(x, index, axis, output, place);
} }
if (index_type == framework::proto::VarType::INT32 && }
axis_type == framework::proto::VarType::INT64) { const auto &place = ctx.GetPlace();
GatherV2Function<T, int32_t, int64_t>(x, index, axis, output, place); const auto &index_type = index->type();
} if (axis != 0) {
if (index_type == framework::proto::VarType::INT64 && if (index_type == framework::proto::VarType::INT32) {
axis_type == framework::proto::VarType::INT32) { GatherV2Function<T, int32_t>(x, index, axis, output, place);
GatherV2Function<T, int64_t, int32_t>(x, index, axis, output, place); } else if (index_type == framework::proto::VarType::INT64) {
} GatherV2Function<T, int64_t>(x, index, axis, output, place);
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2Function<T, int64_t, int64_t>(x, index, axis, output, place);
} }
return; return;
} }
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return; if (x->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
CPUGather<T, int>(ctx.device_context(), *x, *index, output); CPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
...@@ -94,26 +79,23 @@ class GatherGradientOpKernel : public framework::OpKernel<T> { ...@@ -94,26 +79,23 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis"); const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type(); const auto &axis_type = axis_tensor->type();
const auto &axis_type = axis->type(); if (axis_type == framework::proto::VarType::INT32) {
auto place = ctx.GetPlace(); axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
if (index_type == framework::proto::VarType::INT32 && } else if (axis_type == framework::proto::VarType::INT64) {
axis_type == framework::proto::VarType::INT32) { axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
GatherV2GradFunction<T, int32_t, int32_t>(dO, index, axis, dX, place);
} }
if (index_type == framework::proto::VarType::INT32 && }
axis_type == framework::proto::VarType::INT64) { const auto &index_type = index->type();
GatherV2GradFunction<T, int32_t, int64_t>(dO, index, axis, dX, place);
} if (axis != 0) {
if (index_type == framework::proto::VarType::INT64 && if (index_type == framework::proto::VarType::INT32) {
axis_type == framework::proto::VarType::INT32) { GatherV2GradFunction<T, int32_t>(dO, index, axis, dX, ctx.GetPlace());
GatherV2GradFunction<T, int64_t, int32_t>(dO, index, axis, dX, place); } else if (index_type == framework::proto::VarType::INT64) {
} GatherV2GradFunction<T, int64_t>(dO, index, axis, dX, ctx.GetPlace());
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2GradFunction<T, int64_t, int64_t>(dO, index, axis, dX, place);
} }
return; return;
} }
...@@ -126,18 +108,6 @@ class GatherGradientOpKernel : public framework::OpKernel<T> { ...@@ -126,18 +108,6 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
if (dO->numel() == 0) return; if (dO->numel() == 0) return;
bool overwrite = ctx.Attr<bool>("overwrite"); bool overwrite = ctx.Attr<bool>("overwrite");
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
if (overwrite) { if (overwrite) {
ScatterAssign<T, int32_t>(ctx.device_context(), *dO, *index, dX); ScatterAssign<T, int32_t>(ctx.device_context(), *dO, *index, dX);
......
...@@ -238,31 +238,39 @@ def new_group(ranks=None, backend=None): ...@@ -238,31 +238,39 @@ def new_group(ranks=None, backend=None):
if global_rank not in ranks: if global_rank not in ranks:
gp = Group(-1, -1, ring_id, ranks) gp = Group(-1, -1, ring_id, ranks)
_group_map[ring_id] = gp _group_map[ring_id] = gp
return gp
ranks = sorted(ranks)
group_rank = ranks.index(global_rank)
group_size = len(ranks)
gp = Group(group_rank, group_size, ring_id, ranks)
_group_map[ring_id] = gp
if group_size < 2:
return gp
strategy = core.ParallelStrategy()
strategy.nranks = group_size
strategy.local_rank = group_rank
strategy.trainer_endpoints = [genv.trainer_endpoints[i] for i in ranks]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1
if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id)
else: else:
assert False, ("no cuda device found") ranks = sorted(ranks)
# need to barrier to construct group group_rank = ranks.index(global_rank)
barrier(gp) group_size = len(ranks)
gp = Group(group_rank, group_size, ring_id, ranks)
_group_map[ring_id] = gp
if group_size >= 2:
strategy = core.ParallelStrategy()
strategy.nranks = group_size
strategy.local_rank = group_rank
strategy.trainer_endpoints = [
genv.trainer_endpoints[i] for i in ranks
]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1
if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy,
place).init_with_ring_id(ring_id)
else:
assert False, ("no cuda device found")
else:
return gp
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
tmp = paddle.to_tensor(
[1], dtype="int32") if in_dygraph_mode() else fill_constant(
[0], dtype="int32", value="1")
paddle.distributed.all_reduce(tmp, use_calc_stream=True)
paddle.distributed.wait(tmp)
return gp return gp
......
...@@ -182,6 +182,7 @@ class TestGatherOp4(TestGatherOp1): ...@@ -182,6 +182,7 @@ class TestGatherOp4(TestGatherOp1):
self.index_type = "int64" self.index_type = "int64"
self.axis = [0] self.axis = [0]
self.axis_type = "int32" self.axis_type = "int32"
self.attrs = {'overwrite': False}
class API_TestGather(unittest.TestCase): class API_TestGather(unittest.TestCase):
......
...@@ -862,34 +862,39 @@ def gather(x, index, axis=None, name=None): ...@@ -862,34 +862,39 @@ def gather(x, index, axis=None, name=None):
""" """
if axis is None: if axis is None:
axis = 0 axis = 0
axis_tensor = axis
if not isinstance(axis, Variable) and axis == 0:
return paddle.fluid.layers.gather(input=x, index=index, overwrite=False)
if not isinstance(axis, Variable):
with device_guard("cpu"):
axis_tensor = fill_constant(
shape=[1], dtype='int64', value=axis, force_cpu=True)
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.gather(x, index, axis_tensor) axis = axis.item() if isinstance(axis, paddle.Tensor) else axis
return core.ops.gather(x, index, None, "axis", axis, "overwrite", False)
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
'gather') 'gather')
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather')
if isinstance(axis, Variable): if isinstance(axis, Variable):
check_variable_and_dtype(axis, 'axis', ['int32', 'int64'], 'gather') check_variable_and_dtype(axis, 'axis', ['int32', 'int64'], 'gather')
else:
check_type(axis, 'axis', (int), 'gather')
helper = LayerHelper('gather', **locals()) helper = LayerHelper('gather', **locals())
dtype = helper.input_dtype('x') dtype = helper.input_dtype('x')
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( if not isinstance(axis, Variable):
type="gather", helper.append_op(
inputs={"X": x, type="gather",
"Index": index, inputs={"X": x,
"Axis": axis_tensor}, "Index": index},
outputs={"Out": out}) attrs={'axis': axis,
'overwrite': False},
outputs={"Out": out})
else:
helper.append_op(
type="gather",
inputs={"X": x,
"Index": index,
"Axis": axis},
attrs={"overwrite": False},
outputs={"Out": out})
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册