未验证 提交 f0dab193 编写于 作者: J james 提交者: GitHub

nullptr bugfix for XPU pg mode (#49043)

* nullptr bugfix for XPU pg mode

Also a few kernels is added to xpu whitelist

* increase error msg length
上级 f2a8dd50
...@@ -154,7 +154,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective( ...@@ -154,7 +154,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective(
const auto& place = in_tensor.place(); const auto& place = in_tensor.place();
const auto& key = GetKeyFromPlace(place); const auto& key = GetKeyFromPlace(place);
if (!calc_event_) { if (!calc_event_ ||
(place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end())) {
CreateBKCLEnvCache(place, key); CreateBKCLEnvCache(place, key);
} }
...@@ -170,6 +171,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective( ...@@ -170,6 +171,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective(
fn(out_tensor, in_tensor, comm_ctx->bkcl_context(), bkcl_stream); fn(out_tensor, in_tensor, comm_ctx->bkcl_context(), bkcl_stream);
if (!use_calc_stream) { if (!use_calc_stream) {
PADDLE_ENFORCE_NOT_NULL(
comm_ctx.get(), platform::errors::Fatal("comm context is nullptr."));
task->comm_event_->Record(*comm_ctx.get()); task->comm_event_->Record(*comm_ctx.get());
} }
...@@ -369,6 +372,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce( ...@@ -369,6 +372,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
1, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication.")); "BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
return Collective( return Collective(
&out_tensors[0], &out_tensors[0],
in_tensors[0], in_tensors[0],
...@@ -406,6 +413,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce( ...@@ -406,6 +413,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
1, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication.")); "BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
return Collective( return Collective(
&out_tensors[0], &out_tensors[0],
in_tensors[0], in_tensors[0],
...@@ -442,6 +453,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast( ...@@ -442,6 +453,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
1, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication.")); "BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
return Collective( return Collective(
&out_tensors[0], &out_tensors[0],
...@@ -481,6 +496,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast( ...@@ -481,6 +496,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
1, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication.")); "BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
return Collective( return Collective(
&out_tensors[0], &out_tensors[0],
...@@ -518,6 +537,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather( ...@@ -518,6 +537,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
1, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication.")); "BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(out_tensors), CheckTensorsInXPUPlace(out_tensors),
true, true,
......
...@@ -25,6 +25,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -25,6 +25,7 @@ XPUOpMap& get_kl2_ops() {
{"abs", XPUKernelSet({phi::DataType::FLOAT32})}, {"abs", XPUKernelSet({phi::DataType::FLOAT32})},
{"abs_grad", {"abs_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"accuracy", XPUKernelSet({phi::DataType::FLOAT32})},
{"adadelta", XPUKernelSet({phi::DataType::FLOAT32})}, {"adadelta", XPUKernelSet({phi::DataType::FLOAT32})},
{"adamw", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"adamw", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"adam", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"adam", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
...@@ -402,6 +403,13 @@ XPUOpMap& get_kl2_ops() { ...@@ -402,6 +403,13 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::BOOL, phi::DataType::BOOL,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"reshape_with_xshape",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
{"resnet_unit", {"resnet_unit",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"resnet_unit_grad", {"resnet_unit_grad",
...@@ -485,6 +493,14 @@ XPUOpMap& get_kl2_ops() { ...@@ -485,6 +493,14 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT8, phi::DataType::INT8,
phi::DataType::UINT8, phi::DataType::UINT8,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"squeeze_with_xshape",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"squeeze_grad", {"squeeze_grad",
XPUKernelSet({phi::DataType::FLOAT64, XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64, phi::DataType::INT64,
......
...@@ -378,7 +378,9 @@ PD_REGISTER_KERNEL(multiply, ...@@ -378,7 +378,9 @@ PD_REGISTER_KERNEL(multiply,
ALL_LAYOUT, ALL_LAYOUT,
phi::MultiplyKernel, phi::MultiplyKernel,
phi::dtype::float16, phi::dtype::float16,
float) {} float,
int,
int64_t) {}
PD_REGISTER_KERNEL(subtract, PD_REGISTER_KERNEL(subtract,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -42,4 +42,6 @@ PD_REGISTER_KERNEL(multiply_raw, ...@@ -42,4 +42,6 @@ PD_REGISTER_KERNEL(multiply_raw,
ALL_LAYOUT, ALL_LAYOUT,
phi::MultiplyRawKernel, phi::MultiplyRawKernel,
phi::dtype::float16, phi::dtype::float16,
float) {} float,
int,
int64_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册