未验证 提交 8e291bf7 编写于 作者: J jameszhang 提交者: GitHub

Fix reduce func bug in process_group_bkcl (#49749)

* Fix reduce func bug in process_group_bkcl

Also catch up with a recent process_group PR that failed to add XPU branch.
Note that reduce is still accomplished by allreduce for xpu. Fix this should
xccl lib be updated.

* fix compile issue for non-XPU
上级 27aec62b
...@@ -277,7 +277,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce( ...@@ -277,7 +277,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
phi::DenseTensor output_t(*output); phi::DenseTensor output_t;
paddle::framework::TensorCopy(*output, platform::XPUPlace(), &output_t);
const auto& place = input.place(); const auto& place = input.place();
auto* calc_ctx = static_cast<phi::XPUContext*>( auto* calc_ctx = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
......
...@@ -129,6 +129,27 @@ void ConcatDenseTensorWithType(const DeviceContext &dev_ctx, ...@@ -129,6 +129,27 @@ void ConcatDenseTensorWithType(const DeviceContext &dev_ctx,
} }
} }
#ifdef PADDLE_WITH_XPU
template <>
void ConcatDenseTensorWithType(const phi::XPUContext &dev_ctx,
const std::vector<phi::DenseTensor> &t_list,
phi::DenseTensor *p_out,
phi::DataType type) {
switch (type) {
case phi::DataType::FLOAT16:
ConcatDenseTensor<phi::XPUContext, phi::dtype::float16>()(
dev_ctx, t_list, p_out);
break;
case phi::DataType::FLOAT32:
ConcatDenseTensor<phi::XPUContext, float>()(dev_ctx, t_list, p_out);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it concats tensors.", type));
}
}
#endif
template <typename DeviceContext> template <typename DeviceContext>
void SplitDenseTensorWithType(const DeviceContext &dev_ctx, void SplitDenseTensorWithType(const DeviceContext &dev_ctx,
const phi::DenseTensor &t_in, const phi::DenseTensor &t_in,
...@@ -170,6 +191,27 @@ void SplitDenseTensorWithType(const DeviceContext &dev_ctx, ...@@ -170,6 +191,27 @@ void SplitDenseTensorWithType(const DeviceContext &dev_ctx,
} }
} }
#ifdef PADDLE_WITH_XPU
template <>
void SplitDenseTensorWithType(const phi::XPUContext &dev_ctx,
const phi::DenseTensor &t_in,
std::vector<phi::DenseTensor *> *p_list,
phi::DataType type) {
switch (type) {
case phi::DataType::FLOAT16:
SplitDenseTensor<phi::XPUContext, phi::dtype::float16>()(
dev_ctx, t_in, p_list);
break;
case phi::DataType::FLOAT32:
SplitDenseTensor<phi::XPUContext, float>()(dev_ctx, t_in, p_list);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors.", type));
}
}
#endif
void ConcatTensor(const phi::DeviceContext &dev_ctx, void ConcatTensor(const phi::DeviceContext &dev_ctx,
const std::vector<phi::DenseTensor> &tensor_list, const std::vector<phi::DenseTensor> &tensor_list,
const experimental::Tensor *tensor) { const experimental::Tensor *tensor) {
...@@ -187,6 +229,17 @@ void ConcatTensor(const phi::DeviceContext &dev_ctx, ...@@ -187,6 +229,17 @@ void ConcatTensor(const phi::DeviceContext &dev_ctx,
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat tensor since it's not support GPU, please " "Paddle can't concat tensor since it's not support GPU, please "
"recompile or reinstall Paddle with GPU support.")); "recompile or reinstall Paddle with GPU support."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
ConcatDenseTensorWithType(static_cast<const phi::XPUContext &>(dev_ctx),
tensor_list,
dense_tensor,
tensor->dtype());
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat tensor since it's not support XPU, please "
"recompile or reinstall Paddle with XPU support."));
#endif #endif
} else if (platform::is_custom_place(place)) { } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
...@@ -233,6 +286,17 @@ void SplitTensor(const phi::DeviceContext &dev_ctx, ...@@ -233,6 +286,17 @@ void SplitTensor(const phi::DeviceContext &dev_ctx,
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split tensor since it's not support GPU, please " "Paddle can't split tensor since it's not support GPU, please "
"recompile or reinstall Paddle with GPU support.")); "recompile or reinstall Paddle with GPU support."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
SplitDenseTensorWithType(static_cast<const phi::XPUContext &>(dev_ctx),
tensor,
&dense_list,
tensor.dtype());
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split tensor since it's not compiled with XPU, "
"please recompile or reinstall Paddle with XPU support."));
#endif #endif
} else if (platform::is_custom_place(place)) { } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
......
...@@ -163,6 +163,7 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -163,6 +163,7 @@ class TestProcessGroupFp32(unittest.TestCase):
y = np.random.random(self.shape).astype(self.dtype) y = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x) tensor_x = paddle.to_tensor(x)
tensor_y = paddle.to_tensor(y) tensor_y = paddle.to_tensor(y)
old_tensor_y = paddle.to_tensor(y)
sum_result = tensor_x + tensor_y sum_result = tensor_x + tensor_y
if pg.rank() == 0: if pg.rank() == 0:
task = dist.reduce(tensor_x, 0, sync_op=True) task = dist.reduce(tensor_x, 0, sync_op=True)
...@@ -174,6 +175,7 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -174,6 +175,7 @@ class TestProcessGroupFp32(unittest.TestCase):
paddle.device.xpu.synchronize() paddle.device.xpu.synchronize()
if pg.rank() == 0: if pg.rank() == 0:
assert np.array_equal(tensor_x, sum_result) assert np.array_equal(tensor_x, sum_result)
assert np.array_equal(tensor_y, old_tensor_y)
sys.stdout.write("rank {}: test reduce sum api ok\n".format(pg.rank())) sys.stdout.write("rank {}: test reduce sum api ok\n".format(pg.rank()))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册