diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index 2a0d70f8eabd8c387231e61bfd9e07a865c0c052..92420f10ac5972b7924d83b43bb28234079e5072 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -29,9 +29,7 @@ namespace framework { namespace details { struct BroadcastOpHandle : public OpHandleBase { - const std::vector &local_scopes_; - const std::vector &places_; - + public: BroadcastOpHandle(const std::vector &local_scopes, const std::vector &places); @@ -41,10 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase { protected: void RunImpl() override; - void WaitInputVarGenerated(const VarHandle &in_var); -}; + private: + const std::vector &local_scopes_; + const std::vector &places_; +}; } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/variable_visitor.cc b/paddle/fluid/framework/details/variable_visitor.cc index f5f62ed8c4483b433af37d3b7a6f6d2589677277..10bac0fae9504215fab11dd8cca7c278feaa4bda 100644 --- a/paddle/fluid/framework/details/variable_visitor.cc +++ b/paddle/fluid/framework/details/variable_visitor.cc @@ -18,22 +18,22 @@ namespace paddle { namespace framework { namespace details { template -static void VisitVariable(Variable* var, Func func) { +static void VisitVariable(Variable* var, Func* func) { if (var->IsType()) { - func(var->GetMutable()); + (*func)(var->GetMutable()); } else if (var->IsType()) { - func(var->GetMutable()); + (*func)(var->GetMutable()); } else { PADDLE_THROW("Not supported type %s", var->Type().name()); } } template -static void VisitVariable(const Variable& var, Func func) { +static void VisitVariable(const Variable& var, Func* func) { if (var.IsType()) { - func(var.Get()); + (*func)(var.Get()); } else if (var.IsType()) { - func(var.Get()); + (*func)(var.Get()); } else { PADDLE_THROW("Not supported type %s", var.Type().name()); } @@ -56,7 +56,7 @@ struct TensorVisitor { Tensor& VariableVisitor::GetMutableTensor(Variable* var) { TensorVisitor vistor; - VisitVariable(var, vistor); + VisitVariable(var, &vistor); return *vistor.result_; } @@ -85,7 +85,7 @@ struct ShareDimsAndLoDVisitor { void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) { ShareDimsAndLoDVisitor visitor{trg}; - VisitVariable(src, visitor); + VisitVariable(src, &visitor); } } // namespace details