未验证 提交 3fec7a6e 编写于 作者: W Wen Sun 提交者: GitHub

fix: gloo compatible (#49084)

上级 d808f160
......@@ -310,6 +310,16 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
}
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllReduce(in_wrapper, out_wrapper, opts, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
......@@ -393,8 +403,8 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset, // for compatibility, no use now
int64_t numel, // for compatibility, no use now
int64_t /*offset*/,
int64_t /*offset*/,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
......
......@@ -116,8 +116,14 @@ class ProcessGroupGloo : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset, // for compatibility, no use now
int64_t numel, // for compatibility, no use now
int64_t /*offset*/, // for compatibility, no use now
int64_t /*numel*/, // for compatibility, no use now
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册