From 089cc11df48c8b29b34eda8ea19328a090d4c9f6 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Mon, 9 Oct 2017 03:30:53 +0000 Subject: [PATCH] clean up && fix #4624 --- paddle/framework/block_desc.cc | 6 ++ paddle/framework/executor.cc | 37 +++------ paddle/framework/executor_test.cc | 129 ++++++++++++------------------ 3 files changed, 68 insertions(+), 104 deletions(-) diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index 01f50e1393..509aa235d3 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -74,6 +74,12 @@ void BlockDescBind::Sync() { for (auto &op_desc : ops_) { op_field.AddAllocated(op_desc->Proto()); } + auto &var_field = *this->desc_->mutable_vars(); + var_field.Clear(); + var_field.Reserve(static_cast(vars_.size())); + for (auto &var_desc : vars_) { + var_field.AddAllocated(var_desc.second->Proto()); + } need_update_ = false; } } diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index 9391e18ded..c6c9d13469 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -54,39 +54,33 @@ Executor::~Executor() { void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { // TODO(tonyyang-svail): - // - only runs the first block - // - only runs on the first device - // - test on gpu + // - only runs the first block (i.e. no RNN support) + // - only runs on the first device (i.e. no interdevice communication) auto& block = pdesc.blocks(0); auto& device = device_contexts_[0]; - // TODO(tonyyang-svail): - // - runs on a new local scope - // Scope& local_scope = scope->NewScope(); - + // Instantiate all the vars in the global scope for (auto& var : block.vars()) { scope->NewVar(var.name()); } + Scope& local_scope = scope->NewScope(); + std::vector should_run = Preprocess(pdesc); PADDLE_ENFORCE(should_run.size() == block.ops_size()); for (size_t i = 0; i < should_run.size(); ++i) { if (should_run[i]) { + for (auto var : block.ops(i).outputs()) { + for (auto argu : var.arguments()) { + if (local_scope.FindVar(argu) == nullptr) { + local_scope.NewVar(argu); + } + } + } auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i)); - op->Run(*scope, *device); + op->Run(local_scope, *device); } } - - // // print tensor value - // for (auto& var : block.vars()) { - // std::cout << var.name() << std::endl; - // auto v = scope->FindVar(var.name()); - // const LoDTensor& t = v->Get(); - // for (int i = 0; i < t.numel(); ++i) { - // std::cout << t.data()[i] << " "; - // } - // std::cout << std::endl; - // } } std::vector Executor::Preprocess(const ProgramDesc& pdesc) { @@ -125,7 +119,6 @@ std::vector Executor::Preprocess(const ProgramDesc& pdesc) { } } - // TODO(tonyyang-svail): add VLOG here for debugging if (op_desc.type() == "fetch" || found_dependent_vars) { // erase its output to the dependency graph for (auto& var : op_desc.outputs()) { @@ -141,13 +134,9 @@ std::vector Executor::Preprocess(const ProgramDesc& pdesc) { } } - // this op should be executed should_run.push_back(true); - LOG(INFO) << "Yes " << op_desc.type(); } else { - // this op should NOT be executed should_run.push_back(false); - LOG(INFO) << "No " << op_desc.type(); } } diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index 7ce472ed2f..99f80d04e8 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/framework/attribute.h" #include "paddle/framework/backward.h" #include "paddle/framework/block_desc.h" -#include "paddle/framework/grad_op_builder.h" +// #include "paddle/framework/grad_op_builder.h" #include "paddle/framework/op_desc.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" @@ -37,68 +37,27 @@ using namespace paddle::framework; typedef paddle::framework::BlockDesc proto_block; typedef paddle::framework::OpDesc proto_op; -struct SetAttrDescVisitor : public boost::static_visitor { - explicit SetAttrDescVisitor(OpDesc::Attr* attr) : attr_(attr) {} - mutable OpDesc::Attr* attr_; - void operator()(int v) const { attr_->set_i(v); } - void operator()(float v) const { attr_->set_f(v); } - void operator()(const std::string& v) const { attr_->set_s(v); } - void operator()(bool b) const { attr_->set_b(b); } - - void operator()(const std::vector& v) const { - VectorToRepeated(v, attr_->mutable_ints()); - } - void operator()(const std::vector& v) const { - VectorToRepeated(v, attr_->mutable_floats()); - } - void operator()(const std::vector& v) const { - VectorToRepeated(v, attr_->mutable_strings()); - } - void operator()(const std::vector& v) const { - VectorToRepeated(v, attr_->mutable_bools()); - } - void operator()(BlockDesc* desc) const { attr_->set_block_idx(desc->idx()); } - void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } -}; - void AddOp(const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, AttributeMap attrs, - proto_block* block) { + paddle::framework::BlockDescBind* block) { // insert output for (auto kv : outputs) { for (auto v : kv.second) { - auto var = block->add_vars(); - var->set_name(v); - auto var_lt = var->mutable_lod_tensor(); - var_lt->set_data_type(paddle::framework::DataType::FP32); + auto var = block->NewVar(v); + var->SetDataType(paddle::framework::DataType::FP32); } } // insert op - auto op = block->add_ops(); - op->set_type(type); + auto op = block->AppendOp(); + op->SetType(type); for (auto kv : inputs) { - auto X = op->add_inputs(); - X->set_parameter(kv.first); - for (auto argu : kv.second) { - X->add_arguments(argu); - } + op->SetInput(kv.first, kv.second); } for (auto kv : outputs) { - auto X = op->add_outputs(); - X->set_parameter(kv.first); - for (auto argu : kv.second) { - X->add_arguments(argu); - } - } - for (auto& attr : attrs) { - auto* attr_desc = op->add_attrs(); - attr_desc->set_name(attr.first); - attr_desc->set_type( - static_cast(attr.second.which() - 1)); - SetAttrDescVisitor visitor(attr_desc); - boost::apply_visitor(visitor, attr.second); + op->SetOutput(kv.first, kv.second); } + op->SetAttrMap(attrs); } std::once_flag set_variable_flag; @@ -146,10 +105,16 @@ class ExecutorTesterRandom : public ::testing::Test { virtual void SetUp() override { int input_dim = 5, batch_size = 2, embed_dim = 5; - // init pdesc - auto init_root_block = init_pdesc_.add_blocks(); - init_root_block->set_idx(0); - init_root_block->set_parent_idx(-1); + // init pdesc ----------------------------------------- + auto temp_init_root_block = init_pdesc_.add_blocks(); + temp_init_root_block->set_idx(0); + temp_init_root_block->set_parent_idx(-1); + + // wrap to BlockDescBind + paddle::framework::ProgramDescBind& init_program = + paddle::framework::ProgramDescBind::Instance(&init_pdesc_); + paddle::framework::BlockDescBind* init_root_block = init_program.Block(0); + AddOp("gaussian_random", {}, {{"Out", {"w1"}}}, {{"dims", std::vector{input_dim, embed_dim}}}, init_root_block); AddOp("gaussian_random", {}, {{"Out", {"w2"}}}, @@ -160,11 +125,18 @@ class ExecutorTesterRandom : public ::testing::Test { AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"dims", std::vector{embed_dim, input_dim}}, {"col", 1}}, init_root_block); + // flush + init_program.Proto(); + + // run pdesc ----------------------------------------- + auto temp_root_block = pdesc_.add_blocks(); + temp_root_block->set_idx(0); + temp_root_block->set_parent_idx(-1); - // run pdesc - auto root_block = pdesc_.add_blocks(); - root_block->set_idx(0); - root_block->set_parent_idx(-1); + // wrap to BlockDescBind + paddle::framework::ProgramDescBind& program = + paddle::framework::ProgramDescBind::Instance(&pdesc_); + paddle::framework::BlockDescBind* root_block = program.Block(0); AddOp("gaussian_random", {}, {{"Out", {"a"}}}, {{"dims", std::vector{batch_size, input_dim}}}, root_block); @@ -175,13 +147,16 @@ class ExecutorTesterRandom : public ::testing::Test { AddOp("squared_l2_distance", {{"X", {"a"}}, {"Y", {"a_out"}}}, {{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {}, root_block); - - AppendBackward(pdesc_, {}); - // AddOp("fetch", {{"Input", {"sub_result"}}}, {}, - // {{"dims", std::vector{input_dim, batch_size}}, {"col", 0}}, - // root_block); AddOp("fetch", {{"Input", {"l2_distance"}}}, {}, {{"dims", std::vector{batch_size}}, {"col", 1}}, root_block); + // flush + program.Proto(); + + // TODO(tonyyang-svail): + // - Test with Backward + // AddOp("gaussian_random", {}, {{"Out", {"l2_distance@GRAD"}}}, + // {{"dims", std::vector{batch_size, 1}}}, root_block); + // AppendBackward(program, {}); } protected: @@ -192,9 +167,14 @@ class ExecutorTesterRandom : public ::testing::Test { class ExecutorTesterFeedAndFetch : public ::testing::Test { public: virtual void SetUp() override { - auto root_block = pdesc_.add_blocks(); - root_block->set_idx(0); - root_block->set_parent_idx(-1); + auto temp_root_block = pdesc_.add_blocks(); + temp_root_block->set_idx(0); + temp_root_block->set_parent_idx(-1); + + // wrap to BlockDescBind + paddle::framework::ProgramDescBind& program = + paddle::framework::ProgramDescBind::Instance(&pdesc_); + paddle::framework::BlockDescBind* root_block = program.Block(0); std::vector dim{6}; @@ -207,6 +187,9 @@ class ExecutorTesterFeedAndFetch : public ::testing::Test { AddOp("fetch", {{"Input", {"b"}}}, {}, {{"dims", dim}, {"col", 1}}, root_block); + // flush + program.Proto(); + std::vector vec1 = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; std::vector vec2 = {4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; inputs_.push_back(vec1); @@ -235,12 +218,6 @@ TEST_F(ExecutorTesterRandom, CPU) { executor->Run(pdesc_, GetGlobalScope()); std::vector> result = get_fetch_variable(); - for (auto& vec : result) { - for (auto& num : vec) { - std::cout << num << " "; - } - std::cout << std::endl; - } delete executor; } @@ -290,18 +267,10 @@ TEST_F(ExecutorTesterRandom, GPU) { Executor* executor = new Executor(places); - LOG(INFO) << "Run Init"; executor->Run(init_pdesc_, GetGlobalScope()); - LOG(INFO) << "Run"; executor->Run(pdesc_, GetGlobalScope()); std::vector> result = get_fetch_variable(); - for (auto& vec : result) { - for (auto& num : vec) { - std::cout << num << " "; - } - std::cout << std::endl; - } delete executor; } -- GitLab