提交 22ac2133 编写于 作者: B baojun-nervana

Rename class

test=develop
上级 bfde5e10
......@@ -91,11 +91,11 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
static void EnableFusedOp(ExecutorPrepareContext* ctx) {
#ifdef PADDLE_WITH_NGRAPH
VLOG(3) << "use_ngraph=True";
auto intervals = FusedOperator::FusedOpIntervals(&ctx->ops_);
auto intervals = NgraphOperator::NgraphOpIntervals(&ctx->ops_);
for (auto& interval : intervals) {
auto* fused_op = new FusedOperator(ctx->prog_, ctx->block_id_,
interval.at(0), interval.at(1));
*interval[0] = std::unique_ptr<OperatorBase>(fused_op);
auto* ng_op = new NgraphOperator(ctx->prog_, ctx->block_id_, interval.at(0),
interval.at(1));
*interval[0] = std::unique_ptr<OperatorBase>(ng_op);
}
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
ctx->ops_.erase(it->at(0) + 1, it->at(1));
......
......@@ -57,16 +57,16 @@ typedef enum { /* nGraph support state on ops */
} op_state;
// perform graph build through bridge and execute computation
class NgraphOperator {
class NgraphEngine {
public:
explicit NgraphOperator(const Scope& scope, const platform::Place& place,
const std::vector<std::shared_ptr<OperatorBase>>& ops,
const std::unordered_map<
std::string, ngraph::element::Type>& var_type_map,
const std::unordered_set<std::string>& persist,
const std::unordered_set<std::string>& fetches,
const std::unordered_set<std::string>& post_op_inputs,
op_state ng_op_state)
explicit NgraphEngine(const Scope& scope, const platform::Place& place,
const std::vector<std::shared_ptr<OperatorBase>>& ops,
const std::unordered_map<
std::string, ngraph::element::Type>& var_type_map,
const std::unordered_set<std::string>& persist,
const std::unordered_set<std::string>& fetches,
const std::unordered_set<std::string>& post_op_inputs,
op_state ng_op_state)
: scope_(scope),
place_(place),
fused_ops_(ops),
......@@ -131,7 +131,7 @@ class NgraphOperator {
};
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
FusedOperator::FusedOpIntervals(
NgraphOperator::NgraphOpIntervals(
std::vector<std::unique_ptr<paddle::framework::OperatorBase>>* ops) {
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
intervals;
......@@ -184,7 +184,7 @@ FusedOperator::FusedOpIntervals(
return intervals;
}
FusedOperator::FusedOperator(
NgraphOperator::NgraphOperator(
const ProgramDesc& prog, size_t block_id,
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
......@@ -214,7 +214,7 @@ FusedOperator::FusedOperator(
Process();
}
void FusedOperator::Process() {
void NgraphOperator::Process() {
auto& bdesc = pdesc_.Block(block_);
for (auto& var : bdesc.AllVars()) {
if (!(var->GetType() == proto::VarType::SELECTED_ROWS ||
......@@ -250,8 +250,8 @@ void FusedOperator::Process() {
}
}
void FusedOperator::RunImpl(const Scope& scope,
const platform::Place& place) const {
void NgraphOperator::RunImpl(const Scope& scope,
const platform::Place& place) const {
op_state ng_op_state = PARTIAL_TEST;
auto& bdesc = pdesc_.Block(block_);
for (auto* op : bdesc.AllOps()) {
......@@ -265,19 +265,19 @@ void FusedOperator::RunImpl(const Scope& scope,
ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN;
}
NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_,
persistables_, fetches_, post_op_inputs_,
ng_op_state);
ngraph_op.Run(scope, place);
NgraphEngine ngraph_engine(scope, place, fused_ops_, var_type_map_,
persistables_, fetches_, post_op_inputs_,
ng_op_state);
ngraph_engine.Run(scope, place);
}
std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
NgraphOperator::func_cache_ = {};
NgraphEngine::func_cache_ = {};
std::shared_ptr<ngraph::runtime::Backend> NgraphOperator::backend_ =
std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU");
void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
void NgraphEngine::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
op->RuntimeInferShape(scope_, place_);
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
......@@ -300,7 +300,7 @@ void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
}
}
void NgraphOperator::BuildNgNodes() {
void NgraphEngine::BuildNgNodes() {
for (auto& var_name : var_out_) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto* var = scope_.FindVar(var_name);
......@@ -322,7 +322,7 @@ void NgraphOperator::BuildNgNodes() {
}
}
void NgraphOperator::BuildNgIO() {
void NgraphEngine::BuildNgIO() {
std::unordered_set<std::string> inputs;
std::unordered_set<std::string> outputs;
......@@ -394,7 +394,7 @@ void NgraphOperator::BuildNgIO() {
}
}
void NgraphOperator::BuildNgFunction() {
void NgraphEngine::BuildNgFunction() {
BuildNgNodes();
ngraph_function_ = nullptr;
ngraph::NodeVector func_outputs;
......@@ -415,7 +415,7 @@ void NgraphOperator::BuildNgFunction() {
std::make_shared<ngraph::Function>(func_outputs, func_inputs);
}
std::shared_ptr<std::string> NgraphOperator::GetCacheKey() {
std::shared_ptr<std::string> NgraphEngine::GetCacheKey() {
auto cache_key = std::make_shared<std::string>("");
*cache_key += std::to_string(fused_ops_.size());
for (auto& op : fused_ops_) {
......@@ -443,7 +443,7 @@ std::shared_ptr<std::string> NgraphOperator::GetCacheKey() {
return cache_key;
}
void NgraphOperator::GetNgFunction() {
void NgraphEngine::GetNgFunction() {
bool cache_on = true;
if (cache_on) {
std::string cache_key_val = *GetCacheKey();
......@@ -458,8 +458,7 @@ void NgraphOperator::GetNgFunction() {
}
}
void NgraphOperator::Run(const Scope& scope,
const platform::Place& place) const {
void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const {
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out;
......@@ -544,6 +543,6 @@ void NgraphOperator::Run(const Scope& scope,
}
backend_->call(ngraph_function_, t_out, t_in);
} // NgraphOperator::RunImpl
} // NgraphEngine::RunImpl
} // namespace framework
} // namespace paddle
......@@ -32,14 +32,14 @@ limitations under the License. */
namespace paddle {
namespace framework {
class FusedOperator : public OperatorBase {
class NgraphOperator : public OperatorBase {
public:
static std::vector<
std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
FusedOpIntervals(
NgraphOpIntervals(
std::vector<std::unique_ptr<paddle::framework::OperatorBase>>* ops);
explicit FusedOperator(
explicit NgraphOperator(
const ProgramDesc& prog, size_t block_id,
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册