未验证 提交 13cdaab6 编写于 作者: H houj04 提交者: GitHub

[XPU] bkcl_broadcast support int64_t (#53720)

上级 b150b168
...@@ -334,16 +334,30 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast( ...@@ -334,16 +334,30 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
<< ", root: " << root << ", numel: " << input.numel() << ", root: " << root << ", numel: " << input.numel()
<< ", dtype: " << input.type() << ", sync_op: " << sync_op << ", dtype: " << input.type() << ", sync_op: " << sync_op
<< ", use_calc_stream: " << use_calc_stream; << ", use_calc_stream: " << use_calc_stream;
int r = if (framework::TransToProtoVarType(input.dtype()) ==
bkcl_broadcast(comm, framework::proto::VarType::INT64) {
input.data(), // special for int64_t, send as int32_t with DOUBLE NUMEL
output->data(), int r = bkcl_broadcast(
input.numel(), comm,
platform::ToBKCLDataType( input.data(),
framework::TransToProtoVarType(input.type())), output->data(),
root, input.numel() * 2,
stream); platform::ToBKCLDataType(framework::proto::VarType::INT32),
return r; root,
stream);
return r;
} else {
int r =
bkcl_broadcast(comm,
input.data(),
output->data(),
input.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
root,
stream);
return r;
}
}, },
CommType::BROADCAST, CommType::BROADCAST,
sync_op, sync_op,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册