提交 03571282 编写于 作者: C chengduoZH

fix VisitVariable

上级 fbb75c6b
......@@ -29,9 +29,7 @@ namespace framework {
namespace details {
struct BroadcastOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
public:
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
......@@ -41,10 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
void WaitInputVarGenerated(const VarHandle &in_var);
};
private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -18,22 +18,22 @@ namespace paddle {
namespace framework {
namespace details {
template <typename Func>
static void VisitVariable(Variable* var, Func func) {
static void VisitVariable(Variable* var, Func* func) {
if (var->IsType<LoDTensor>()) {
func(var->GetMutable<LoDTensor>());
(*func)(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) {
func(var->GetMutable<SelectedRows>());
(*func)(var->GetMutable<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var->Type().name());
}
}
template <typename Func>
static void VisitVariable(const Variable& var, Func func) {
static void VisitVariable(const Variable& var, Func* func) {
if (var.IsType<LoDTensor>()) {
func(var.Get<LoDTensor>());
(*func)(var.Get<LoDTensor>());
} else if (var.IsType<SelectedRows>()) {
func(var.Get<SelectedRows>());
(*func)(var.Get<SelectedRows>());
} else {
PADDLE_THROW("Not supported type %s", var.Type().name());
}
......@@ -56,7 +56,7 @@ struct TensorVisitor {
Tensor& VariableVisitor::GetMutableTensor(Variable* var) {
TensorVisitor vistor;
VisitVariable(var, vistor);
VisitVariable(var, &vistor);
return *vistor.result_;
}
......@@ -85,7 +85,7 @@ struct ShareDimsAndLoDVisitor {
void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
ShareDimsAndLoDVisitor visitor{trg};
VisitVariable(src, visitor);
VisitVariable(src, &visitor);
}
} // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册