From 51a538e0554ff08ee4cd80e1ecc849f564e3416e Mon Sep 17 00:00:00 2001 From: baojun-nervana Date: Tue, 13 Nov 2018 14:14:24 -0800 Subject: [PATCH] Fix style and use enum test=develop --- paddle/fluid/framework/ngraph_operator.cc | 80 ++++++++++++----------- paddle/fluid/framework/ngraph_operator.h | 18 ++--- 2 files changed, 51 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index 70e6f97b4c1..d967b2780c2 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -35,6 +35,13 @@ static std::map pd2ng_type_map = { {proto::VarType::BOOL, ngraph::element::boolean}, }; +typedef enum { /* nGraph support state on ops */ + FULL_TRAIN, /* Support full ops for train */ + PARTIAL_TRAIN, /* Support partial ops for train */ + FULL_TEST, /* Support full list of ops for test */ + PARTIAL_TEST /* Support partial list of ops for test */ +} op_state; + class NgraphOperator { public: explicit NgraphOperator(const Scope& scope, const platform::Place& place, @@ -44,33 +51,29 @@ class NgraphOperator { const std::unordered_set& persist, const std::unordered_set& fetches, const std::unordered_set& post_op_inputs, - int is_test_or_train) - : scope(scope), - place(place), - fused_ops(ops), - var_type_map(var_type_map), - persistables(persist), - fetches(fetches), - post_op_inputs(post_op_inputs), - is_test_or_train(is_test_or_train) {} + op_state ng_op_state) + : scope_(scope), + place_(place), + fused_ops_(ops), + var_type_map_(var_type_map), + persistables_(persist), + fetches_(fetches), + post_op_inputs_(post_op_inputs), + ng_op_state_(ng_op_state) {} void Run(const Scope& scope, const platform::Place& place) const; private: static std::unordered_map> func_cache; - const Scope& scope; - const platform::Place& place; - std::vector> fused_ops; - std::unordered_map var_type_map; - std::unordered_set persistables; - std::unordered_set fetches; - std::unordered_set post_op_inputs; - // 0 = default; 1 = (is_test && not is_complete) - // 2 = (is_test && is_complete) - // 3 = (is_training && not is_complete) - // 4 = (is_training && is_complete) - int is_test_or_train; + const Scope& scope_; + const platform::Place& place_; + std::vector> fused_ops_; + std::unordered_map var_type_map_; + std::unordered_set persistables_; + std::unordered_set fetches_; + std::unordered_set post_op_inputs_; + op_state ng_op_state_; }; std::vector>::iterator>> @@ -131,19 +134,19 @@ FusedOperator::FusedOperator( const ProgramDesc& prog, size_t block_id, std::vector>::iterator start, std::vector>::iterator end, - const std::string& type = "fused_op", const VariableNameMap& inputs = {}, - const VariableNameMap& outputs = {}, const AttributeMap& attrs = {}) + const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) { for (std::vector>::iterator it = start; it != end; ++it) { - fused_ops.push_back(std::move(*it)); + fused_ops_.push_back(std::move(*it)); } for (std::vector>::iterator it = end; (*it)->Type() != kFetchOpType; ++it) { for (auto& var_name_item : (*it)->Inputs()) { for (auto& var_name : var_name_item.second) { - post_op_inputs.insert(var_name); + post_op_inputs_.insert(var_name); } } } @@ -152,11 +155,11 @@ FusedOperator::FusedOperator( is_complete = true; } - process(); + Process(); } -void FusedOperator::process() { - auto& bdesc = pdesc.Block(block); +void FusedOperator::Process() { + auto& bdesc = pdesc_.Block(block_); for (auto& var : bdesc.AllVars()) { if (!(var->GetType() == proto::VarType::SELECTED_ROWS || var->GetType() == proto::VarType::LOD_TENSOR || @@ -175,39 +178,40 @@ void FusedOperator::process() { PADDLE_THROW("Data type of var %s not found in pd2ng_type_map", var_name); } - var_type_map[var_name] = pd2ng_type_map[pd_type]; + var_type_map_[var_name] = pd2ng_type_map[pd_type]; } if (var->Persistable()) { - persistables.insert(var->Name()); + persistables_.insert(var->Name()); } } for (auto* op : bdesc.AllOps()) { if (op->Type() == kFetchOpType) { std::string fetch_target_name = op->Input("X")[0]; - fetches.insert(fetch_target_name); + fetches_.insert(fetch_target_name); } } } void FusedOperator::RunImpl(const Scope& scope, const platform::Place& place) const { - int is_test_or_train = 1; - auto& bdesc = pdesc.Block(block); + op_state ng_op_state = PARTIAL_TEST; + auto& bdesc = pdesc_.Block(block_); for (auto* op : bdesc.AllOps()) { if (op->Type().find("_grad") != std::string::npos) { - is_test_or_train = 3; + ng_op_state = PARTIAL_TRAIN; break; } } - if (is_complete) { - is_test_or_train = is_test_or_train == 1 ? 2 : 4; + if (is_full) { + ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN; } - NgraphOperator ngraph_op(scope, place, fused_ops, var_type_map, persistables, - fetches, post_op_inputs, is_test_or_train); + NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_, + persistables_, fetches_, post_op_inputs_, + ng_op_state); ngraph_op.Run(scope, place); } diff --git a/paddle/fluid/framework/ngraph_operator.h b/paddle/fluid/framework/ngraph_operator.h index eb77c781150..0f655cef1dd 100644 --- a/paddle/fluid/framework/ngraph_operator.h +++ b/paddle/fluid/framework/ngraph_operator.h @@ -56,16 +56,16 @@ class FusedOperator : public OperatorBase { void RunImpl(const Scope& scope, const platform::Place& place) const final; private: - const ProgramDesc pdesc; - size_t block; - std::vector> fused_ops; - std::unordered_map var_type_map; - std::unordered_set persistables; - std::unordered_set fetches; - std::unordered_set post_op_inputs; - bool is_complete = false; + const ProgramDesc pdesc_; + size_t block_; + std::vector> fused_ops_; + std::unordered_map var_type_map_; + std::unordered_set persistables_; + std::unordered_set fetches_; + std::unordered_set post_op_inputs_; + bool is_full_ = false; - void process(); + void Process(); }; } // namespace framework } // namespace paddle -- GitLab