提交 f25eb9a7 编写于 作者: X Xin Pan

fix some tests.

test=develop
上级 adf5615e
...@@ -37,8 +37,9 @@ struct TestBroadcastOpHandle { ...@@ -37,8 +37,9 @@ struct TestBroadcastOpHandle {
std::vector<Scope*> local_scopes_; std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_; std::vector<Scope*> param_scopes_;
Scope g_scope_; Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase* op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase*> vars_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
std::vector<p::Place> place_list_; std::vector<p::Place> place_list_;
bool use_gpu_; bool use_gpu_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -90,6 +91,7 @@ struct TestBroadcastOpHandle { ...@@ -90,6 +91,7 @@ struct TestBroadcastOpHandle {
} }
void InitBroadcastOp(size_t input_scope_idx) { void InitBroadcastOp(size_t input_scope_idx) {
nodes_.clear();
for (size_t j = 0; j < place_list_.size(); ++j) { for (size_t j = 0; j < place_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
...@@ -101,39 +103,39 @@ struct TestBroadcastOpHandle { ...@@ -101,39 +103,39 @@ struct TestBroadcastOpHandle {
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n = nodes_.emplace_back(
ir::CreateNodeForTest("node0", ir::Node::Type::kOperation); ir::CreateNodeForTest("node0", ir::Node::Type::kOperation));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_, nccl_ctxs_.get())); place_list_, nccl_ctxs_.get());
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not support.");
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_, nccl_ctxs_.get())); place_list_, nccl_ctxs_.get());
#else #else
op_handle_.reset( op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
new BroadcastOpHandle(n.get(), local_scopes_, place_list_)); place_list_);
#endif #endif
} }
std::unique_ptr<ir::Node> v = nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable); ir::CreateNodeForTest("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", auto* in_var_handle = new VarHandle(nodes_.back().get(), 1, input_scope_idx,
place_list_[input_scope_idx]); "input", place_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v2 = nodes_.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable); ir::CreateNodeForTest("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v2.get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
dummy_var_handle->ClearGeneratedOp(); dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(dummy_var_handle); op_handle_->AddInput(dummy_var_handle);
...@@ -141,20 +143,20 @@ struct TestBroadcastOpHandle { ...@@ -141,20 +143,20 @@ struct TestBroadcastOpHandle {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(place_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(place_list_[j], ctxs_[j].get());
} }
std::unique_ptr<ir::Node> v3 = nodes_.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable); ir::CreateNodeForTest("node3", ir::Node::Type::kVariable));
VarHandle* out_var_handle = VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", place_list_[j]); new VarHandle(nodes_.back().get(), 2, j, "out", place_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
} }
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v4 = nodes_.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable); ir::CreateNodeForTest("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v4.get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* out_dummy_var_handle = DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
out_dummy_var_handle->ClearGeneratedOp(); out_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddOutput(out_dummy_var_handle); op_handle_->AddOutput(out_dummy_var_handle);
} }
......
...@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, ...@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
offset_(offset), offset_(offset),
local_scopes_(local_scopes) {} local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() { FetchOpHandle::~FetchOpHandle() {}
for (auto *input_var : inputs_) {
input_var->RemoveOutput(this, this->Node());
}
}
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
......
...@@ -22,8 +22,10 @@ namespace details { ...@@ -22,8 +22,10 @@ namespace details {
struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
std::vector<std::string> out_varnames_; std::vector<std::string> out_varnames_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) { void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) {
nodes_.clear();
// initialize scope and var // initialize scope and var
for (size_t i = 0; i < place_list_.size(); ++i) { for (size_t i = 0; i < place_list_.size(); ++i) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
...@@ -39,41 +41,41 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { ...@@ -39,41 +41,41 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
} }
// create op handle node // create op handle node
std::unique_ptr<ir::Node> n = nodes_.emplace_back(
ir::CreateNodeForTest("fused_broadcast", ir::Node::Type::kOperation); ir::CreateNodeForTest("fused_broadcast", ir::Node::Type::kOperation));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new FusedBroadcastOpHandle( op_handle_ = new FusedBroadcastOpHandle(
n.get(), local_scopes_, place_list_, nccl_ctxs_.get())); nodes_.back().get(), local_scopes_, place_list_, nccl_ctxs_.get());
#else #else
PADDLE_THROW("CUDA is not supported."); PADDLE_THROW("CUDA is not supported.");
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new FusedBroadcastOpHandle( op_handle_ = new FusedBroadcastOpHandle(
n.get(), local_scopes_, place_list_, nccl_ctxs_.get())); nodes_.back().get(), local_scopes_, place_list_, nccl_ctxs_.get());
#else #else
op_handle_.reset( op_handle_ = new FusedBroadcastOpHandle(nodes_.back().get(),
new FusedBroadcastOpHandle(n.get(), local_scopes_, place_list_)); local_scopes_, place_list_);
#endif #endif
} }
for (size_t i = 0; i < input_scope_idxes.size(); ++i) { for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
// add input var handle // add input var handle
std::unique_ptr<ir::Node> in_node = nodes_.emplace_back(
ir::CreateNodeForTest("in_node" + i, ir::Node::Type::kVariable); ir::CreateNodeForTest("in_node" + i, ir::Node::Type::kVariable));
VarHandle* in_var_handle = VarHandle* in_var_handle =
new VarHandle(in_node.get(), 1, input_scope_idxes[i], "in_var" + i, new VarHandle(nodes_.back().get(), 1, input_scope_idxes[i],
place_list_[input_scope_idxes[i]]); "in_var" + i, place_list_[input_scope_idxes[i]]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
// add output var handle // add output var handle
for (size_t j = 0; j < place_list_.size(); ++j) { for (size_t j = 0; j < place_list_.size(); ++j) {
std::unique_ptr<ir::Node> out_node = nodes_.emplace_back(
ir::CreateNodeForTest("out_node" + i, ir::Node::Type::kVariable); ir::CreateNodeForTest("out_node" + i, ir::Node::Type::kVariable));
VarHandle* out_var_handle = VarHandle* out_var_handle = new VarHandle(
new VarHandle(out_node.get(), 2, j, "out_var" + i, place_list_[j]); nodes_.back().get(), 2, j, "out_var" + i, place_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
} }
......
...@@ -34,6 +34,7 @@ struct TestGatherOpHandle { ...@@ -34,6 +34,7 @@ struct TestGatherOpHandle {
OpHandleBase* op_handle_; OpHandleBase* op_handle_;
std::vector<VarHandleBase*> vars_; std::vector<VarHandleBase*> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
void WaitAll() { void WaitAll() {
for (size_t j = 0; j < ctxs_.size(); ++j) { for (size_t j = 0; j < ctxs_.size(); ++j) {
...@@ -70,7 +71,7 @@ struct TestGatherOpHandle { ...@@ -70,7 +71,7 @@ struct TestGatherOpHandle {
} }
void InitGatherOp(size_t input_scope_idx) { void InitGatherOp(size_t input_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes; nodes_.clear();
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
...@@ -82,42 +83,43 @@ struct TestGatherOpHandle { ...@@ -82,42 +83,43 @@ struct TestGatherOpHandle {
} }
param_scopes_[input_scope_idx]->Var("out"); param_scopes_[input_scope_idx]->Var("out");
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release()); ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release());
op_handle_ = op_handle_ =
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_); new GatherOpHandle(nodes_.back().get(), local_scopes_, gpu_list_);
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release());
auto* in_var_handle = auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); new VarHandle(nodes_.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
} }
// add dummy var // add dummy var
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* in_dummy_var_handle = DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back()); static_cast<DummyVarHandle*>(vars_.back());
in_dummy_var_handle->ClearGeneratedOp(); in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release());
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, auto* out_var_handle =
"out", gpu_list_[input_scope_idx]); new VarHandle(nodes_.back().get(), 2, input_scope_idx, "out",
gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back()); static_cast<DummyVarHandle*>(vars_.back());
op_handle_->AddOutput(dummy_var_handle); op_handle_->AddOutput(dummy_var_handle);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -45,9 +46,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -45,9 +46,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
insert_pending_var(var); insert_pending_var(var);
} }
for (ir::Node *node : graph->Nodes()) { for (OpHandleBase *op : ir::GetFilteredNodes<OpHandleBase>(*graph)) {
if (!node->IsWrappedBy<OpHandleBase>()) continue;
OpHandleBase *op = &node->Wrapper<OpHandleBase>();
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op); ready_ops.insert(op);
} else { } else {
......
...@@ -35,7 +35,7 @@ namespace details { ...@@ -35,7 +35,7 @@ namespace details {
// The outside vector is the device vector. Each element of this vector is a // The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name, // map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the // will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles. // `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>> typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>>
GraphVars; GraphVars;
const char kGraphVars[] = "vars"; const char kGraphVars[] = "vars";
......
...@@ -30,8 +30,8 @@ struct TestReduceOpHandle { ...@@ -30,8 +30,8 @@ struct TestReduceOpHandle {
Scope g_scope_; Scope g_scope_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<Scope *> param_scopes_; std::vector<Scope *> param_scopes_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase *op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase *> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
std::vector<std::unique_ptr<p::DeviceContext>> ctxs_; std::vector<std::unique_ptr<p::DeviceContext>> ctxs_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册