提交 03571282 编写于 作者: C chengduoZH

fix VisitVariable

上级 fbb75c6b
...@@ -29,9 +29,7 @@ namespace framework { ...@@ -29,9 +29,7 @@ namespace framework {
namespace details { namespace details {
struct BroadcastOpHandle : public OpHandleBase { struct BroadcastOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_; public:
const std::vector<platform::Place> &places_;
BroadcastOpHandle(const std::vector<Scope *> &local_scopes, BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
...@@ -41,10 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -41,10 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
void WaitInputVarGenerated(const VarHandle &in_var); void WaitInputVarGenerated(const VarHandle &in_var);
};
private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -18,22 +18,22 @@ namespace paddle { ...@@ -18,22 +18,22 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
template <typename Func> template <typename Func>
static void VisitVariable(Variable* var, Func func) { static void VisitVariable(Variable* var, Func* func) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
func(var->GetMutable<LoDTensor>()); (*func)(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
func(var->GetMutable<SelectedRows>()); (*func)(var->GetMutable<SelectedRows>());
} else { } else {
PADDLE_THROW("Not supported type %s", var->Type().name()); PADDLE_THROW("Not supported type %s", var->Type().name());
} }
} }
template <typename Func> template <typename Func>
static void VisitVariable(const Variable& var, Func func) { static void VisitVariable(const Variable& var, Func* func) {
if (var.IsType<LoDTensor>()) { if (var.IsType<LoDTensor>()) {
func(var.Get<LoDTensor>()); (*func)(var.Get<LoDTensor>());
} else if (var.IsType<SelectedRows>()) { } else if (var.IsType<SelectedRows>()) {
func(var.Get<SelectedRows>()); (*func)(var.Get<SelectedRows>());
} else { } else {
PADDLE_THROW("Not supported type %s", var.Type().name()); PADDLE_THROW("Not supported type %s", var.Type().name());
} }
...@@ -56,7 +56,7 @@ struct TensorVisitor { ...@@ -56,7 +56,7 @@ struct TensorVisitor {
Tensor& VariableVisitor::GetMutableTensor(Variable* var) { Tensor& VariableVisitor::GetMutableTensor(Variable* var) {
TensorVisitor vistor; TensorVisitor vistor;
VisitVariable(var, vistor); VisitVariable(var, &vistor);
return *vistor.result_; return *vistor.result_;
} }
...@@ -85,7 +85,7 @@ struct ShareDimsAndLoDVisitor { ...@@ -85,7 +85,7 @@ struct ShareDimsAndLoDVisitor {
void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) { void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
ShareDimsAndLoDVisitor visitor{trg}; ShareDimsAndLoDVisitor visitor{trg};
VisitVariable(src, visitor); VisitVariable(src, &visitor);
} }
} // namespace details } // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册