提交 e3c37bd5 编写于 作者: B baojun 提交者: tensor-tang

remove const_cast and refactor ngraph engine code (#15925)

* remove concast_cast and refactor code test=develop

* reduce flag use test=develop
上级 09799566
......@@ -34,11 +34,11 @@ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#include "paddle/fluid/operators/ngraph/ngraph_engine.h"
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
#endif
DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
DEFINE_bool(use_ngraph, false, "Use NGRAPH to run");
namespace paddle {
namespace framework {
......@@ -194,9 +194,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool force_disable_gc) {
platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc);
#endif
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
}
......@@ -372,6 +369,12 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) {
paddle::operators::NgraphEngine::FuseNgraphOps(
ctx->prog_.Block(ctx->block_id_), &ctx->ops_);
}
#endif
return ctx;
}
......
......@@ -29,7 +29,6 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/operators/ngraph/ngraph_engine.h"
......@@ -42,44 +41,75 @@ static ngraph::Shape Ddim2Shape(const framework::DDim& dims) {
for (int i = 0; i < dims.size(); ++i) {
int k = dims[i];
k = k == 0 ? 1 : k;
sp.push_back(k);
sp.emplace_back(k);
}
return sp;
}
static framework::DDim Shape2Ddim(const ngraph::Shape& shape) {
std::vector<int64_t> dims;
for (size_t i = 0; i < shape.size(); ++i) {
int64_t k = shape[i];
dims.emplace_back(k);
}
return framework::make_ddim(dims);
}
static std::map<framework::proto::VarType::Type, ngraph::element::Type>
pd2ng_type_map = {
{framework::proto::VarType::FP32, ngraph::element::f32},
{framework::proto::VarType::FP64, ngraph::element::f64},
{framework::proto::VarType::INT32, ngraph::element::i32},
{framework::proto::VarType::INT64, ngraph::element::i64},
{framework::proto::VarType::BOOL, ngraph::element::boolean},
};
std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
NgraphEngine::func_cache_ = {};
{framework::proto::VarType::BOOL, ngraph::element::boolean}};
static std::map<ngraph::element::Type, framework::proto::VarType::Type>
ng2pd_type_map = {
{ngraph::element::f32, framework::proto::VarType::FP32},
{ngraph::element::f64, framework::proto::VarType::FP64},
{ngraph::element::i32, framework::proto::VarType::INT32},
{ngraph::element::i64, framework::proto::VarType::INT64},
{ngraph::element::boolean, framework::proto::VarType::BOOL}};
std::vector<std::string> NgraphEngine::feed_vars = {};
std::vector<std::string> NgraphEngine::fetch_vars = {};
framework::Variable* NgraphEngine::pre_var_ptr = nullptr;
const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr;
std::unordered_map<std::string, EngineCache> NgraphEngine::engine_cache = {};
std::unordered_map<std::string,
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>
NgraphEngine::t_in_cache_ = {};
std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU");
static std::vector<std::vector<int>> NgraphOpIntervals(
framework::BlockDesc* block) {
std::vector<std::unique_ptr<framework::OperatorBase>>* ops) {
NgraphEngine::feed_vars.clear();
NgraphEngine::fetch_vars.clear();
std::vector<std::vector<int>> intervals;
auto ops = block->AllOps();
int size = ops.size();
int size = ops->size();
int left = 0;
while (left < size && ops.at(left)->Type() != framework::kFeedOpType) {
while (left < size && ops->at(left)->Type() != framework::kFeedOpType) {
++left;
}
if (left == size) {
return intervals;
}
while (left < size && ops.at(left)->Type() == framework::kFeedOpType) {
while (left < size && ops->at(left)->Type() == framework::kFeedOpType) {
for (auto& var_name_item : ops->at(left)->Outputs()) {
for (auto& var_name : var_name_item.second) {
NgraphEngine::feed_vars.emplace_back(var_name);
}
}
++left;
}
int right = left;
while (right < size && ops.at(right)->Type() != framework::kFetchOpType) {
while (right < size && ops->at(right)->Type() != framework::kFetchOpType) {
++right;
}
if (right == size) {
......@@ -87,85 +117,124 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
}
if (left >= right) return intervals;
int index = right;
while (index < size && ops->at(index)->Type() == framework::kFetchOpType) {
for (auto& var_name_item : ops->at(index)->Inputs()) {
for (auto& var_name : var_name_item.second) {
NgraphEngine::fetch_vars.emplace_back(var_name);
}
}
++index;
}
// (left, right - 1) represents indices between feed and fetch
int pivot = left;
while (pivot < right) {
auto op_type = ops.at(pivot)->Type();
auto op_type = ops->at(pivot)->Type();
if (NgraphBridge::isRegister(op_type)) {
++pivot;
} else {
int start = pivot, end = start;
while (pivot < right &&
(!NgraphBridge::isRegister(ops.at(pivot)->Type()))) {
(!NgraphBridge::isRegister(ops->at(pivot)->Type()))) {
++pivot;
++end;
}
std::vector<int> interval = {start, end};
intervals.push_back(interval);
intervals.emplace_back(interval);
}
} // end while
return intervals;
}
static void SubstituteNgraphOp(framework::BlockDesc* block,
std::string block_str,
std::vector<int> interval) {
framework::ProgramDesc program;
block->RemoveOp(interval.at(0), interval.at(1));
auto* ng_op = block->InsertOp(interval.at(0));
ng_op->SetType("ngraph_engine");
ng_op->SetAttr("interval", interval);
ng_op->SetAttr("graph", block_str);
static void SubstituteNgraphOp(
std::vector<std::unique_ptr<framework::OperatorBase>>* ops,
std::string engine_key, std::string block_str, std::vector<int> interval) {
framework::OpDesc ng_op_desc(nullptr);
ng_op_desc.SetType("ngraph_engine");
ng_op_desc.SetAttr("interval", interval);
ng_op_desc.SetAttr("engine_key", engine_key);
ng_op_desc.SetAttr("graph", block_str);
ops->erase(ops->begin() + interval[0], ops->begin() + interval[1]);
ops->insert(ops->begin() + interval[0],
framework::OpRegistry::CreateOp(ng_op_desc));
}
// TODO(baojun-nervana): Move EnableNgraph to compile time per PR #15089
void NgraphEngine::EnableNgraph(const framework::ProgramDesc& program) {
#ifdef PADDLE_WITH_NGRAPH
VLOG(4) << "use_ngraph=True";
for (size_t bid = 0; bid < program.Size(); ++bid) {
// TODO(baojun-nervana): Remove the const_cast
auto* block =
const_cast<framework::ProgramDesc&>(program).MutableBlock(bid);
std::string block_str = block->Proto()->SerializeAsString();
auto intervals = NgraphOpIntervals(block);
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
SubstituteNgraphOp(block, block_str, *it);
}
std::string SerializedBlock(const std::vector<framework::OpDesc*>& op_descs) {
framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto);
block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0);
for (auto* op_desc : op_descs) {
auto* op = block_desc.AppendOp();
*op->Proto() = *op_desc->Proto();
}
return block_desc.Proto()->SerializeAsString();
}
std::string GenerateEngineKey(const framework::BlockDesc& bdesc) {
framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto);
block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0);
for (auto& op_desc : bdesc.AllOps()) {
auto* op = block_desc.AppendOp();
*op->Proto() = *op_desc->Proto();
}
auto engine_key = std::to_string(
std::hash<std::string>()(block_desc.Proto()->SerializeAsString()));
return engine_key;
}
std::string GenerateEngineKey(const std::vector<std::string>& engine_inputs,
const std::vector<std::string>& engine_outputs,
int size) {
std::string engine_hash_key = "";
for (auto name : engine_inputs) {
engine_hash_key += name;
}
for (auto name : engine_outputs) {
engine_hash_key += name;
}
engine_hash_key += std::to_string(size);
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
return engine_key;
}
void NgraphEngine::FuseNgraphOps(
const framework::BlockDesc& block_desc,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops) {
NgraphEngine::p_bdesc = &block_desc;
auto intervals = NgraphOpIntervals(ops);
std::string engine_key =
GenerateEngineKey(feed_vars, fetch_vars, ops->size());
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
SubstituteNgraphOp(ops, engine_key, "", *it);
}
#else
LOG(WARNING)
<< "'NGRAPH' is not supported, Please re-compile with WITH_NGRAPH option";
#endif
}
NgraphEngine::NgraphEngine(const framework::Scope& scope,
const platform::Place& place,
const std::string& serialized_graph,
const std::vector<int>& interval)
const framework::ExecutionContext& ctx)
: scope_(scope), place_(place) {
std::string serialized_graph = ctx.Attr<std::string>("graph");
auto interval = ctx.Attr<std::vector<int>>("interval");
std::string engine_key = ctx.Attr<std::string>("engine_key");
var_in_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
var_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
func_cache_key_ = std::to_string(interval[0]) + std::to_string(interval[1]) +
serialized_graph;
framework::proto::BlockDesc bdesc;
bdesc.ParseFromString(serialized_graph);
framework::BlockDesc block(nullptr, &bdesc);
Prepare(block, interval);
BuildNgIO();
GetNgFunction();
GetNgFunction(engine_key, interval);
}
void NgraphEngine::Prepare(const framework::BlockDesc& block,
const std::vector<int>& interval) {
for (auto& var : block.AllVars()) {
void NgraphEngine::Prepare(const std::vector<int>& interval) {
for (auto& var : p_bdesc->AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
var->GetType() == framework::proto::VarType::LOD_TENSOR ||
var->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY)) {
......@@ -192,108 +261,57 @@ void NgraphEngine::Prepare(const framework::BlockDesc& block,
}
}
auto ops_desc = block.AllOps();
int idx = interval[0];
while (idx < interval[1]) {
auto op_desc = ops_desc.at(idx);
auto op = framework::OpRegistry::CreateOp(*op_desc);
fused_ops_.push_back(std::move(op));
++idx;
}
while (ops_desc.at(idx)->Type() != framework::kFetchOpType) {
auto op_desc = ops_desc.at(idx);
for (auto& var_name_item : op_desc->Inputs()) {
for (auto& var_name : var_name_item.second) {
post_op_inputs_.insert(var_name);
}
}
++idx;
}
while (idx < static_cast<int>(ops_desc.size()) &&
ops_desc.at(idx)->Type() == framework::kFetchOpType) {
std::string fetch_target_name = ops_desc.at(idx)->Input("X")[0];
fetches_.insert(fetch_target_name);
++idx;
}
if (ops_desc.at(interval.at(0) - 1)->Type() == framework::kFeedOpType &&
ops_desc.at(interval.at(1))->Type() == framework::kFetchOpType) {
ng_op_state_ = OpState::FULL;
std::vector<paddle::framework::OpDesc*> ops_desc;
for (auto op_desc : p_bdesc->AllOps()) {
ops_desc.emplace_back(op_desc);
}
for (auto* op_desc : ops_desc) {
for (auto op_desc : ops_desc) {
if (op_desc->Type().find("_grad") != std::string::npos) {
ng_op_state_ = ng_op_state_ == OpState::FULL ? OpState::FULL_TRAIN
: OpState::PARTIAL_TRAIN;
this->is_test_ = false;
break;
}
}
if (ng_op_state_ != OpState::FULL_TRAIN &&
ng_op_state_ != OpState::PARTIAL_TRAIN) {
ng_op_state_ = ng_op_state_ == OpState::FULL ? OpState::FULL_TEST
: OpState::PARTIAL_TEST;
if (interval[0] > 0 &&
ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType &&
interval[1] < static_cast<int>(ops_desc.size()) &&
ops_desc.at(interval.at(1))->Type() == framework::kFetchOpType) {
this->op_state_ = OpState::FULL;
}
}
void NgraphEngine::GetNgInputShape(
std::shared_ptr<framework::OperatorBase> op) {
framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_);
op->RuntimeInferShape(scope_, place_, ctx);
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto sp = Ddim2Shape(tensor_pd->dims());
if (std::find(var_in_.begin(), var_in_.end(), var_name) !=
var_in_.end()) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
// auto ng_type = pd2ng_type_map.at(GetDataTypeOfVar(var));
auto ng_type = var_type_map_.at(var_name);
auto prm =
std::make_shared<ngraph::op::Parameter>(ng_type, sp, true);
(*var_node_map_)[var_name] = prm;
(*var_in_node_map_)[var_name] = prm;
}
}
}
}
if (this->op_state_ == OpState::FULL) {
this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN;
} else {
this->op_state_ =
this->is_test_ ? OpState::PARTIAL_TEST : OpState::PARTIAL_TRAIN;
}
}
void NgraphEngine::BuildNgNodes() {
for (auto& op : fused_ops_) {
for (auto& var_name_item : op->Outputs()) {
int idx = interval[0];
while (idx < interval[1]) {
this->fused_ops_.emplace_back(
framework::OpRegistry::CreateOp(*(ops_desc[idx])));
++idx;
}
while (ops_desc.at(idx)->Type() != framework::kFetchOpType) {
auto op_desc = ops_desc.at(idx);
for (auto& var_name_item : op_desc->Inputs()) {
for (auto& var_name : var_name_item.second) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims();
auto ng_shape = Ddim2Shape(ddim);
auto ng_type = var_type_map_.at(var_name);
auto prm = std::make_shared<ngraph::op::Parameter>(ng_type,
ng_shape, true);
(*var_node_map_)[var_name] = prm;
}
}
this->post_op_inputs_.insert(var_name);
}
}
++idx;
}
NgraphBridge ngb(var_node_map_);
for (auto& op : fused_ops_) {
ngb.BuildNgNode(op);
}
BuildNgIO(ops_desc, interval);
}
void NgraphEngine::BuildNgIO() {
void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
const std::vector<int>& interval) {
std::unordered_set<std::string> inputs;
std::unordered_set<std::string> outputs;
for (auto& op : fused_ops_) {
for (int i = interval[0]; i < interval[1]; ++i) {
auto op = ops_desc[i];
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
inputs.insert(var_name);
......@@ -302,15 +320,11 @@ void NgraphEngine::BuildNgIO() {
std::find(var_in_.begin(), var_in_.end(), var_name) ==
var_in_.end()) {
// fill var_in here to keep lhs and rhs order
var_in_.push_back(var_name);
this->var_in_.emplace_back(var_name);
}
}
}
if (op->Type() != "fill_constant") {
GetNgInputShape(op);
}
for (auto& var_name_item : op->Outputs()) {
PADDLE_ENFORCE_LE(var_name_item.second.size(), 1,
"op %s has more than 1 output - Not handling yet",
......@@ -322,172 +336,278 @@ void NgraphEngine::BuildNgIO() {
}
// var_out.clear();
for (auto& op : fused_ops_) {
for (int i = interval[0]; i < interval[1]; ++i) {
auto op = ops_desc[i];
for (auto& var_name_item : op->Outputs()) {
PADDLE_ENFORCE_LE(var_name_item.second.size(), 1,
"op %s has more than 1 output - Not handling yet",
op->Type());
for (auto& var_name : var_name_item.second) {
switch (ng_op_state_) {
switch (this->op_state_) {
case OpState::PARTIAL_TEST:
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
fetches_.find(var_name) != fetches_.end()) {
var_out_.push_back(var_name);
find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end()) {
this->var_out_.emplace_back(var_name);
}
break;
case OpState::FULL_TEST:
if (fetches_.find(var_name) != fetches_.end()) {
var_out_.push_back(var_name);
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end()) {
this->var_out_.emplace_back(var_name);
}
break;
case OpState::PARTIAL_TRAIN:
if (fetches_.find(var_name) != fetches_.end() ||
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end() ||
post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
persistables_.find(var_name) != persistables_.end()) {
var_out_.push_back(var_name);
this->var_out_.emplace_back(var_name);
}
break;
case OpState::FULL_TRAIN:
if (fetches_.find(var_name) != fetches_.end() ||
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end() ||
persistables_.find(var_name) != persistables_.end()) {
var_out_.push_back(var_name);
this->var_out_.emplace_back(var_name);
}
break;
default:
var_out_.push_back(var_name);
this->var_out_.emplace_back(var_name);
}
}
}
}
for (size_t i = 0; i < var_in_.size(); ++i) {
auto var_name = var_in_[i];
if (persistables_.find(var_name) == persistables_.end()) {
var_in_updates_.emplace_back(i);
}
}
}
void NgraphEngine::BuildNgFunction() {
void NgraphEngine::GetNgInputShape() {
for (auto& var_name : var_in_) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto sp = Ddim2Shape(tensor_pd->dims());
auto ng_type = var_type_map_[var_name];
auto prm = std::make_shared<ngraph::op::Parameter>(ng_type, sp, true);
(*var_node_map_)[var_name] = prm;
(*var_in_node_map_)[var_name] = prm;
}
}
}
void NgraphEngine::BuildNgNodes() {
for (auto& op : fused_ops_) {
for (auto& var_name_item : op->Outputs()) {
for (auto& var_name : var_name_item.second) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims();
auto ng_shape = Ddim2Shape(ddim);
auto ng_type = var_type_map_[var_name];
auto prm = std::make_shared<ngraph::op::Parameter>(ng_type,
ng_shape, true);
(*var_node_map_)[var_name] = prm;
}
}
}
}
}
NgraphBridge ngb(var_node_map_);
for (auto& op : fused_ops_) {
ngb.BuildNgNode(op);
}
}
void NgraphEngine::RunInferShape() {
for (auto& op : fused_ops_) {
framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_);
op->RuntimeInferShape(scope_, place_, ctx);
}
}
void NgraphEngine::BuildNgFunction(const std::vector<int>& interval) {
Prepare(interval);
RunInferShape();
GetNgInputShape();
BuildNgNodes();
ngraph_function_ = nullptr;
ngraph::NodeVector func_outputs;
ngraph::ParameterVector func_inputs;
for (auto& vo : var_out_) {
func_outputs.push_back(var_node_map_->at(vo));
func_outputs.emplace_back(var_node_map_->at(vo));
}
for (auto& vi : var_in_) {
std::shared_ptr<ngraph::op::Parameter> prm =
std::dynamic_pointer_cast<ngraph::op::Parameter>(
var_in_node_map_->at(vi));
func_inputs.push_back(prm);
func_inputs.emplace_back(prm);
}
ngraph_function_ =
std::make_shared<ngraph::Function>(func_outputs, func_inputs);
}
void NgraphEngine::GetNgFunction() {
bool cache_on = true;
if (cache_on) {
std::string input_shape_str;
for (auto& var_name : var_in_) {
auto shape = var_node_map_->at(var_name)->get_shape();
for (size_t i = 0; i < shape.size(); ++i) {
input_shape_str += std::to_string(shape.at(i));
void NgraphEngine::GetNgFunction(std::string engine_key,
const std::vector<int>& interval) {
bool use_cache = true;
if (use_cache) {
this->func_cache_key_ = "";
for (int i = 0; i < std::min(static_cast<int>(feed_vars.size()), 10); ++i) {
auto* var = scope_.FindVar(feed_vars[i]);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto dims = tensor_pd->dims();
for (int j = 0; j < dims.size(); ++j) {
func_cache_key_ += std::to_string(dims[j]);
}
}
}
func_cache_key_ = input_shape_str + func_cache_key_;
if (func_cache_.find(func_cache_key_) != func_cache_.end()) {
ngraph_function_ = func_cache_.at(func_cache_key_);
} else {
BuildNgFunction();
func_cache_[func_cache_key_] = ngraph_function_;
func_cache_key_ += std::to_string(interval[0]) + "_" +
std::to_string(interval[1]) + engine_key;
func_cache_key_ = std::to_string(std::hash<std::string>()(func_cache_key_));
if (engine_cache.find(func_cache_key_) != engine_cache.end()) {
if (engine_cache[func_cache_key_].persistables.size() == 0) {
engine_cache.clear();
t_in_cache_.clear();
} else {
auto var_name = engine_cache[func_cache_key_].persistables.begin();
framework::Variable* var = scope_.FindVar(*var_name);
if (var != pre_var_ptr) {
engine_cache.clear();
t_in_cache_.clear();
}
pre_var_ptr = var;
}
}
if (engine_cache.find(func_cache_key_) == engine_cache.end()) {
BuildNgFunction(interval);
engine_cache[func_cache_key_].ngraph_function = this->ngraph_function_;
engine_cache[func_cache_key_].persistables = this->persistables_;
engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_;
engine_cache[func_cache_key_].var_in = this->var_in_;
engine_cache[func_cache_key_].var_out = this->var_out_;
engine_cache[func_cache_key_].is_test = this->is_test_;
}
} else {
BuildNgFunction();
BuildNgFunction(interval);
}
}
void NgraphEngine::Run(const framework::Scope& scope,
const platform::Place& place) const {
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out;
std::shared_ptr<ngraph::Function> ng_func;
const std::set<std::string>* p_persistables;
const std::vector<size_t>* p_var_in_updates;
const std::vector<std::string>* p_var_in;
const std::vector<std::string>* p_var_out;
bool is_test;
bool use_cache = true;
if (use_cache) {
PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(),
"Cannot find cached data to run ngraph function");
ng_func = engine_cache[func_cache_key_].ngraph_function;
p_persistables = &(engine_cache[func_cache_key_].persistables);
p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates);
p_var_in = &(engine_cache[func_cache_key_].var_in);
p_var_out = &(engine_cache[func_cache_key_].var_out);
is_test = engine_cache[func_cache_key_].is_test;
} else {
ng_func = ngraph_function_;
p_persistables = &this->persistables_;
p_var_in_updates = &this->var_in_updates_;
p_var_in = &this->var_in_;
p_var_out = &this->var_out_;
is_test = this->is_test_;
}
for (size_t i = 0; i < var_in_.size(); ++i) {
auto vi = var_in_.at(i);
auto sp = var_node_map_->at(vi)->get_shape();
std::shared_ptr<ngraph::runtime::Tensor> ti;
auto* var = scope.FindVar(vi);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()),
"Ensure ngraph tensor layout align with paddle tensor");
auto ng_type = var_type_map_.at(vi);
if (ng_type == ngraph::element::f32) {
auto pd_arr = tensor_pd->mutable_data<float>(place);
ti = backend_->create_tensor(ngraph::element::f32, sp, pd_arr);
} else if (ng_type == ngraph::element::i32) {
const int* arr = tensor_pd->data<int>();
ti = backend_->create_tensor(ngraph::element::i32, sp,
const_cast<int*>(arr));
} else if (ng_type == ngraph::element::i64) {
auto pd_arr = tensor_pd->mutable_data<int64_t>(place);
ti = backend_->create_tensor(ngraph::element::i64, sp, pd_arr);
} else if (ng_type == ngraph::element::f64) {
auto pd_arr = tensor_pd->mutable_data<double>(place);
ti = backend_->create_tensor(ngraph::element::f64, sp, pd_arr);
} else if (ng_type == ngraph::element::boolean) {
auto pd_arr = tensor_pd->mutable_data<bool>(place);
ti = backend_->create_tensor(ngraph::element::boolean, sp, pd_arr);
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>* p_t_in;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in = {};
auto m_parameters = ng_func->get_parameters();
auto m_results = ng_func->get_results();
if (is_test && use_cache &&
t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) {
p_t_in = &(t_in_cache_[func_cache_key_]);
for (size_t i = 0; i < p_var_in_updates->size(); ++i) {
int index = p_var_in_updates->at(i);
auto vi = p_var_in->at(index);
auto sp = m_parameters[index]->get_shape();
auto ng_type = m_parameters[index]->get_element_type();
std::shared_ptr<ngraph::runtime::Tensor> ti;
auto* var = scope.FindVar(vi);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
ti = backend_->create_tensor(ng_type, sp, pd_arr);
(*p_t_in)[index] = ti;
} else {
PADDLE_THROW("Data type not handling for var %s", vi);
PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
}
}
} else {
if (is_test && use_cache) {
p_t_in = &(t_in_cache_[func_cache_key_]);
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
p_t_in = &t_in;
}
bool is_test = (ng_op_state_ == OpState::PARTIAL_TEST ||
ng_op_state_ == OpState::FULL_TEST)
? true
: false;
bool is_persistable =
(persistables_.find(vi) != persistables_.end()) ? true : false;
if (is_test && is_persistable) {
ti->set_stale(false);
for (size_t i = 0; i < p_var_in->size(); ++i) {
auto vi = p_var_in->at(i);
auto sp = m_parameters[i]->get_shape();
auto ng_type = m_parameters[i]->get_element_type();
std::shared_ptr<ngraph::runtime::Tensor> ti;
auto* var = scope.FindVar(vi);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()),
"Ensure ngraph tensor layout align with paddle tensor");
ti = backend_->create_tensor(ng_type, sp, pd_arr);
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
}
bool is_persistable =
(p_persistables->find(vi) != p_persistables->end()) ? true : false;
if (is_test && is_persistable) {
ti->set_stale(false);
}
(*p_t_in).emplace_back(ti);
}
t_in.push_back(ti);
}
for (size_t i = 0; i < var_out_.size(); ++i) {
auto vo = var_out_[i];
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out = {};
for (size_t i = 0; i < p_var_out->size(); ++i) {
auto vo = p_var_out->at(i);
auto* var = scope.FindVar(vo);
std::shared_ptr<ngraph::runtime::Tensor> to;
if (var && var->IsType<framework::LoDTensor>()) {
auto sp = m_results[i]->get_shape();
var->GetMutable<framework::LoDTensor>()->Resize(Shape2Ddim(sp));
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
auto dd = tensor_pd->dims();
ngraph::Shape sp = Ddim2Shape(dd);
auto ng_type = var_type_map_.at(vo);
if (ng_type == ngraph::element::f32) {
auto pd_arr = tensor_pd->mutable_data<float>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else if (ng_type == ngraph::element::i64) {
auto pd_arr = tensor_pd->mutable_data<int64_t>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else if (ng_type == ngraph::element::i32) {
auto pd_arr = tensor_pd->mutable_data<int>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else if (ng_type == ngraph::element::f64) {
auto pd_arr = tensor_pd->mutable_data<double>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else if (ng_type == ngraph::element::boolean) {
auto pd_arr = tensor_pd->mutable_data<bool>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else {
PADDLE_THROW("Data type not handled in for var %s", vo);
}
t_out.push_back(to);
auto ng_type = m_results[i]->get_element_type();
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
std::shared_ptr<ngraph::runtime::Tensor> to =
backend_->create_tensor(ng_type, sp, pd_arr);
t_out.emplace_back(to);
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vo);
}
}
auto handle = backend_->compile(ngraph_function_);
handle->call_with_validate(t_out, t_in);
auto handle = backend_->compile(ng_func);
handle->call_with_validate(t_out, *p_t_in);
} // NgraphEngine::Run
} // namespace operators
} // namespace paddle
......@@ -12,12 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
#define PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "ngraph/ngraph.hpp"
......@@ -33,29 +39,47 @@ enum class OpState { /* nGraph support state on ops */
UNKNOWN /* Output all for debug purpose */
};
// cache engine repetitives
struct EngineCache {
std::shared_ptr<ngraph::Function> ngraph_function;
std::set<std::string> persistables;
std::vector<std::string> var_in;
std::vector<std::string> var_out;
std::vector<size_t> var_in_updates;
bool is_test = true;
};
// perform graph build through bridge and execute computation
class NgraphEngine {
public:
explicit NgraphEngine(const framework::Scope& scope,
const platform::Place& place,
const std::string& serialized_graph,
const std::vector<int>& interval);
const framework::ExecutionContext& ctx);
void Run(const framework::Scope& scope, const platform::Place& place) const;
static void EnableNgraph(const framework::ProgramDesc& program);
static const framework::BlockDesc* p_bdesc;
static std::vector<std::string> feed_vars, fetch_vars;
static void FuseNgraphOps(
const framework::BlockDesc& prog,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops);
private:
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
func_cache_;
static std::unordered_map<std::string, EngineCache> engine_cache;
static std::unordered_map<
std::string, std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>
t_in_cache_;
static framework::Variable* pre_var_ptr;
const framework::Scope& scope_;
const platform::Place& place_;
std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::unordered_set<std::string> persistables_;
std::unordered_set<std::string> fetches_;
std::set<std::string> persistables_;
std::unordered_set<std::string> post_op_inputs_;
OpState ng_op_state_ = OpState::UNKNOWN;
OpState op_state_ = OpState::UNKNOWN;
bool is_test_{true};
std::string func_cache_key_;
// ngraph backend eg. CPU
......@@ -66,6 +90,8 @@ class NgraphEngine {
std::vector<std::string> var_in_;
// var_name of outputs from fetch in order
std::vector<std::string> var_out_;
// non-persitable var_in
std::vector<size_t> var_in_updates_;
// map input vars to nodes
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
......@@ -74,20 +100,23 @@ class NgraphEngine {
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_node_map_;
// prepare info for nraph engine
void Prepare(const framework::BlockDesc& block,
const std::vector<int>& interval);
// prepare info for ngraph engine need
void Prepare(const std::vector<int>& interval);
// get ngraph engine input and output list
void BuildNgIO(const std::vector<framework::OpDesc*>& op_descs,
const std::vector<int>& interval);
// get ngraph input and define ngraph input parameters
void GetNgInputShape(std::shared_ptr<framework::OperatorBase> op);
void GetNgInputShape();
// Call ngraph bridge to map ops
void BuildNgNodes();
// get the ngraph input and output var list
void BuildNgIO();
// run paddle RuntimeInferShape to get the tensor shape
void RunInferShape();
// build ngraph function call
void BuildNgFunction();
void BuildNgFunction(const std::vector<int>& interval);
// Check cache for ngraph function or otherwise build the function
void GetNgFunction();
void GetNgFunction(std::string engine_key, const std::vector<int>& interval);
};
} // namespace operators
} // namespace paddle
#endif // PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
......@@ -29,6 +29,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Xs", "A list of inputs.").AsDispensable();
AddOutput("Ys", "A list of outputs").AsDispensable();
AddAttr<std::string>("graph", "the graph.");
AddAttr<std::string>("engine_key", "the engine hash key.");
AddAttr<std::vector<int>>("interval", "op interval supported by ngraph");
AddComment("ngraph engine operator.");
}
......
......@@ -46,10 +46,8 @@ class NgraphEngineKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto& scope = ctx.scope();
auto place = ctx.GetPlace();
std::string serialized_graph = ctx.Attr<std::string>("graph");
auto interval = ctx.Attr<std::vector<int>>("interval");
NgraphEngine ngraph_engine(scope, place, serialized_graph, interval);
NgraphEngine ngraph_engine(scope, place, ctx);
ngraph_engine.Run(scope, place);
}
};
......
......@@ -94,6 +94,14 @@ bool IsCompiledWithMKLDNN() {
#endif
}
bool IsCompiledWithNGRAPH() {
#ifndef PADDLE_WITH_NGRAPH
return false;
#else
return true;
#endif
}
bool IsCompiledWithBrpc() {
#ifndef PADDLE_WITH_DISTRIBUTE
return false;
......@@ -874,6 +882,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_devices",
[](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_ngraph", IsCompiledWithNGRAPH);
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN);
m.def("is_compiled_with_brpc", IsCompiledWithBrpc);
......
......@@ -125,7 +125,7 @@ def __bootstrap__():
os.environ['OMP_NUM_THREADS'] = str(num_threads)
sysstr = platform.system()
read_env_flags = [
'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_ngraph',
'check_nan_inf', 'benchmark', 'eager_delete_scope',
'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory',
'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb',
'fast_eager_deletion_mode', 'memory_fraction_of_eager_deletion',
......@@ -143,6 +143,9 @@ def __bootstrap__():
if core.is_compiled_with_mkldnn():
read_env_flags.append('use_mkldnn')
if core.is_compiled_with_ngraph():
read_env_flags.append('use_ngraph')
if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline')
read_env_flags.append('rpc_server_profile_path')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册