From 15400748ae6d21facb0b8e656b4298e1ae83df89 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 9 Oct 2017 20:42:29 -0700 Subject: [PATCH] follow comments and refine codes --- paddle/framework/backward.cc | 2 +- paddle/framework/executor_test.cc | 44 +++++++++++++++---------------- paddle/operators/feed_op.cc | 6 ++--- paddle/operators/feed_op.h | 4 +-- paddle/operators/fetch_op.cc | 4 +-- paddle/operators/fetch_op.h | 4 +-- 6 files changed, 31 insertions(+), 33 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 9a5c4e9cf..774d8e491 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -378,7 +378,7 @@ std::vector> MakeBlockBackward( backward_descs[dup_op[i]]->Rename(out_name, new_name); sum_op_inputs.emplace_back(new_name); } - LOG(INFO) << "fuck " << sum_op_inputs.size(); + LOG(INFO) << "sum_op_inputs size " << sum_op_inputs.size(); std::unique_ptr sum_op(new OpDescBind( "sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {})); pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)}); diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index 12be79d01..0515fb221 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -60,15 +60,13 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, op->SetAttrMap(attrs); } -std::once_flag set_variable_flag; - // Tensors in feed value variable will only be in CPUPlace -// So we can memcpy the data from vector to feed_value +// So we can memcpy the data from vector to feed_value template void SetFeedVariable(const std::vector>& inputs) { - typedef std::vector FeedInputs; Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value"); - FeedInputs& feed_inputs = *(g_feed_value->GetMutable()); + auto& feed_inputs = + *(g_feed_value->GetMutable>()); size_t size = inputs.size(); feed_inputs.resize(size); for (size_t i = 0; i < size; i++) { @@ -82,9 +80,9 @@ void SetFeedVariable(const std::vector>& inputs) { // So we can memcpy the data from fetch_value to vector template std::vector> GetFetchVariable() { - typedef std::vector FetchOutputs; Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value"); - FetchOutputs& fetch_outputs = *(g_fetch_value->GetMutable()); + auto& fetch_outputs = + *(g_fetch_value->GetMutable>()); size_t size = fetch_outputs.size(); std::vector> result; @@ -143,22 +141,22 @@ class ExecutorTesterRandom : public ::testing::Test { {{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {}, root_block); - AddOp("gaussian_random", {}, {{"Out", {"l2_distance@GRAD"}}}, - {{"dims", std::vector{batch_size, 1}}}, root_block); - AppendBackward(program, {}); - - program.Proto(); - - for (auto& op : pdesc_.blocks(0).ops()) { - if (op.type() == "sum") { - LOG(INFO) << "Here"; - for (auto& var : op.inputs()) { - for (auto& argu : var.arguments()) { - LOG(INFO) << var.parameter() << " " << argu; - } - } - } - } + // AddOp("gaussian_random", {}, {{"Out", {"l2_distance@GRAD"}}}, + // {{"dims", std::vector{batch_size, 1}}}, root_block); + // AppendBackward(program, {}); + + // program.Proto(); + + // for (auto& op : pdesc_.blocks(0).ops()) { + // if (op.type() == "sum") { + // LOG(INFO) << "Here"; + // for (auto& var : op.inputs()) { + // for (auto& argu : var.arguments()) { + // LOG(INFO) << var.parameter() << " " << argu; + // } + // } + // } + // } AddOp("fetch", {{"Input", {"l2_distance"}}}, {}, {{"dims", std::vector{batch_size}}, {"col", 1}}, root_block); diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index b15bc86ae..29e128ce7 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -23,15 +23,15 @@ class FeedOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - typedef std::vector FeedInputs; PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null."); int col = ctx->Attrs().Get("col"); framework::Variable* g_feed_variable = framework::GetGlobalScope()->FindVar("feed_value"); - const FeedInputs& tensors = g_feed_variable->Get(); + const auto& tensors = + g_feed_variable->Get>(); - PADDLE_ENFORCE_GT(tensors.size(), col); + PADDLE_ENFORCE_GT(tensors.size(), static_cast(col)); auto in_dim = tensors[col].dims(); ctx->SetOutputDim("Out", in_dim); // TODO(qijun): need to handle LodTensor later diff --git a/paddle/operators/feed_op.h b/paddle/operators/feed_op.h index de8ec6ff6..96e3bf52b 100644 --- a/paddle/operators/feed_op.h +++ b/paddle/operators/feed_op.h @@ -23,13 +23,13 @@ template class FeedKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - typedef std::vector FeedInputs; framework::Tensor* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); framework::Variable* g_feed_variable = framework::GetGlobalScope()->FindVar("feed_value"); int col = ctx.template Attr("col"); - const FeedInputs& tensors = g_feed_variable->Get(); + const auto& tensors = + g_feed_variable->Get>(); out->CopyFrom(tensors[col], ctx.GetPlace()); } }; diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index 7ca3762c3..77e3450a7 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -23,13 +23,13 @@ class FetchOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - typedef std::vector FetchOutputs; PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null."); int col = ctx->Attrs().Get("col"); framework::Variable* g_fetch_variable = framework::GetGlobalScope()->FindVar("fetch_value"); - FetchOutputs* tensors = g_fetch_variable->GetMutable(); + auto* tensors = + g_fetch_variable->GetMutable>(); if (tensors->size() < static_cast(col + 1)) { tensors->resize(col + 1); } diff --git a/paddle/operators/fetch_op.h b/paddle/operators/fetch_op.h index 3bec9c997..fd9855205 100644 --- a/paddle/operators/fetch_op.h +++ b/paddle/operators/fetch_op.h @@ -23,12 +23,12 @@ template class FetchKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - typedef std::vector FetchOutputs; const framework::Tensor* input = ctx.Input("Input"); int col = ctx.template Attr("col"); framework::Variable* g_fetch_variable = framework::GetGlobalScope()->FindVar("fetch_value"); - FetchOutputs* tensors = g_fetch_variable->GetMutable(); + auto* tensors = + g_fetch_variable->GetMutable>(); (*tensors)[col].mutable_data(platform::CPUPlace()); (*tensors)[col].CopyFrom(*input, platform::CPUPlace()); } -- GitLab