提交 51a538e0 编写于 作者: B baojun-nervana

Fix style and use enum

test=develop
上级 ea3538d8
......@@ -35,6 +35,13 @@ static std::map<proto::VarType::Type, ngraph::element::Type> pd2ng_type_map = {
{proto::VarType::BOOL, ngraph::element::boolean},
};
typedef enum { /* nGraph support state on ops */
FULL_TRAIN, /* Support full ops for train */
PARTIAL_TRAIN, /* Support partial ops for train */
FULL_TEST, /* Support full list of ops for test */
PARTIAL_TEST /* Support partial list of ops for test */
} op_state;
class NgraphOperator {
public:
explicit NgraphOperator(const Scope& scope, const platform::Place& place,
......@@ -44,33 +51,29 @@ class NgraphOperator {
const std::unordered_set<std::string>& persist,
const std::unordered_set<std::string>& fetches,
const std::unordered_set<std::string>& post_op_inputs,
int is_test_or_train)
: scope(scope),
place(place),
fused_ops(ops),
var_type_map(var_type_map),
persistables(persist),
fetches(fetches),
post_op_inputs(post_op_inputs),
is_test_or_train(is_test_or_train) {}
op_state ng_op_state)
: scope_(scope),
place_(place),
fused_ops_(ops),
var_type_map_(var_type_map),
persistables_(persist),
fetches_(fetches),
post_op_inputs_(post_op_inputs),
ng_op_state_(ng_op_state) {}
void Run(const Scope& scope, const platform::Place& place) const;
private:
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
func_cache;
const Scope& scope;
const platform::Place& place;
std::vector<std::shared_ptr<OperatorBase>> fused_ops;
std::unordered_map<std::string, ngraph::element::Type> var_type_map;
std::unordered_set<std::string> persistables;
std::unordered_set<std::string> fetches;
std::unordered_set<std::string> post_op_inputs;
// 0 = default; 1 = (is_test && not is_complete)
// 2 = (is_test && is_complete)
// 3 = (is_training && not is_complete)
// 4 = (is_training && is_complete)
int is_test_or_train;
const Scope& scope_;
const platform::Place& place_;
std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::unordered_set<std::string> persistables_;
std::unordered_set<std::string> fetches_;
std::unordered_set<std::string> post_op_inputs_;
op_state ng_op_state_;
};
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
......@@ -131,19 +134,19 @@ FusedOperator::FusedOperator(
const ProgramDesc& prog, size_t block_id,
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
const std::string& type = "fused_op", const VariableNameMap& inputs = {},
const VariableNameMap& outputs = {}, const AttributeMap& attrs = {})
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) {
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = start;
it != end; ++it) {
fused_ops.push_back(std::move(*it));
fused_ops_.push_back(std::move(*it));
}
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = end;
(*it)->Type() != kFetchOpType; ++it) {
for (auto& var_name_item : (*it)->Inputs()) {
for (auto& var_name : var_name_item.second) {
post_op_inputs.insert(var_name);
post_op_inputs_.insert(var_name);
}
}
}
......@@ -152,11 +155,11 @@ FusedOperator::FusedOperator(
is_complete = true;
}
process();
Process();
}
void FusedOperator::process() {
auto& bdesc = pdesc.Block(block);
void FusedOperator::Process() {
auto& bdesc = pdesc_.Block(block_);
for (auto& var : bdesc.AllVars()) {
if (!(var->GetType() == proto::VarType::SELECTED_ROWS ||
var->GetType() == proto::VarType::LOD_TENSOR ||
......@@ -175,39 +178,40 @@ void FusedOperator::process() {
PADDLE_THROW("Data type of var %s not found in pd2ng_type_map",
var_name);
}
var_type_map[var_name] = pd2ng_type_map[pd_type];
var_type_map_[var_name] = pd2ng_type_map[pd_type];
}
if (var->Persistable()) {
persistables.insert(var->Name());
persistables_.insert(var->Name());
}
}
for (auto* op : bdesc.AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
fetches.insert(fetch_target_name);
fetches_.insert(fetch_target_name);
}
}
}
void FusedOperator::RunImpl(const Scope& scope,
const platform::Place& place) const {
int is_test_or_train = 1;
auto& bdesc = pdesc.Block(block);
op_state ng_op_state = PARTIAL_TEST;
auto& bdesc = pdesc_.Block(block_);
for (auto* op : bdesc.AllOps()) {
if (op->Type().find("_grad") != std::string::npos) {
is_test_or_train = 3;
ng_op_state = PARTIAL_TRAIN;
break;
}
}
if (is_complete) {
is_test_or_train = is_test_or_train == 1 ? 2 : 4;
if (is_full) {
ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN;
}
NgraphOperator ngraph_op(scope, place, fused_ops, var_type_map, persistables,
fetches, post_op_inputs, is_test_or_train);
NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_,
persistables_, fetches_, post_op_inputs_,
ng_op_state);
ngraph_op.Run(scope, place);
}
......
......@@ -56,16 +56,16 @@ class FusedOperator : public OperatorBase {
void RunImpl(const Scope& scope, const platform::Place& place) const final;
private:
const ProgramDesc pdesc;
size_t block;
std::vector<std::shared_ptr<OperatorBase>> fused_ops;
std::unordered_map<std::string, ngraph::element::Type> var_type_map;
std::unordered_set<std::string> persistables;
std::unordered_set<std::string> fetches;
std::unordered_set<std::string> post_op_inputs;
bool is_complete = false;
const ProgramDesc pdesc_;
size_t block_;
std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::unordered_set<std::string> persistables_;
std::unordered_set<std::string> fetches_;
std::unordered_set<std::string> post_op_inputs_;
bool is_full_ = false;
void process();
void Process();
};
} // namespace framework
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册