From 035712822c8fdf9d8b3a7ad19efbfe775b14aa7a Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 18 Apr 2018 22:52:13 +0800 Subject: [PATCH] fix VisitVariable --- .../framework/details/broadcast_op_handle.h | 10 +++++----- .../fluid/framework/details/variable_visitor.cc | 16 ++++++++-------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index 2a0d70f8eab..92420f10ac5 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -29,9 +29,7 @@ namespace framework { namespace details { struct BroadcastOpHandle : public OpHandleBase { - const std::vector &local_scopes_; - const std::vector &places_; - + public: BroadcastOpHandle(const std::vector &local_scopes, const std::vector &places); @@ -41,10 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase { protected: void RunImpl() override; - void WaitInputVarGenerated(const VarHandle &in_var); -}; + private: + const std::vector &local_scopes_; + const std::vector &places_; +}; } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/variable_visitor.cc b/paddle/fluid/framework/details/variable_visitor.cc index f5f62ed8c44..10bac0fae95 100644 --- a/paddle/fluid/framework/details/variable_visitor.cc +++ b/paddle/fluid/framework/details/variable_visitor.cc @@ -18,22 +18,22 @@ namespace paddle { namespace framework { namespace details { template -static void VisitVariable(Variable* var, Func func) { +static void VisitVariable(Variable* var, Func* func) { if (var->IsType()) { - func(var->GetMutable()); + (*func)(var->GetMutable()); } else if (var->IsType()) { - func(var->GetMutable()); + (*func)(var->GetMutable()); } else { PADDLE_THROW("Not supported type %s", var->Type().name()); } } template -static void VisitVariable(const Variable& var, Func func) { +static void VisitVariable(const Variable& var, Func* func) { if (var.IsType()) { - func(var.Get()); + (*func)(var.Get()); } else if (var.IsType()) { - func(var.Get()); + (*func)(var.Get()); } else { PADDLE_THROW("Not supported type %s", var.Type().name()); } @@ -56,7 +56,7 @@ struct TensorVisitor { Tensor& VariableVisitor::GetMutableTensor(Variable* var) { TensorVisitor vistor; - VisitVariable(var, vistor); + VisitVariable(var, &vistor); return *vistor.result_; } @@ -85,7 +85,7 @@ struct ShareDimsAndLoDVisitor { void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) { ShareDimsAndLoDVisitor visitor{trg}; - VisitVariable(src, visitor); + VisitVariable(src, &visitor); } } // namespace details -- GitLab