“2d45b245384f0f6e71695329c384c33ae37fb44a”上不存在“doc/howto/usage/k8s/k8s_en.html”
提交 e3c37bd5 编写于 作者: B baojun 提交者: tensor-tang

remove const_cast and refactor ngraph engine code (#15925)

* remove concast_cast and refactor code test=develop

* reduce flag use test=develop
上级 09799566
...@@ -34,11 +34,11 @@ limitations under the License. */ ...@@ -34,11 +34,11 @@ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
#include "paddle/fluid/operators/ngraph/ngraph_engine.h" #include "paddle/fluid/operators/ngraph/ngraph_engine.h"
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
#endif #endif
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -194,9 +194,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -194,9 +194,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool force_disable_gc) { bool force_disable_gc) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc); if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc);
#endif
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc); auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars); RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
} }
...@@ -372,6 +369,12 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( ...@@ -372,6 +369,12 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) {
paddle::operators::NgraphEngine::FuseNgraphOps(
ctx->prog_.Block(ctx->block_id_), &ctx->ops_);
}
#endif
return ctx; return ctx;
} }
......
...@@ -29,7 +29,6 @@ limitations under the License. */ ...@@ -29,7 +29,6 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h" #include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/operators/ngraph/ngraph_engine.h" #include "paddle/fluid/operators/ngraph/ngraph_engine.h"
...@@ -42,44 +41,75 @@ static ngraph::Shape Ddim2Shape(const framework::DDim& dims) { ...@@ -42,44 +41,75 @@ static ngraph::Shape Ddim2Shape(const framework::DDim& dims) {
for (int i = 0; i < dims.size(); ++i) { for (int i = 0; i < dims.size(); ++i) {
int k = dims[i]; int k = dims[i];
k = k == 0 ? 1 : k; k = k == 0 ? 1 : k;
sp.push_back(k); sp.emplace_back(k);
} }
return sp; return sp;
} }
static framework::DDim Shape2Ddim(const ngraph::Shape& shape) {
std::vector<int64_t> dims;
for (size_t i = 0; i < shape.size(); ++i) {
int64_t k = shape[i];
dims.emplace_back(k);
}
return framework::make_ddim(dims);
}
static std::map<framework::proto::VarType::Type, ngraph::element::Type> static std::map<framework::proto::VarType::Type, ngraph::element::Type>
pd2ng_type_map = { pd2ng_type_map = {
{framework::proto::VarType::FP32, ngraph::element::f32}, {framework::proto::VarType::FP32, ngraph::element::f32},
{framework::proto::VarType::FP64, ngraph::element::f64}, {framework::proto::VarType::FP64, ngraph::element::f64},
{framework::proto::VarType::INT32, ngraph::element::i32}, {framework::proto::VarType::INT32, ngraph::element::i32},
{framework::proto::VarType::INT64, ngraph::element::i64}, {framework::proto::VarType::INT64, ngraph::element::i64},
{framework::proto::VarType::BOOL, ngraph::element::boolean}, {framework::proto::VarType::BOOL, ngraph::element::boolean}};
};
static std::map<ngraph::element::Type, framework::proto::VarType::Type>
std::unordered_map<std::string, std::shared_ptr<ngraph::Function>> ng2pd_type_map = {
NgraphEngine::func_cache_ = {}; {ngraph::element::f32, framework::proto::VarType::FP32},
{ngraph::element::f64, framework::proto::VarType::FP64},
{ngraph::element::i32, framework::proto::VarType::INT32},
{ngraph::element::i64, framework::proto::VarType::INT64},
{ngraph::element::boolean, framework::proto::VarType::BOOL}};
std::vector<std::string> NgraphEngine::feed_vars = {};
std::vector<std::string> NgraphEngine::fetch_vars = {};
framework::Variable* NgraphEngine::pre_var_ptr = nullptr;
const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr;
std::unordered_map<std::string, EngineCache> NgraphEngine::engine_cache = {};
std::unordered_map<std::string,
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>
NgraphEngine::t_in_cache_ = {};
std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ = std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU"); ngraph::runtime::Backend::create("CPU");
static std::vector<std::vector<int>> NgraphOpIntervals( static std::vector<std::vector<int>> NgraphOpIntervals(
framework::BlockDesc* block) { std::vector<std::unique_ptr<framework::OperatorBase>>* ops) {
NgraphEngine::feed_vars.clear();
NgraphEngine::fetch_vars.clear();
std::vector<std::vector<int>> intervals; std::vector<std::vector<int>> intervals;
auto ops = block->AllOps();
int size = ops.size(); int size = ops->size();
int left = 0; int left = 0;
while (left < size && ops.at(left)->Type() != framework::kFeedOpType) { while (left < size && ops->at(left)->Type() != framework::kFeedOpType) {
++left; ++left;
} }
if (left == size) { if (left == size) {
return intervals; return intervals;
} }
while (left < size && ops.at(left)->Type() == framework::kFeedOpType) {
while (left < size && ops->at(left)->Type() == framework::kFeedOpType) {
for (auto& var_name_item : ops->at(left)->Outputs()) {
for (auto& var_name : var_name_item.second) {
NgraphEngine::feed_vars.emplace_back(var_name);
}
}
++left; ++left;
} }
int right = left; int right = left;
while (right < size && ops.at(right)->Type() != framework::kFetchOpType) { while (right < size && ops->at(right)->Type() != framework::kFetchOpType) {
++right; ++right;
} }
if (right == size) { if (right == size) {
...@@ -87,85 +117,124 @@ static std::vector<std::vector<int>> NgraphOpIntervals( ...@@ -87,85 +117,124 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
} }
if (left >= right) return intervals; if (left >= right) return intervals;
int index = right;
while (index < size && ops->at(index)->Type() == framework::kFetchOpType) {
for (auto& var_name_item : ops->at(index)->Inputs()) {
for (auto& var_name : var_name_item.second) {
NgraphEngine::fetch_vars.emplace_back(var_name);
}
}
++index;
}
// (left, right - 1) represents indices between feed and fetch // (left, right - 1) represents indices between feed and fetch
int pivot = left; int pivot = left;
while (pivot < right) { while (pivot < right) {
auto op_type = ops.at(pivot)->Type(); auto op_type = ops->at(pivot)->Type();
if (NgraphBridge::isRegister(op_type)) { if (NgraphBridge::isRegister(op_type)) {
++pivot; ++pivot;
} else { } else {
int start = pivot, end = start; int start = pivot, end = start;
while (pivot < right && while (pivot < right &&
(!NgraphBridge::isRegister(ops.at(pivot)->Type()))) { (!NgraphBridge::isRegister(ops->at(pivot)->Type()))) {
++pivot; ++pivot;
++end; ++end;
} }
std::vector<int> interval = {start, end}; std::vector<int> interval = {start, end};
intervals.push_back(interval); intervals.emplace_back(interval);
} }
} // end while } // end while
return intervals; return intervals;
} }
static void SubstituteNgraphOp(framework::BlockDesc* block, static void SubstituteNgraphOp(
std::string block_str, std::vector<std::unique_ptr<framework::OperatorBase>>* ops,
std::vector<int> interval) { std::string engine_key, std::string block_str, std::vector<int> interval) {
framework::ProgramDesc program; framework::OpDesc ng_op_desc(nullptr);
block->RemoveOp(interval.at(0), interval.at(1)); ng_op_desc.SetType("ngraph_engine");
auto* ng_op = block->InsertOp(interval.at(0)); ng_op_desc.SetAttr("interval", interval);
ng_op->SetType("ngraph_engine"); ng_op_desc.SetAttr("engine_key", engine_key);
ng_op->SetAttr("interval", interval); ng_op_desc.SetAttr("graph", block_str);
ng_op->SetAttr("graph", block_str);
ops->erase(ops->begin() + interval[0], ops->begin() + interval[1]);
ops->insert(ops->begin() + interval[0],
framework::OpRegistry::CreateOp(ng_op_desc));
} }
// TODO(baojun-nervana): Move EnableNgraph to compile time per PR #15089 std::string SerializedBlock(const std::vector<framework::OpDesc*>& op_descs) {
void NgraphEngine::EnableNgraph(const framework::ProgramDesc& program) { framework::proto::BlockDesc block_proto;
#ifdef PADDLE_WITH_NGRAPH framework::BlockDesc block_desc(nullptr, &block_proto);
VLOG(4) << "use_ngraph=True"; block_desc.Proto()->set_parent_idx(-1);
for (size_t bid = 0; bid < program.Size(); ++bid) { block_desc.Proto()->set_idx(0);
// TODO(baojun-nervana): Remove the const_cast
auto* block = for (auto* op_desc : op_descs) {
const_cast<framework::ProgramDesc&>(program).MutableBlock(bid); auto* op = block_desc.AppendOp();
std::string block_str = block->Proto()->SerializeAsString(); *op->Proto() = *op_desc->Proto();
auto intervals = NgraphOpIntervals(block); }
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) { return block_desc.Proto()->SerializeAsString();
SubstituteNgraphOp(block, block_str, *it); }
}
std::string GenerateEngineKey(const framework::BlockDesc& bdesc) {
framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto);
block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0);
for (auto& op_desc : bdesc.AllOps()) {
auto* op = block_desc.AppendOp();
*op->Proto() = *op_desc->Proto();
}
auto engine_key = std::to_string(
std::hash<std::string>()(block_desc.Proto()->SerializeAsString()));
return engine_key;
}
std::string GenerateEngineKey(const std::vector<std::string>& engine_inputs,
const std::vector<std::string>& engine_outputs,
int size) {
std::string engine_hash_key = "";
for (auto name : engine_inputs) {
engine_hash_key += name;
}
for (auto name : engine_outputs) {
engine_hash_key += name;
}
engine_hash_key += std::to_string(size);
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
return engine_key;
}
void NgraphEngine::FuseNgraphOps(
const framework::BlockDesc& block_desc,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops) {
NgraphEngine::p_bdesc = &block_desc;
auto intervals = NgraphOpIntervals(ops);
std::string engine_key =
GenerateEngineKey(feed_vars, fetch_vars, ops->size());
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
SubstituteNgraphOp(ops, engine_key, "", *it);
} }
#else
LOG(WARNING)
<< "'NGRAPH' is not supported, Please re-compile with WITH_NGRAPH option";
#endif
} }
NgraphEngine::NgraphEngine(const framework::Scope& scope, NgraphEngine::NgraphEngine(const framework::Scope& scope,
const platform::Place& place, const platform::Place& place,
const std::string& serialized_graph, const framework::ExecutionContext& ctx)
const std::vector<int>& interval)
: scope_(scope), place_(place) { : scope_(scope), place_(place) {
std::string serialized_graph = ctx.Attr<std::string>("graph");
auto interval = ctx.Attr<std::vector<int>>("interval");
std::string engine_key = ctx.Attr<std::string>("engine_key");
var_in_node_map_ = std::make_shared< var_in_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>(); std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
var_node_map_ = std::make_shared< var_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>(); std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
func_cache_key_ = std::to_string(interval[0]) + std::to_string(interval[1]) + GetNgFunction(engine_key, interval);
serialized_graph;
framework::proto::BlockDesc bdesc;
bdesc.ParseFromString(serialized_graph);
framework::BlockDesc block(nullptr, &bdesc);
Prepare(block, interval);
BuildNgIO();
GetNgFunction();
} }
void NgraphEngine::Prepare(const framework::BlockDesc& block, void NgraphEngine::Prepare(const std::vector<int>& interval) {
const std::vector<int>& interval) { for (auto& var : p_bdesc->AllVars()) {
for (auto& var : block.AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS || if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
var->GetType() == framework::proto::VarType::LOD_TENSOR || var->GetType() == framework::proto::VarType::LOD_TENSOR ||
var->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY)) { var->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY)) {
...@@ -192,108 +261,57 @@ void NgraphEngine::Prepare(const framework::BlockDesc& block, ...@@ -192,108 +261,57 @@ void NgraphEngine::Prepare(const framework::BlockDesc& block,
} }
} }
auto ops_desc = block.AllOps(); std::vector<paddle::framework::OpDesc*> ops_desc;
int idx = interval[0]; for (auto op_desc : p_bdesc->AllOps()) {
while (idx < interval[1]) { ops_desc.emplace_back(op_desc);
auto op_desc = ops_desc.at(idx);
auto op = framework::OpRegistry::CreateOp(*op_desc);
fused_ops_.push_back(std::move(op));
++idx;
}
while (ops_desc.at(idx)->Type() != framework::kFetchOpType) {
auto op_desc = ops_desc.at(idx);
for (auto& var_name_item : op_desc->Inputs()) {
for (auto& var_name : var_name_item.second) {
post_op_inputs_.insert(var_name);
}
}
++idx;
}
while (idx < static_cast<int>(ops_desc.size()) &&
ops_desc.at(idx)->Type() == framework::kFetchOpType) {
std::string fetch_target_name = ops_desc.at(idx)->Input("X")[0];
fetches_.insert(fetch_target_name);
++idx;
}
if (ops_desc.at(interval.at(0) - 1)->Type() == framework::kFeedOpType &&
ops_desc.at(interval.at(1))->Type() == framework::kFetchOpType) {
ng_op_state_ = OpState::FULL;
} }
for (auto* op_desc : ops_desc) { for (auto op_desc : ops_desc) {
if (op_desc->Type().find("_grad") != std::string::npos) { if (op_desc->Type().find("_grad") != std::string::npos) {
ng_op_state_ = ng_op_state_ == OpState::FULL ? OpState::FULL_TRAIN this->is_test_ = false;
: OpState::PARTIAL_TRAIN;
break; break;
} }
} }
if (ng_op_state_ != OpState::FULL_TRAIN && if (interval[0] > 0 &&
ng_op_state_ != OpState::PARTIAL_TRAIN) { ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType &&
ng_op_state_ = ng_op_state_ == OpState::FULL ? OpState::FULL_TEST interval[1] < static_cast<int>(ops_desc.size()) &&
: OpState::PARTIAL_TEST; ops_desc.at(interval.at(1))->Type() == framework::kFetchOpType) {
this->op_state_ = OpState::FULL;
} }
}
void NgraphEngine::GetNgInputShape( if (this->op_state_ == OpState::FULL) {
std::shared_ptr<framework::OperatorBase> op) { this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN;
framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_); } else {
op->RuntimeInferShape(scope_, place_, ctx); this->op_state_ =
for (auto& var_name_item : op->Inputs()) { this->is_test_ ? OpState::PARTIAL_TEST : OpState::PARTIAL_TRAIN;
for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto sp = Ddim2Shape(tensor_pd->dims());
if (std::find(var_in_.begin(), var_in_.end(), var_name) !=
var_in_.end()) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
// auto ng_type = pd2ng_type_map.at(GetDataTypeOfVar(var));
auto ng_type = var_type_map_.at(var_name);
auto prm =
std::make_shared<ngraph::op::Parameter>(ng_type, sp, true);
(*var_node_map_)[var_name] = prm;
(*var_in_node_map_)[var_name] = prm;
}
}
}
}
} }
}
void NgraphEngine::BuildNgNodes() { int idx = interval[0];
for (auto& op : fused_ops_) { while (idx < interval[1]) {
for (auto& var_name_item : op->Outputs()) { this->fused_ops_.emplace_back(
framework::OpRegistry::CreateOp(*(ops_desc[idx])));
++idx;
}
while (ops_desc.at(idx)->Type() != framework::kFetchOpType) {
auto op_desc = ops_desc.at(idx);
for (auto& var_name_item : op_desc->Inputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
if (var_node_map_->find(var_name) == var_node_map_->end()) { this->post_op_inputs_.insert(var_name);
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims();
auto ng_shape = Ddim2Shape(ddim);
auto ng_type = var_type_map_.at(var_name);
auto prm = std::make_shared<ngraph::op::Parameter>(ng_type,
ng_shape, true);
(*var_node_map_)[var_name] = prm;
}
}
} }
} }
++idx;
} }
NgraphBridge ngb(var_node_map_);
for (auto& op : fused_ops_) { BuildNgIO(ops_desc, interval);
ngb.BuildNgNode(op);
}
} }
void NgraphEngine::BuildNgIO() { void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
const std::vector<int>& interval) {
std::unordered_set<std::string> inputs; std::unordered_set<std::string> inputs;
std::unordered_set<std::string> outputs; std::unordered_set<std::string> outputs;
for (int i = interval[0]; i < interval[1]; ++i) {
for (auto& op : fused_ops_) { auto op = ops_desc[i];
for (auto& var_name_item : op->Inputs()) { for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
inputs.insert(var_name); inputs.insert(var_name);
...@@ -302,15 +320,11 @@ void NgraphEngine::BuildNgIO() { ...@@ -302,15 +320,11 @@ void NgraphEngine::BuildNgIO() {
std::find(var_in_.begin(), var_in_.end(), var_name) == std::find(var_in_.begin(), var_in_.end(), var_name) ==
var_in_.end()) { var_in_.end()) {
// fill var_in here to keep lhs and rhs order // fill var_in here to keep lhs and rhs order
var_in_.push_back(var_name); this->var_in_.emplace_back(var_name);
} }
} }
} }
if (op->Type() != "fill_constant") {
GetNgInputShape(op);
}
for (auto& var_name_item : op->Outputs()) { for (auto& var_name_item : op->Outputs()) {
PADDLE_ENFORCE_LE(var_name_item.second.size(), 1, PADDLE_ENFORCE_LE(var_name_item.second.size(), 1,
"op %s has more than 1 output - Not handling yet", "op %s has more than 1 output - Not handling yet",
...@@ -322,172 +336,278 @@ void NgraphEngine::BuildNgIO() { ...@@ -322,172 +336,278 @@ void NgraphEngine::BuildNgIO() {
} }
// var_out.clear(); // var_out.clear();
for (auto& op : fused_ops_) { for (int i = interval[0]; i < interval[1]; ++i) {
auto op = ops_desc[i];
for (auto& var_name_item : op->Outputs()) { for (auto& var_name_item : op->Outputs()) {
PADDLE_ENFORCE_LE(var_name_item.second.size(), 1, PADDLE_ENFORCE_LE(var_name_item.second.size(), 1,
"op %s has more than 1 output - Not handling yet", "op %s has more than 1 output - Not handling yet",
op->Type()); op->Type());
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
switch (ng_op_state_) { switch (this->op_state_) {
case OpState::PARTIAL_TEST: case OpState::PARTIAL_TEST:
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() || if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
fetches_.find(var_name) != fetches_.end()) { find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
var_out_.push_back(var_name); fetch_vars.end()) {
this->var_out_.emplace_back(var_name);
} }
break; break;
case OpState::FULL_TEST: case OpState::FULL_TEST:
if (fetches_.find(var_name) != fetches_.end()) { if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
var_out_.push_back(var_name); fetch_vars.end()) {
this->var_out_.emplace_back(var_name);
} }
break; break;
case OpState::PARTIAL_TRAIN: case OpState::PARTIAL_TRAIN:
if (fetches_.find(var_name) != fetches_.end() || if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end() ||
post_op_inputs_.find(var_name) != post_op_inputs_.end() || post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
persistables_.find(var_name) != persistables_.end()) { persistables_.find(var_name) != persistables_.end()) {
var_out_.push_back(var_name); this->var_out_.emplace_back(var_name);
} }
break; break;
case OpState::FULL_TRAIN: case OpState::FULL_TRAIN:
if (fetches_.find(var_name) != fetches_.end() || if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end() ||
persistables_.find(var_name) != persistables_.end()) { persistables_.find(var_name) != persistables_.end()) {
var_out_.push_back(var_name); this->var_out_.emplace_back(var_name);
} }
break; break;
default: default:
var_out_.push_back(var_name); this->var_out_.emplace_back(var_name);
} }
} }
} }
} }
for (size_t i = 0; i < var_in_.size(); ++i) {
auto var_name = var_in_[i];
if (persistables_.find(var_name) == persistables_.end()) {
var_in_updates_.emplace_back(i);
}
}
} }
void NgraphEngine::BuildNgFunction() { void NgraphEngine::GetNgInputShape() {
for (auto& var_name : var_in_) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto sp = Ddim2Shape(tensor_pd->dims());
auto ng_type = var_type_map_[var_name];
auto prm = std::make_shared<ngraph::op::Parameter>(ng_type, sp, true);
(*var_node_map_)[var_name] = prm;
(*var_in_node_map_)[var_name] = prm;
}
}
}
void NgraphEngine::BuildNgNodes() {
for (auto& op : fused_ops_) {
for (auto& var_name_item : op->Outputs()) {
for (auto& var_name : var_name_item.second) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims();
auto ng_shape = Ddim2Shape(ddim);
auto ng_type = var_type_map_[var_name];
auto prm = std::make_shared<ngraph::op::Parameter>(ng_type,
ng_shape, true);
(*var_node_map_)[var_name] = prm;
}
}
}
}
}
NgraphBridge ngb(var_node_map_);
for (auto& op : fused_ops_) {
ngb.BuildNgNode(op);
}
}
void NgraphEngine::RunInferShape() {
for (auto& op : fused_ops_) {
framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_);
op->RuntimeInferShape(scope_, place_, ctx);
}
}
void NgraphEngine::BuildNgFunction(const std::vector<int>& interval) {
Prepare(interval);
RunInferShape();
GetNgInputShape();
BuildNgNodes(); BuildNgNodes();
ngraph_function_ = nullptr; ngraph_function_ = nullptr;
ngraph::NodeVector func_outputs; ngraph::NodeVector func_outputs;
ngraph::ParameterVector func_inputs; ngraph::ParameterVector func_inputs;
for (auto& vo : var_out_) { for (auto& vo : var_out_) {
func_outputs.push_back(var_node_map_->at(vo)); func_outputs.emplace_back(var_node_map_->at(vo));
} }
for (auto& vi : var_in_) { for (auto& vi : var_in_) {
std::shared_ptr<ngraph::op::Parameter> prm = std::shared_ptr<ngraph::op::Parameter> prm =
std::dynamic_pointer_cast<ngraph::op::Parameter>( std::dynamic_pointer_cast<ngraph::op::Parameter>(
var_in_node_map_->at(vi)); var_in_node_map_->at(vi));
func_inputs.push_back(prm); func_inputs.emplace_back(prm);
} }
ngraph_function_ = ngraph_function_ =
std::make_shared<ngraph::Function>(func_outputs, func_inputs); std::make_shared<ngraph::Function>(func_outputs, func_inputs);
} }
void NgraphEngine::GetNgFunction() { void NgraphEngine::GetNgFunction(std::string engine_key,
bool cache_on = true; const std::vector<int>& interval) {
if (cache_on) { bool use_cache = true;
std::string input_shape_str; if (use_cache) {
for (auto& var_name : var_in_) { this->func_cache_key_ = "";
auto shape = var_node_map_->at(var_name)->get_shape(); for (int i = 0; i < std::min(static_cast<int>(feed_vars.size()), 10); ++i) {
for (size_t i = 0; i < shape.size(); ++i) { auto* var = scope_.FindVar(feed_vars[i]);
input_shape_str += std::to_string(shape.at(i)); if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto dims = tensor_pd->dims();
for (int j = 0; j < dims.size(); ++j) {
func_cache_key_ += std::to_string(dims[j]);
}
} }
} }
func_cache_key_ = input_shape_str + func_cache_key_; func_cache_key_ += std::to_string(interval[0]) + "_" +
if (func_cache_.find(func_cache_key_) != func_cache_.end()) { std::to_string(interval[1]) + engine_key;
ngraph_function_ = func_cache_.at(func_cache_key_); func_cache_key_ = std::to_string(std::hash<std::string>()(func_cache_key_));
} else {
BuildNgFunction(); if (engine_cache.find(func_cache_key_) != engine_cache.end()) {
func_cache_[func_cache_key_] = ngraph_function_; if (engine_cache[func_cache_key_].persistables.size() == 0) {
engine_cache.clear();
t_in_cache_.clear();
} else {
auto var_name = engine_cache[func_cache_key_].persistables.begin();
framework::Variable* var = scope_.FindVar(*var_name);
if (var != pre_var_ptr) {
engine_cache.clear();
t_in_cache_.clear();
}
pre_var_ptr = var;
}
}
if (engine_cache.find(func_cache_key_) == engine_cache.end()) {
BuildNgFunction(interval);
engine_cache[func_cache_key_].ngraph_function = this->ngraph_function_;
engine_cache[func_cache_key_].persistables = this->persistables_;
engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_;
engine_cache[func_cache_key_].var_in = this->var_in_;
engine_cache[func_cache_key_].var_out = this->var_out_;
engine_cache[func_cache_key_].is_test = this->is_test_;
} }
} else { } else {
BuildNgFunction(); BuildNgFunction(interval);
} }
} }
void NgraphEngine::Run(const framework::Scope& scope, void NgraphEngine::Run(const framework::Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in; std::shared_ptr<ngraph::Function> ng_func;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out; const std::set<std::string>* p_persistables;
const std::vector<size_t>* p_var_in_updates;
const std::vector<std::string>* p_var_in;
const std::vector<std::string>* p_var_out;
bool is_test;
bool use_cache = true;
if (use_cache) {
PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(),
"Cannot find cached data to run ngraph function");
ng_func = engine_cache[func_cache_key_].ngraph_function;
p_persistables = &(engine_cache[func_cache_key_].persistables);
p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates);
p_var_in = &(engine_cache[func_cache_key_].var_in);
p_var_out = &(engine_cache[func_cache_key_].var_out);
is_test = engine_cache[func_cache_key_].is_test;
} else {
ng_func = ngraph_function_;
p_persistables = &this->persistables_;
p_var_in_updates = &this->var_in_updates_;
p_var_in = &this->var_in_;
p_var_out = &this->var_out_;
is_test = this->is_test_;
}
for (size_t i = 0; i < var_in_.size(); ++i) { std::vector<std::shared_ptr<ngraph::runtime::Tensor>>* p_t_in;
auto vi = var_in_.at(i); std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in = {};
auto sp = var_node_map_->at(vi)->get_shape();
std::shared_ptr<ngraph::runtime::Tensor> ti; auto m_parameters = ng_func->get_parameters();
auto* var = scope.FindVar(vi); auto m_results = ng_func->get_results();
if (var && var->IsType<framework::LoDTensor>()) { if (is_test && use_cache &&
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var); t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) {
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()), p_t_in = &(t_in_cache_[func_cache_key_]);
"Ensure ngraph tensor layout align with paddle tensor"); for (size_t i = 0; i < p_var_in_updates->size(); ++i) {
auto ng_type = var_type_map_.at(vi); int index = p_var_in_updates->at(i);
if (ng_type == ngraph::element::f32) { auto vi = p_var_in->at(index);
auto pd_arr = tensor_pd->mutable_data<float>(place); auto sp = m_parameters[index]->get_shape();
ti = backend_->create_tensor(ngraph::element::f32, sp, pd_arr); auto ng_type = m_parameters[index]->get_element_type();
} else if (ng_type == ngraph::element::i32) { std::shared_ptr<ngraph::runtime::Tensor> ti;
const int* arr = tensor_pd->data<int>(); auto* var = scope.FindVar(vi);
ti = backend_->create_tensor(ngraph::element::i32, sp, if (var && var->IsType<framework::LoDTensor>()) {
const_cast<int*>(arr)); auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
} else if (ng_type == ngraph::element::i64) { void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
auto pd_arr = tensor_pd->mutable_data<int64_t>(place); ti = backend_->create_tensor(ng_type, sp, pd_arr);
ti = backend_->create_tensor(ngraph::element::i64, sp, pd_arr); (*p_t_in)[index] = ti;
} else if (ng_type == ngraph::element::f64) {
auto pd_arr = tensor_pd->mutable_data<double>(place);
ti = backend_->create_tensor(ngraph::element::f64, sp, pd_arr);
} else if (ng_type == ngraph::element::boolean) {
auto pd_arr = tensor_pd->mutable_data<bool>(place);
ti = backend_->create_tensor(ngraph::element::boolean, sp, pd_arr);
} else { } else {
PADDLE_THROW("Data type not handling for var %s", vi); PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
} }
}
} else {
if (is_test && use_cache) {
p_t_in = &(t_in_cache_[func_cache_key_]);
} else { } else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi); p_t_in = &t_in;
} }
bool is_test = (ng_op_state_ == OpState::PARTIAL_TEST ||
ng_op_state_ == OpState::FULL_TEST) for (size_t i = 0; i < p_var_in->size(); ++i) {
? true auto vi = p_var_in->at(i);
: false; auto sp = m_parameters[i]->get_shape();
bool is_persistable = auto ng_type = m_parameters[i]->get_element_type();
(persistables_.find(vi) != persistables_.end()) ? true : false; std::shared_ptr<ngraph::runtime::Tensor> ti;
if (is_test && is_persistable) { auto* var = scope.FindVar(vi);
ti->set_stale(false); if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()),
"Ensure ngraph tensor layout align with paddle tensor");
ti = backend_->create_tensor(ng_type, sp, pd_arr);
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
}
bool is_persistable =
(p_persistables->find(vi) != p_persistables->end()) ? true : false;
if (is_test && is_persistable) {
ti->set_stale(false);
}
(*p_t_in).emplace_back(ti);
} }
t_in.push_back(ti);
} }
for (size_t i = 0; i < var_out_.size(); ++i) { std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out = {};
auto vo = var_out_[i]; for (size_t i = 0; i < p_var_out->size(); ++i) {
auto vo = p_var_out->at(i);
auto* var = scope.FindVar(vo); auto* var = scope.FindVar(vo);
std::shared_ptr<ngraph::runtime::Tensor> to;
if (var && var->IsType<framework::LoDTensor>()) { if (var && var->IsType<framework::LoDTensor>()) {
auto sp = m_results[i]->get_shape();
var->GetMutable<framework::LoDTensor>()->Resize(Shape2Ddim(sp));
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var); auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
auto dd = tensor_pd->dims(); auto ng_type = m_results[i]->get_element_type();
ngraph::Shape sp = Ddim2Shape(dd); void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
auto ng_type = var_type_map_.at(vo); std::shared_ptr<ngraph::runtime::Tensor> to =
if (ng_type == ngraph::element::f32) { backend_->create_tensor(ng_type, sp, pd_arr);
auto pd_arr = tensor_pd->mutable_data<float>(place); t_out.emplace_back(to);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else if (ng_type == ngraph::element::i64) {
auto pd_arr = tensor_pd->mutable_data<int64_t>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else if (ng_type == ngraph::element::i32) {
auto pd_arr = tensor_pd->mutable_data<int>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else if (ng_type == ngraph::element::f64) {
auto pd_arr = tensor_pd->mutable_data<double>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else if (ng_type == ngraph::element::boolean) {
auto pd_arr = tensor_pd->mutable_data<bool>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else {
PADDLE_THROW("Data type not handled in for var %s", vo);
}
t_out.push_back(to);
} else { } else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vo); PADDLE_THROW("Cannot find var or tensor with var name %s", vo);
} }
} }
auto handle = backend_->compile(ngraph_function_); auto handle = backend_->compile(ng_func);
handle->call_with_validate(t_out, t_in); handle->call_with_validate(t_out, *p_t_in);
} // NgraphEngine::Run } // NgraphEngine::Run
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,12 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
#define PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
#include <memory>
#include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
...@@ -33,29 +39,47 @@ enum class OpState { /* nGraph support state on ops */ ...@@ -33,29 +39,47 @@ enum class OpState { /* nGraph support state on ops */
UNKNOWN /* Output all for debug purpose */ UNKNOWN /* Output all for debug purpose */
}; };
// cache engine repetitives
struct EngineCache {
std::shared_ptr<ngraph::Function> ngraph_function;
std::set<std::string> persistables;
std::vector<std::string> var_in;
std::vector<std::string> var_out;
std::vector<size_t> var_in_updates;
bool is_test = true;
};
// perform graph build through bridge and execute computation // perform graph build through bridge and execute computation
class NgraphEngine { class NgraphEngine {
public: public:
explicit NgraphEngine(const framework::Scope& scope, explicit NgraphEngine(const framework::Scope& scope,
const platform::Place& place, const platform::Place& place,
const std::string& serialized_graph, const framework::ExecutionContext& ctx);
const std::vector<int>& interval);
void Run(const framework::Scope& scope, const platform::Place& place) const; void Run(const framework::Scope& scope, const platform::Place& place) const;
static void EnableNgraph(const framework::ProgramDesc& program); static const framework::BlockDesc* p_bdesc;
static std::vector<std::string> feed_vars, fetch_vars;
static void FuseNgraphOps(
const framework::BlockDesc& prog,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops);
private: private:
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>> static std::unordered_map<std::string, EngineCache> engine_cache;
func_cache_; static std::unordered_map<
std::string, std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>
t_in_cache_;
static framework::Variable* pre_var_ptr;
const framework::Scope& scope_; const framework::Scope& scope_;
const platform::Place& place_; const platform::Place& place_;
std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_; std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map_; std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::unordered_set<std::string> persistables_; std::set<std::string> persistables_;
std::unordered_set<std::string> fetches_;
std::unordered_set<std::string> post_op_inputs_; std::unordered_set<std::string> post_op_inputs_;
OpState ng_op_state_ = OpState::UNKNOWN; OpState op_state_ = OpState::UNKNOWN;
bool is_test_{true};
std::string func_cache_key_; std::string func_cache_key_;
// ngraph backend eg. CPU // ngraph backend eg. CPU
...@@ -66,6 +90,8 @@ class NgraphEngine { ...@@ -66,6 +90,8 @@ class NgraphEngine {
std::vector<std::string> var_in_; std::vector<std::string> var_in_;
// var_name of outputs from fetch in order // var_name of outputs from fetch in order
std::vector<std::string> var_out_; std::vector<std::string> var_out_;
// non-persitable var_in
std::vector<size_t> var_in_updates_;
// map input vars to nodes // map input vars to nodes
std::shared_ptr< std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
...@@ -74,20 +100,23 @@ class NgraphEngine { ...@@ -74,20 +100,23 @@ class NgraphEngine {
std::shared_ptr< std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_node_map_; var_node_map_;
// prepare info for nraph engine // prepare info for ngraph engine need
void Prepare(const framework::BlockDesc& block, void Prepare(const std::vector<int>& interval);
const std::vector<int>& interval); // get ngraph engine input and output list
void BuildNgIO(const std::vector<framework::OpDesc*>& op_descs,
const std::vector<int>& interval);
// get ngraph input and define ngraph input parameters // get ngraph input and define ngraph input parameters
void GetNgInputShape(std::shared_ptr<framework::OperatorBase> op); void GetNgInputShape();
// Call ngraph bridge to map ops // Call ngraph bridge to map ops
void BuildNgNodes(); void BuildNgNodes();
// get the ngraph input and output var list // run paddle RuntimeInferShape to get the tensor shape
void BuildNgIO(); void RunInferShape();
// build ngraph function call // build ngraph function call
void BuildNgFunction(); void BuildNgFunction(const std::vector<int>& interval);
// Check cache for ngraph function or otherwise build the function // Check cache for ngraph function or otherwise build the function
void GetNgFunction(); void GetNgFunction(std::string engine_key, const std::vector<int>& interval);
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif // PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
...@@ -29,6 +29,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -29,6 +29,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Xs", "A list of inputs.").AsDispensable(); AddInput("Xs", "A list of inputs.").AsDispensable();
AddOutput("Ys", "A list of outputs").AsDispensable(); AddOutput("Ys", "A list of outputs").AsDispensable();
AddAttr<std::string>("graph", "the graph."); AddAttr<std::string>("graph", "the graph.");
AddAttr<std::string>("engine_key", "the engine hash key.");
AddAttr<std::vector<int>>("interval", "op interval supported by ngraph"); AddAttr<std::vector<int>>("interval", "op interval supported by ngraph");
AddComment("ngraph engine operator."); AddComment("ngraph engine operator.");
} }
......
...@@ -46,10 +46,8 @@ class NgraphEngineKernel : public framework::OpKernel<T> { ...@@ -46,10 +46,8 @@ class NgraphEngineKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& scope = ctx.scope(); auto& scope = ctx.scope();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
std::string serialized_graph = ctx.Attr<std::string>("graph");
auto interval = ctx.Attr<std::vector<int>>("interval");
NgraphEngine ngraph_engine(scope, place, serialized_graph, interval); NgraphEngine ngraph_engine(scope, place, ctx);
ngraph_engine.Run(scope, place); ngraph_engine.Run(scope, place);
} }
}; };
......
...@@ -94,6 +94,14 @@ bool IsCompiledWithMKLDNN() { ...@@ -94,6 +94,14 @@ bool IsCompiledWithMKLDNN() {
#endif #endif
} }
bool IsCompiledWithNGRAPH() {
#ifndef PADDLE_WITH_NGRAPH
return false;
#else
return true;
#endif
}
bool IsCompiledWithBrpc() { bool IsCompiledWithBrpc() {
#ifndef PADDLE_WITH_DISTRIBUTE #ifndef PADDLE_WITH_DISTRIBUTE
return false; return false;
...@@ -874,6 +882,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -874,6 +882,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_devices", m.def("init_devices",
[](bool init_p2p) { framework::InitDevices(init_p2p); }); [](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_ngraph", IsCompiledWithNGRAPH);
m.def("is_compiled_with_cuda", IsCompiledWithCUDA); m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN); m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN);
m.def("is_compiled_with_brpc", IsCompiledWithBrpc); m.def("is_compiled_with_brpc", IsCompiledWithBrpc);
......
...@@ -125,7 +125,7 @@ def __bootstrap__(): ...@@ -125,7 +125,7 @@ def __bootstrap__():
os.environ['OMP_NUM_THREADS'] = str(num_threads) os.environ['OMP_NUM_THREADS'] = str(num_threads)
sysstr = platform.system() sysstr = platform.system()
read_env_flags = [ read_env_flags = [
'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_ngraph', 'check_nan_inf', 'benchmark', 'eager_delete_scope',
'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory',
'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb', 'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb',
'fast_eager_deletion_mode', 'memory_fraction_of_eager_deletion', 'fast_eager_deletion_mode', 'memory_fraction_of_eager_deletion',
...@@ -143,6 +143,9 @@ def __bootstrap__(): ...@@ -143,6 +143,9 @@ def __bootstrap__():
if core.is_compiled_with_mkldnn(): if core.is_compiled_with_mkldnn():
read_env_flags.append('use_mkldnn') read_env_flags.append('use_mkldnn')
if core.is_compiled_with_ngraph():
read_env_flags.append('use_ngraph')
if core.is_compiled_with_dist(): if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline') read_env_flags.append('rpc_deadline')
read_env_flags.append('rpc_server_profile_path') read_env_flags.append('rpc_server_profile_path')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册