From e51557130e91383afb0e54dee00710664c9bf555 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Mon, 9 Oct 2017 22:57:11 +0000 Subject: [PATCH] clean up for review --- paddle/framework/executor.cc | 40 ++++++++++++++------- paddle/framework/executor.h | 2 +- paddle/framework/executor_test.cc | 60 +++++++++++++------------------ paddle/framework/scope.cc | 1 + paddle/operators/feed_op.cc | 1 + paddle/operators/fetch_op.cc | 1 + paddle/platform/gpu_info.cc | 2 +- 7 files changed, 56 insertions(+), 51 deletions(-) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index c6c9d13469..3ac752388f 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -13,11 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/executor.h" + #include #include #include #include #include + #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/scope.h" @@ -27,7 +29,11 @@ limitations under the License. */ namespace paddle { namespace framework { +const std::string kFeedOpType = "feed"; +const std::string kFetchOpType = "fetch"; + Executor::Executor(const std::vector& places) { + PADDLE_ENFORCE_GT(places.size(), 0); device_contexts_.resize(places.size()); for (size_t i = 0; i < places.size(); i++) { if (platform::is_cpu_place(places[i])) { @@ -46,9 +52,7 @@ Executor::Executor(const std::vector& places) { Executor::~Executor() { for (auto& device_context : device_contexts_) { - if (device_context) { - delete device_context; - } + delete device_context; } } @@ -56,6 +60,8 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { // TODO(tonyyang-svail): // - only runs the first block (i.e. no RNN support) // - only runs on the first device (i.e. no interdevice communication) + // - will change to use multiple blocks for RNN op and Cond Op + PADDLE_ENFORCE_GT(pdesc.blocks_size(), 0); auto& block = pdesc.blocks(0); auto& device = device_contexts_[0]; @@ -66,12 +72,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { Scope& local_scope = scope->NewScope(); - std::vector should_run = Preprocess(pdesc); - PADDLE_ENFORCE(should_run.size() == block.ops_size()); + std::vector should_run = Prune(pdesc); + PADDLE_ENFORCE_EQ(should_run.size(), block.ops_size()); for (size_t i = 0; i < should_run.size(); ++i) { if (should_run[i]) { - for (auto var : block.ops(i).outputs()) { - for (auto argu : var.arguments()) { + for (auto& var : block.ops(i).outputs()) { + for (auto& argu : var.arguments()) { if (local_scope.FindVar(argu) == nullptr) { local_scope.NewVar(argu); } @@ -81,28 +87,32 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { op->Run(local_scope, *device); } } + + // TODO(tonyyang-svail): + // - Destroy local_scope } -std::vector Executor::Preprocess(const ProgramDesc& pdesc) { +std::vector Executor::Prune(const ProgramDesc& pdesc) { // TODO(tonyyang-svail): // - only runs the first block + // - will change to use multiple blocks for RNN op and Cond Op auto& block = pdesc.blocks(0); auto& ops = block.ops(); bool expect_feed = true; for (auto& op_desc : ops) { - PADDLE_ENFORCE(op_desc.type() != "feed" || expect_feed, + PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed, "All FeedOps are at the beginning of the ProgramDesc"); - expect_feed = (op_desc.type() == "feed"); + expect_feed = (op_desc.type() == kFeedOpType); } bool expect_fetch = true; for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { auto& op_desc = *op_iter; - PADDLE_ENFORCE(op_desc.type() != "fetch" || expect_fetch, + PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch, "All FetchOps must at the end of the ProgramDesc"); - expect_fetch = (op_desc.type() == "fetch"); + expect_fetch = (op_desc.type() == kFetchOpType); } std::set dependent_vars; @@ -119,7 +129,7 @@ std::vector Executor::Preprocess(const ProgramDesc& pdesc) { } } - if (op_desc.type() == "fetch" || found_dependent_vars) { + if (op_desc.type() == kFetchOpType || found_dependent_vars) { // erase its output to the dependency graph for (auto& var : op_desc.outputs()) { for (auto& argu : var.arguments()) { @@ -140,6 +150,10 @@ std::vector Executor::Preprocess(const ProgramDesc& pdesc) { } } + // TODO(tonyyang-svail): + // - check this after integration of Init + // PADDLE_ENFORCE(dependent_vars.empty()); + // since we are traversing the ProgramDesc in reverse order // we reverse the should_run vector std::reverse(should_run.begin(), should_run.end()); diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h index 75cb5939ff..f832b0d7d6 100644 --- a/paddle/framework/executor.h +++ b/paddle/framework/executor.h @@ -46,7 +46,7 @@ class Executor { * @return * vector Same size as ops. Indicates whether an op should be run. */ - std::vector Preprocess(const ProgramDesc& pdesc); + std::vector Prune(const ProgramDesc& pdesc); private: std::vector device_contexts_; diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index 99f80d04e8..f28651e809 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/executor.h" + +#include #include + #include "gtest/gtest.h" #include "paddle/framework/attribute.h" #include "paddle/framework/backward.h" #include "paddle/framework/block_desc.h" -// #include "paddle/framework/grad_op_builder.h" #include "paddle/framework/op_desc.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" @@ -34,9 +36,6 @@ using std::string; using namespace paddle::platform; using namespace paddle::framework; -typedef paddle::framework::BlockDesc proto_block; -typedef paddle::framework::OpDesc proto_op; - void AddOp(const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, AttributeMap attrs, paddle::framework::BlockDescBind* block) { @@ -51,10 +50,10 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, // insert op auto op = block->AppendOp(); op->SetType(type); - for (auto kv : inputs) { + for (auto& kv : inputs) { op->SetInput(kv.first, kv.second); } - for (auto kv : outputs) { + for (auto& kv : outputs) { op->SetOutput(kv.first, kv.second); } op->SetAttrMap(attrs); @@ -65,11 +64,11 @@ std::once_flag set_variable_flag; // Tensors in feed value variable will only be in CPUPlace // So we can memcpy the data from vector to feed_value template -void set_feed_variable(const std::vector>& inputs) { +void SetFeedVariable(const std::vector>& inputs) { typedef std::vector FeedInputs; Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value"); FeedInputs& feed_inputs = *(g_feed_value->GetMutable()); - auto size = inputs.size(); + size_t size = inputs.size(); feed_inputs.resize(size); for (size_t i = 0; i < size; i++) { T* dst = feed_inputs[i].mutable_data( @@ -81,12 +80,12 @@ void set_feed_variable(const std::vector>& inputs) { // Tensors in fetch value variable will only be in CPUPlace // So we can memcpy the data from fetch_value to vector template -std::vector> get_fetch_variable() { +std::vector> GetFetchVariable() { typedef std::vector FetchOutputs; Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value"); FetchOutputs& fetch_outputs = *(g_fetch_value->GetMutable()); - auto size = fetch_outputs.size(); + size_t size = fetch_outputs.size(); std::vector> result; result.reserve(size); for (size_t i = 0; i < size; i++) { @@ -105,7 +104,7 @@ class ExecutorTesterRandom : public ::testing::Test { virtual void SetUp() override { int input_dim = 5, batch_size = 2, embed_dim = 5; - // init pdesc ----------------------------------------- + // init pdesc auto temp_init_root_block = init_pdesc_.add_blocks(); temp_init_root_block->set_idx(0); temp_init_root_block->set_parent_idx(-1); @@ -128,7 +127,7 @@ class ExecutorTesterRandom : public ::testing::Test { // flush init_program.Proto(); - // run pdesc ----------------------------------------- + // run pdesc auto temp_root_block = pdesc_.add_blocks(); temp_root_block->set_idx(0); temp_root_block->set_parent_idx(-1); @@ -154,9 +153,6 @@ class ExecutorTesterRandom : public ::testing::Test { // TODO(tonyyang-svail): // - Test with Backward - // AddOp("gaussian_random", {}, {{"Out", {"l2_distance@GRAD"}}}, - // {{"dims", std::vector{batch_size, 1}}}, root_block); - // AppendBackward(program, {}); } protected: @@ -213,12 +209,11 @@ TEST_F(ExecutorTesterRandom, CPU) { // "pointer being freed was not allocated" error will appear. paddle::memory::Used(cpu_place); - Executor* executor = new Executor(places); + std::unique_ptr executor(new Executor(places)); + executor->Run(init_pdesc_, GetGlobalScope()); executor->Run(pdesc_, GetGlobalScope()); - std::vector> result = get_fetch_variable(); - - delete executor; + std::vector> result = GetFetchVariable(); } TEST_F(ExecutorTesterFeedAndFetch, CPU) { @@ -232,13 +227,12 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) { // "pointer being freed was not allocated" error will appear. paddle::memory::Used(cpu_place); - Executor* executor = new Executor(places); + std::unique_ptr executor(new Executor(places)); - // 3 mini-batch - for (int i = 0; i < 3; i++) { - set_feed_variable(inputs_); + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_); executor->Run(pdesc_, GetGlobalScope()); - std::vector> result = get_fetch_variable(); + std::vector> result = GetFetchVariable(); PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); for (size_t i = 0; i < result.size(); ++i) { PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); @@ -247,8 +241,6 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) { } } } - - delete executor; } #else TEST_F(ExecutorTesterRandom, GPU) { @@ -265,13 +257,11 @@ TEST_F(ExecutorTesterRandom, GPU) { paddle::memory::Used(CPUPlace()); paddle::memory::Used(gpu_place); - Executor* executor = new Executor(places); + std::unique_ptr executor(new Executor(places)); executor->Run(init_pdesc_, GetGlobalScope()); executor->Run(pdesc_, GetGlobalScope()); - std::vector> result = get_fetch_variable(); - - delete executor; + std::vector> result = GetFetchVariable(); } TEST_F(ExecutorTesterFeedAndFetch, GPU) { @@ -287,13 +277,12 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) { paddle::memory::Used(CPUPlace()); paddle::memory::Used(gpu_place); - Executor* executor = new Executor(places); + std::unique_ptr executor(new Executor(places)); - // 3 mini-batch - for (int i = 0; i < 3; i++) { - set_feed_variable(inputs_); + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_); executor->Run(pdesc_, GetGlobalScope()); - std::vector> result = get_fetch_variable(); + std::vector> result = GetFetchVariable(); PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); for (size_t i = 0; i < result.size(); ++i) { PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); @@ -302,6 +291,5 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) { } } } - delete executor; } #endif diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index 2a0d9bbf33..c9e53a0d85 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/scope.h" + #include // for unique_ptr #include // for call_once #include "paddle/string/printf.h" diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index b9e43be966..dcd5f7fb77 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -31,6 +31,7 @@ class FeedOp : public framework::OperatorWithKernel { const FeedInputs& tensors = g_feed_variable->Get(); + PADDLE_ENFORCE_GT(tensors.size(), col); auto in_dim = tensors[col].dims(); ctx->SetOutputDim("Out", in_dim); // TODO(qijun): need to handle LodTensor later diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index 7bde4953cd..5adb83144a 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -35,6 +35,7 @@ class FetchOp : public framework::OperatorWithKernel { } auto input_dim = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_GT(tensors->size(), col); (*tensors)[col].Resize(input_dim); // TODO(qijun): need to handle LodTensor later diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index aa76bb209d..0cab5ffc56 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -44,7 +44,7 @@ int GetCurrentDeviceId() { void SetDeviceId(int id) { // TODO(qijun): find a better way to cache the cuda device count - PADDLE_ENFORCE(id < GetCUDADeviceCount(), "id must less than GPU count"); + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); PADDLE_ENFORCE(cudaSetDevice(id), "cudaSetDevice failed in paddle::platform::SetDeviceId"); } -- GitLab