提交 5eb81fe5 编写于 作者: M mozga-intel 提交者: Tao Luo

Capi for a ngraph engine (#17037)

上级 5782ddda
......@@ -95,6 +95,14 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_squash_pass inference mkldnn)
endif()
if(WITH_NGRAPH)
cc_library(ngraph_subgraph_pass SRCS ngraph_subgraph_pass.cc DEPS ngraph_bridge
analysis_helper subgraph_detector graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
file(APPEND ${pass_file} "USE_PASS(ngraph_subgraph_pass);\n")
set(INFER_IR_PASSES ${INFER_IR_PASSES} ngraph_subgraph_pass CACHE INTERNAL "")
endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector )
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <set>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/ngraph_subgraph_pass.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
namespace ANAT = paddle::inference::analysis;
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
const std::set<std::string> &engine_outputs,
const std::string &size) {
std::string engine_hash_key = "";
for (auto name : engine_inputs) {
engine_hash_key += name;
}
for (auto name : engine_outputs) {
engine_hash_key += name;
}
engine_hash_key += size;
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
return engine_key;
}
void NgraphSubgraphPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("ngraph_subgraph_pass", graph);
std::unordered_set<Node *> nodes2delete;
auto teller = [](const Node *node) {
if (!node->IsOp() || !node->Op()) return false;
auto op_type = node->Op()->Type();
return !paddle::operators::NgraphBridge::isRegister(op_type);
};
ANAT::SubGraphFuser fuser(graph, teller, 0, "ngraph_engine");
fuser();
for (auto *node : graph->Nodes()) {
if (node->IsOp() && !ANAT::Agent(node).subgraph()->empty()) {
OpDesc *op_desc = node->Op();
op_desc->SetType("ngraph_engine");
for (auto it = ANAT::Agent(node).subgraph()->begin();
it != ANAT::Agent(node).subgraph()->end(); ++it) {
}
CreateNgraphEngineOp(node, graph);
std::unordered_set<const Node *> nodes2remove(
ANAT::Agent(node).subgraph()->begin(),
ANAT::Agent(node).subgraph()->end());
GraphSafeRemoveNodes(graph, nodes2remove);
}
}
std::unordered_set<const Node *> nodes2remove;
for (auto *node : graph->Nodes()) {
if (node->IsOp() && ANAT::Agent(node).deleted()) {
nodes2remove.insert(node);
}
}
framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
std::vector<ir::Node *> nodes = ir::TopologySortOperations(*graph);
}
void NgraphSubgraphPass::CreateNgraphEngineOp(framework::ir::Node *node,
Graph *graph) const {
auto *op_desc = node->Op();
auto &subgraph = *ANAT::Agent(node).subgraph();
PADDLE_ENFORCE(!subgraph.empty());
framework::ProgramDesc *program_desc =
Get<framework::ProgramDesc *>("program");
const framework::BlockDesc &main_block =
program_desc->Block(framework::kRootBlockIndex);
framework::BlockDesc *new_block = program_desc->AppendBlock(main_block);
framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto);
block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0);
for (auto *node : subgraph) {
auto *new_block_op = new_block->AppendOp();
auto *op = block_desc.AppendOp();
*new_block_op->Proto() = *node->Op()->Proto();
*op->Proto() = *node->Op()->Proto();
}
std::set<std::string> input_names;
std::set<std::string> input_names_with_id;
for (auto *x : node->inputs) {
input_names.insert(x->Name());
input_names_with_id.insert(x->Name() + std::to_string(x->id()));
}
op_desc->SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
std::set<std::string> output_names;
std::set<std::string> output_names_with_id;
for (auto *x : node->outputs) {
output_names.insert(x->Name());
output_names_with_id.insert(x->Name() + std::to_string(x->id()));
}
op_desc->SetOutput(
"Ys", std::vector<std::string>(output_names.begin(), output_names.end()));
auto *vars = block_desc.Proto()->mutable_vars();
for (framework::ir::Node *node : graph->Nodes()) {
if (node->IsVar() && node->Var()) {
*vars->Add() = *node->Var()->Proto();
}
}
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(),
"the block has no var-desc");
op_desc->SetType("ngraph_engine");
int sgs = subgraph.size();
std::string engine_key = GenerateEngineKey(
input_names_with_id, output_names_with_id, std::to_string(sgs));
std::vector<int> interval{0, sgs};
op_desc->SetAttr("interval", interval);
op_desc->SetAttr("graph", block_desc.Proto()->SerializeAsString());
op_desc->SetAttr("engine_key", engine_key);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(ngraph_subgraph_pass, paddle::framework::ir::NgraphSubgraphPass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse supported ops to a NgraphEngineOp.
*/
class NgraphSubgraphPass : public FusePassBase {
public:
void ApplyImpl(ir::Graph *graph) const override;
virtual ~NgraphSubgraphPass() {}
private:
void CreateNgraphEngineOp(framework::ir::Node *x,
framework::ir::Graph *graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -63,6 +63,18 @@ void SetAttr<std::vector<std::string>>(framework::proto::OpDesc *op,
}
}
template <>
void SetAttr<std::vector<int>>(framework::proto::OpDesc *op,
const std::string &name,
const std::vector<int> &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::INTS);
for (const auto i : data) {
attr->add_ints(i);
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -112,7 +112,10 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("engine_opt_info", new std::map<std::string, std::string>(
argument->engine_opt_info()));
}
if (pass_name == "ngraph_subgraph_pass") {
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
}
if (pass_name == "anakin_subgraph_pass") {
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
......
......@@ -420,7 +420,7 @@ void SubGraphFuser::ReplaceNodesWithSubGraphs() {
// Node that contains this subgraph 2. Mark the nodes inside the sub-graph
// as deleted. 3. Replace the deleted node with the new Block Node.
framework::OpDesc empty_desc;
empty_desc.SetType("anakin_engine");
empty_desc.SetType(name_);
auto *block_node = graph_->CreateOpNode(&empty_desc);
Agent(block_node).set_subgraph({});
auto io = ExtractInputAndOutputOfSubGraph(subgraph);
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
......@@ -74,10 +75,11 @@ class SubGraphFuser {
using NodeInsideSubgraphTeller = SubgraphDetector::NodeInsideSubgraphTeller;
SubGraphFuser(Graph *graph, const NodeInsideSubgraphTeller &teller,
int min_subgraph_size)
int min_subgraph_size, std::string name = "anakin_engine")
: graph_(graph),
node_inside_subgraph_teller_(teller),
min_subgraph_size_{min_subgraph_size} {}
min_subgraph_size_{min_subgraph_size},
name_{name} {}
// The main method which run all the logic.
void operator()();
......@@ -90,6 +92,7 @@ class SubGraphFuser {
Graph *graph_;
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
int min_subgraph_size_;
const std::string name_;
};
struct NodeWrapper {
......
......@@ -31,6 +31,10 @@ if (ANAKIN_FOUND)
set(inference_deps ${inference_deps} anakin_op_converter anakin_engine)
endif()
if(WITH_NGRAPH)
set(inference_deps ${inference_deps} ngraph)
endif()
add_subdirectory(details)
if(WITH_MKLDNN)
......@@ -40,7 +44,11 @@ if(WITH_MKLDNN)
endif()
cc_library(analysis_config SRCS analysis_config.cc DEPS ${mkldnn_quantizer_cfg} lod_tensor paddle_pass_builder)
cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc)
if(WITH_NGRAPH)
cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc DEPS ngraph)
else(WITH_NGRAPH)
cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc)
endif(WITH_NGRAPH)
cc_library(analysis_predictor SRCS analysis_predictor.cc ${mkldnn_quantizer_src} DEPS paddle_inference_api zero_copy_tensor
reset_tensor_array analysis_config paddle_pass_builder ir_pass_manager ${inference_deps})
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS
......
......@@ -107,6 +107,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(tensorrt_precision_mode_);
CP_MEMBER(trt_use_static_engine_);
CP_MEMBER(trt_use_calib_mode_);
// NGRAPH related.
CP_MEMBER(use_ngraph_);
// MKLDNN related.
CP_MEMBER(use_mkldnn_);
CP_MEMBER(mkldnn_enabled_op_types_);
......@@ -170,6 +172,16 @@ void AnalysisConfig::EnableMkldnnQuantizer() {
Update();
}
void AnalysisConfig::EnableNgraph() {
#ifdef PADDLE_WITH_NGRAPH
pass_builder()->EnableNgraph();
use_ngraph_ = true;
#else
LOG(ERROR) << "Please compile with NGRAPH first to use NGRAPH";
use_ngraph_ = false;
#endif
}
std::shared_ptr<MkldnnQuantizerConfig> AnalysisConfig::mkldnn_quantizer_config()
const {
PADDLE_ENFORCE_NOT_NULL(mkldnn_quantizer_config_,
......@@ -238,6 +250,20 @@ void AnalysisConfig::Update() {
}
}
if (use_ngraph_) {
if (!enable_ir_optim_) {
LOG(ERROR)
<< "EnableNgraph() only works when IR optimization is enabled.";
}
#ifdef PADDLE_WITH_NGRAPH
pass_builder()->EnableNgraph();
use_ngraph_ = true;
#else
LOG(ERROR) << "Please compile with NGRAPH first to use NGRAPH";
use_ngraph_ = false;
#endif
}
if (use_mkldnn_) {
#ifdef PADDLE_WITH_MKLDNN
if (!enable_ir_optim_) {
......@@ -312,6 +338,8 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << static_memory_optim_;
ss << static_memory_optim_force_update_;
ss << use_ngraph_;
ss << use_mkldnn_;
for (auto &item : mkldnn_enabled_op_types_) ss << item;
ss << ";";
......
......@@ -169,6 +169,13 @@ struct AnalysisConfig {
*/
void SwitchIrDebug(int x = true);
/** Turn on NGRAPH.
*/
void EnableNgraph();
/** A boolean state telling whether to use the NGRAPH.
*/
bool ngraph_enabled() const { return use_ngraph_; }
/** Turn on MKLDNN.
*/
void EnableMKLDNN();
......@@ -274,6 +281,7 @@ struct AnalysisConfig {
bool static_memory_optim_{false};
bool static_memory_optim_force_update_{false};
bool use_ngraph_{false};
bool use_mkldnn_{false};
std::unordered_set<std::string> mkldnn_enabled_op_types_;
......
......@@ -132,6 +132,10 @@ void GpuPassStrategy::EnableMkldnnQuantizer() {
LOG(ERROR) << "GPU not support MKL-DNN quantization";
}
void GpuPassStrategy::EnableNgraph() {
LOG(ERROR) << "GPU not support Ngraph yet";
}
CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// NOTE the large fusions should be located in the front, so that they will
// not be damaged by smaller ones.
......@@ -198,4 +202,14 @@ void CpuPassStrategy::EnableMkldnnQuantizer() {
#endif
}
void CpuPassStrategy::EnableNgraph() {
#ifdef PADDLE_WITH_NGRAPH
if (!use_ngraph_) {
passes_.insert(passes_.begin(), "ngraph_subgraph_pass");
}
use_ngraph_ = true;
#else
use_ngraph_ = false;
#endif
}
} // namespace paddle
......@@ -90,6 +90,10 @@ class PassStrategy : public PaddlePassBuilder {
*/
virtual void EnableMKLDNN() {}
/** Enable NGRAPH optimization
*/
virtual void EnableNgraph() {}
/** Enable MKLDNN quantize optimization
*/
virtual void EnableMkldnnQuantizer() {}
......@@ -99,6 +103,7 @@ class PassStrategy : public PaddlePassBuilder {
virtual ~PassStrategy() = default;
protected:
bool use_ngraph_{false};
bool use_gpu_{false};
bool use_mkldnn_{false};
};
......@@ -112,16 +117,19 @@ class CpuPassStrategy : public PassStrategy {
explicit CpuPassStrategy(const CpuPassStrategy &other)
: PassStrategy(other.AllPasses()) {
use_gpu_ = other.use_gpu_;
use_ngraph_ = other.use_ngraph_;
use_mkldnn_ = other.use_mkldnn_;
use_mkldnn_quantizer_ = other.use_mkldnn_quantizer_;
}
virtual ~CpuPassStrategy() = default;
void EnableNgraph() override;
void EnableMKLDNN() override;
void EnableMkldnnQuantizer() override;
protected:
bool use_ngraph_{false};
bool use_mkldnn_quantizer_{false};
};
......@@ -136,6 +144,7 @@ class GpuPassStrategy : public PassStrategy {
use_gpu_ = true;
}
void EnableNgraph() override;
void EnableMKLDNN() override;
void EnableMkldnnQuantizer() override;
......
......@@ -78,6 +78,8 @@ std::ostream &operator<<(std::ostream &os, const AnalysisConfig &config) {
<< "use_tensorrt: " << config.tensorrt_engine_enabled() << "\n";
os << GenSpaces(num_spaces) << "use_mkldnn: " << config.mkldnn_enabled()
<< "\n";
os << GenSpaces(num_spaces) << "use_ngraph: " << config.ngraph_enabled()
<< "\n";
num_spaces--;
os << GenSpaces(num_spaces) << "}\n";
return os;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册