提交 51a538e0 编写于 作者: B baojun-nervana

Fix style and use enum

test=develop
上级 ea3538d8
...@@ -35,6 +35,13 @@ static std::map<proto::VarType::Type, ngraph::element::Type> pd2ng_type_map = { ...@@ -35,6 +35,13 @@ static std::map<proto::VarType::Type, ngraph::element::Type> pd2ng_type_map = {
{proto::VarType::BOOL, ngraph::element::boolean}, {proto::VarType::BOOL, ngraph::element::boolean},
}; };
typedef enum { /* nGraph support state on ops */
FULL_TRAIN, /* Support full ops for train */
PARTIAL_TRAIN, /* Support partial ops for train */
FULL_TEST, /* Support full list of ops for test */
PARTIAL_TEST /* Support partial list of ops for test */
} op_state;
class NgraphOperator { class NgraphOperator {
public: public:
explicit NgraphOperator(const Scope& scope, const platform::Place& place, explicit NgraphOperator(const Scope& scope, const platform::Place& place,
...@@ -44,33 +51,29 @@ class NgraphOperator { ...@@ -44,33 +51,29 @@ class NgraphOperator {
const std::unordered_set<std::string>& persist, const std::unordered_set<std::string>& persist,
const std::unordered_set<std::string>& fetches, const std::unordered_set<std::string>& fetches,
const std::unordered_set<std::string>& post_op_inputs, const std::unordered_set<std::string>& post_op_inputs,
int is_test_or_train) op_state ng_op_state)
: scope(scope), : scope_(scope),
place(place), place_(place),
fused_ops(ops), fused_ops_(ops),
var_type_map(var_type_map), var_type_map_(var_type_map),
persistables(persist), persistables_(persist),
fetches(fetches), fetches_(fetches),
post_op_inputs(post_op_inputs), post_op_inputs_(post_op_inputs),
is_test_or_train(is_test_or_train) {} ng_op_state_(ng_op_state) {}
void Run(const Scope& scope, const platform::Place& place) const; void Run(const Scope& scope, const platform::Place& place) const;
private: private:
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>> static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
func_cache; func_cache;
const Scope& scope; const Scope& scope_;
const platform::Place& place; const platform::Place& place_;
std::vector<std::shared_ptr<OperatorBase>> fused_ops; std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map; std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::unordered_set<std::string> persistables; std::unordered_set<std::string> persistables_;
std::unordered_set<std::string> fetches; std::unordered_set<std::string> fetches_;
std::unordered_set<std::string> post_op_inputs; std::unordered_set<std::string> post_op_inputs_;
// 0 = default; 1 = (is_test && not is_complete) op_state ng_op_state_;
// 2 = (is_test && is_complete)
// 3 = (is_training && not is_complete)
// 4 = (is_training && is_complete)
int is_test_or_train;
}; };
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>> std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
...@@ -131,19 +134,19 @@ FusedOperator::FusedOperator( ...@@ -131,19 +134,19 @@ FusedOperator::FusedOperator(
const ProgramDesc& prog, size_t block_id, const ProgramDesc& prog, size_t block_id,
std::vector<std::unique_ptr<OperatorBase>>::iterator start, std::vector<std::unique_ptr<OperatorBase>>::iterator start,
std::vector<std::unique_ptr<OperatorBase>>::iterator end, std::vector<std::unique_ptr<OperatorBase>>::iterator end,
const std::string& type = "fused_op", const VariableNameMap& inputs = {}, const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs = {}, const AttributeMap& attrs = {}) const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) { : OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) {
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = start; for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = start;
it != end; ++it) { it != end; ++it) {
fused_ops.push_back(std::move(*it)); fused_ops_.push_back(std::move(*it));
} }
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = end; for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = end;
(*it)->Type() != kFetchOpType; ++it) { (*it)->Type() != kFetchOpType; ++it) {
for (auto& var_name_item : (*it)->Inputs()) { for (auto& var_name_item : (*it)->Inputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
post_op_inputs.insert(var_name); post_op_inputs_.insert(var_name);
} }
} }
} }
...@@ -152,11 +155,11 @@ FusedOperator::FusedOperator( ...@@ -152,11 +155,11 @@ FusedOperator::FusedOperator(
is_complete = true; is_complete = true;
} }
process(); Process();
} }
void FusedOperator::process() { void FusedOperator::Process() {
auto& bdesc = pdesc.Block(block); auto& bdesc = pdesc_.Block(block_);
for (auto& var : bdesc.AllVars()) { for (auto& var : bdesc.AllVars()) {
if (!(var->GetType() == proto::VarType::SELECTED_ROWS || if (!(var->GetType() == proto::VarType::SELECTED_ROWS ||
var->GetType() == proto::VarType::LOD_TENSOR || var->GetType() == proto::VarType::LOD_TENSOR ||
...@@ -175,39 +178,40 @@ void FusedOperator::process() { ...@@ -175,39 +178,40 @@ void FusedOperator::process() {
PADDLE_THROW("Data type of var %s not found in pd2ng_type_map", PADDLE_THROW("Data type of var %s not found in pd2ng_type_map",
var_name); var_name);
} }
var_type_map[var_name] = pd2ng_type_map[pd_type]; var_type_map_[var_name] = pd2ng_type_map[pd_type];
} }
if (var->Persistable()) { if (var->Persistable()) {
persistables.insert(var->Name()); persistables_.insert(var->Name());
} }
} }
for (auto* op : bdesc.AllOps()) { for (auto* op : bdesc.AllOps()) {
if (op->Type() == kFetchOpType) { if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0]; std::string fetch_target_name = op->Input("X")[0];
fetches.insert(fetch_target_name); fetches_.insert(fetch_target_name);
} }
} }
} }
void FusedOperator::RunImpl(const Scope& scope, void FusedOperator::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
int is_test_or_train = 1; op_state ng_op_state = PARTIAL_TEST;
auto& bdesc = pdesc.Block(block); auto& bdesc = pdesc_.Block(block_);
for (auto* op : bdesc.AllOps()) { for (auto* op : bdesc.AllOps()) {
if (op->Type().find("_grad") != std::string::npos) { if (op->Type().find("_grad") != std::string::npos) {
is_test_or_train = 3; ng_op_state = PARTIAL_TRAIN;
break; break;
} }
} }
if (is_complete) { if (is_full) {
is_test_or_train = is_test_or_train == 1 ? 2 : 4; ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN;
} }
NgraphOperator ngraph_op(scope, place, fused_ops, var_type_map, persistables, NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_,
fetches, post_op_inputs, is_test_or_train); persistables_, fetches_, post_op_inputs_,
ng_op_state);
ngraph_op.Run(scope, place); ngraph_op.Run(scope, place);
} }
......
...@@ -56,16 +56,16 @@ class FusedOperator : public OperatorBase { ...@@ -56,16 +56,16 @@ class FusedOperator : public OperatorBase {
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
private: private:
const ProgramDesc pdesc; const ProgramDesc pdesc_;
size_t block; size_t block_;
std::vector<std::shared_ptr<OperatorBase>> fused_ops; std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map; std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::unordered_set<std::string> persistables; std::unordered_set<std::string> persistables_;
std::unordered_set<std::string> fetches; std::unordered_set<std::string> fetches_;
std::unordered_set<std::string> post_op_inputs; std::unordered_set<std::string> post_op_inputs_;
bool is_complete = false; bool is_full_ = false;
void process(); void Process();
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册