未验证 提交 f0dab193 编写于 作者: J james 提交者: GitHub

nullptr bugfix for XPU pg mode (#49043)

* nullptr bugfix for XPU pg mode

Also a few kernels is added to xpu whitelist

* increase error msg length
上级 f2a8dd50
......@@ -154,7 +154,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective(
const auto& place = in_tensor.place();
const auto& key = GetKeyFromPlace(place);
if (!calc_event_) {
if (!calc_event_ ||
(place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end())) {
CreateBKCLEnvCache(place, key);
}
......@@ -170,6 +171,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Collective(
fn(out_tensor, in_tensor, comm_ctx->bkcl_context(), bkcl_stream);
if (!use_calc_stream) {
PADDLE_ENFORCE_NOT_NULL(
comm_ctx.get(), platform::errors::Fatal("comm context is nullptr."));
task->comm_event_->Record(*comm_ctx.get());
}
......@@ -369,6 +372,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
return Collective(
&out_tensors[0],
in_tensors[0],
......@@ -406,6 +413,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
return Collective(
&out_tensors[0],
in_tensors[0],
......@@ -442,6 +453,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
return Collective(
&out_tensors[0],
......@@ -481,6 +496,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
return Collective(
&out_tensors[0],
......@@ -518,6 +537,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
1,
platform::errors::InvalidArgument(
"BKCL only support single tensor collective communication."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in XPUPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInXPUPlace(out_tensors),
true,
......
......@@ -25,6 +25,7 @@ XPUOpMap& get_kl2_ops() {
{"abs", XPUKernelSet({phi::DataType::FLOAT32})},
{"abs_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"accuracy", XPUKernelSet({phi::DataType::FLOAT32})},
{"adadelta", XPUKernelSet({phi::DataType::FLOAT32})},
{"adamw", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"adam", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......@@ -402,6 +403,13 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
{"reshape_with_xshape",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
{"resnet_unit",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"resnet_unit_grad",
......@@ -485,6 +493,14 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"squeeze_with_xshape",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"squeeze_grad",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
......
......@@ -378,7 +378,9 @@ PD_REGISTER_KERNEL(multiply,
ALL_LAYOUT,
phi::MultiplyKernel,
phi::dtype::float16,
float) {}
float,
int,
int64_t) {}
PD_REGISTER_KERNEL(subtract,
XPU,
ALL_LAYOUT,
......
......@@ -42,4 +42,6 @@ PD_REGISTER_KERNEL(multiply_raw,
ALL_LAYOUT,
phi::MultiplyRawKernel,
phi::dtype::float16,
float) {}
float,
int,
int64_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册