未验证 提交 8e291bf7 编写于 作者: J jameszhang 提交者: GitHub

Fix reduce func bug in process_group_bkcl (#49749)

* Fix reduce func bug in process_group_bkcl

Also catch up with a recent process_group PR that failed to add XPU branch.
Note that reduce is still accomplished by allreduce for xpu. Fix this should
xccl lib be updated.

* fix compile issue for non-XPU
上级 27aec62b
......@@ -277,7 +277,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
phi::DenseTensor output_t(*output);
phi::DenseTensor output_t;
paddle::framework::TensorCopy(*output, platform::XPUPlace(), &output_t);
const auto& place = input.place();
auto* calc_ctx = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
......
......@@ -129,6 +129,27 @@ void ConcatDenseTensorWithType(const DeviceContext &dev_ctx,
}
}
#ifdef PADDLE_WITH_XPU
template <>
void ConcatDenseTensorWithType(const phi::XPUContext &dev_ctx,
const std::vector<phi::DenseTensor> &t_list,
phi::DenseTensor *p_out,
phi::DataType type) {
switch (type) {
case phi::DataType::FLOAT16:
ConcatDenseTensor<phi::XPUContext, phi::dtype::float16>()(
dev_ctx, t_list, p_out);
break;
case phi::DataType::FLOAT32:
ConcatDenseTensor<phi::XPUContext, float>()(dev_ctx, t_list, p_out);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it concats tensors.", type));
}
}
#endif
template <typename DeviceContext>
void SplitDenseTensorWithType(const DeviceContext &dev_ctx,
const phi::DenseTensor &t_in,
......@@ -170,6 +191,27 @@ void SplitDenseTensorWithType(const DeviceContext &dev_ctx,
}
}
#ifdef PADDLE_WITH_XPU
template <>
void SplitDenseTensorWithType(const phi::XPUContext &dev_ctx,
const phi::DenseTensor &t_in,
std::vector<phi::DenseTensor *> *p_list,
phi::DataType type) {
switch (type) {
case phi::DataType::FLOAT16:
SplitDenseTensor<phi::XPUContext, phi::dtype::float16>()(
dev_ctx, t_in, p_list);
break;
case phi::DataType::FLOAT32:
SplitDenseTensor<phi::XPUContext, float>()(dev_ctx, t_in, p_list);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors.", type));
}
}
#endif
void ConcatTensor(const phi::DeviceContext &dev_ctx,
const std::vector<phi::DenseTensor> &tensor_list,
const experimental::Tensor *tensor) {
......@@ -187,6 +229,17 @@ void ConcatTensor(const phi::DeviceContext &dev_ctx,
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat tensor since it's not support GPU, please "
"recompile or reinstall Paddle with GPU support."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
ConcatDenseTensorWithType(static_cast<const phi::XPUContext &>(dev_ctx),
tensor_list,
dense_tensor,
tensor->dtype());
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat tensor since it's not support XPU, please "
"recompile or reinstall Paddle with XPU support."));
#endif
} else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
......@@ -233,6 +286,17 @@ void SplitTensor(const phi::DeviceContext &dev_ctx,
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split tensor since it's not support GPU, please "
"recompile or reinstall Paddle with GPU support."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
SplitDenseTensorWithType(static_cast<const phi::XPUContext &>(dev_ctx),
tensor,
&dense_list,
tensor.dtype());
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split tensor since it's not compiled with XPU, "
"please recompile or reinstall Paddle with XPU support."));
#endif
} else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
......
......@@ -163,6 +163,7 @@ class TestProcessGroupFp32(unittest.TestCase):
y = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
tensor_y = paddle.to_tensor(y)
old_tensor_y = paddle.to_tensor(y)
sum_result = tensor_x + tensor_y
if pg.rank() == 0:
task = dist.reduce(tensor_x, 0, sync_op=True)
......@@ -174,6 +175,7 @@ class TestProcessGroupFp32(unittest.TestCase):
paddle.device.xpu.synchronize()
if pg.rank() == 0:
assert np.array_equal(tensor_x, sum_result)
assert np.array_equal(tensor_y, old_tensor_y)
sys.stdout.write("rank {}: test reduce sum api ok\n".format(pg.rank()))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册