diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 1d9ce17c504d24c35b05b870fa078e282d201286..21b0687f63fca8396b49ecfd9f1d593d16e0f21a 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { bool UseGPU() const; - bool NeedCollectiveForGrad(const std::string &grad_name, - std::vector ops) const; + virtual bool NeedCollectiveForGrad(const std::string &grad_name, + std::vector ops) const; bool IsScaleLossOp(ir::Node *node) const; @@ -117,7 +117,10 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const override {} - bool NeedCollectiveOps() const override { return false; } + bool NeedCollectiveForGrad(const std::string &grad_name, + std::vector ops) const { + return false; + } bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override { if (node->Op()->Type() == "recv") {