未验证 提交 13cdaab6 编写于 作者: H houj04 提交者: GitHub

[XPU] bkcl_broadcast support int64_t (#53720)

上级 b150b168
......@@ -334,16 +334,30 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
<< ", root: " << root << ", numel: " << input.numel()
<< ", dtype: " << input.type() << ", sync_op: " << sync_op
<< ", use_calc_stream: " << use_calc_stream;
int r =
bkcl_broadcast(comm,
input.data(),
output->data(),
input.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
root,
stream);
return r;
if (framework::TransToProtoVarType(input.dtype()) ==
framework::proto::VarType::INT64) {
// special for int64_t, send as int32_t with DOUBLE NUMEL
int r = bkcl_broadcast(
comm,
input.data(),
output->data(),
input.numel() * 2,
platform::ToBKCLDataType(framework::proto::VarType::INT32),
root,
stream);
return r;
} else {
int r =
bkcl_broadcast(comm,
input.data(),
output->data(),
input.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
root,
stream);
return r;
}
},
CommType::BROADCAST,
sync_op,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册