提交 2e309b11 编写于 作者: J JiabinYang

test=develop, merge develop

......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h"
#include <set>
#include <vector>
namespace paddle {
......@@ -79,7 +80,7 @@ NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
}
std::unordered_set<Node *> visited;
std::unordered_set<Node *> to_visit{source.begin(), source.end()};
std::set<Node *> to_visit{source.begin(), source.end()};
std::vector<Node *> inlink_visited;
while (!to_visit.empty()) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include <algorithm>
#include <map>
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/ngraph_operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type.h"
#include "ngraph/ngraph.hpp"
namespace paddle {
namespace framework {
static ngraph::Shape Ddim2Shape(const DDim& dims) {
ngraph::Shape sp;
for (int i = 0; i < dims.size(); ++i) {
int k = dims[i];
k = k == 0 ? 1 : k;
sp.push_back(k);
}
return sp;
}
static std::map<proto::VarType::Type, ngraph::element::Type> pd2ng_type_map = {
{proto::VarType::FP32, ngraph::element::f32},
{proto::VarType::FP64, ngraph::element::f64},
{proto::VarType::INT32, ngraph::element::i32},
{proto::VarType::INT64, ngraph::element::i64},
{proto::VarType::BOOL, ngraph::element::boolean},
};
typedef enum { /* nGraph support state on ops */
FULL_TRAIN, /* Support full ops for train */
PARTIAL_TRAIN, /* Support partial ops for train */
FULL_TEST, /* Support full list of ops for test */
PARTIAL_TEST /* Support partial list of ops for test */
} op_state;
// perform graph build through bridge and execute computation
class NgraphEngine {
public:
explicit NgraphEngine(const Scope& scope, const platform::Place& place,
const std::vector<std::shared_ptr<OperatorBase>>& ops,
const std::unordered_map<
std::string, ngraph::element::Type>& var_type_map,
const std::unordered_set<std::string>& persist,
const std::unordered_set<std::string>& fetches,
const std::unordered_set<std::string>& post_op_inputs,
op_state ng_op_state)
: scope_(scope),
place_(place),
fused_ops_(ops),
var_type_map_(var_type_map),
persistables_(persist),
fetches_(fetches),
post_op_inputs_(post_op_inputs),
ng_op_state_(ng_op_state) {
var_in_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
var_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
BuildNgIO();
GetNgFunction();
}
void Run(const Scope& scope, const platform::Place& place) const;
private:
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
func_cache_;
const Scope& scope_;
const platform::Place& place_;
std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::unordered_set<std::string> persistables_;
std::unordered_set<std::string> fetches_;
std::unordered_set<std::string> post_op_inputs_;
op_state ng_op_state_;
// ngraph backend eg. CPU
static std::shared_ptr<ngraph::runtime::Backend> backend_;
// ngraph function to call and execute
std::shared_ptr<ngraph::Function> ngraph_function_;
// var_name of inputs
std::vector<std::string> var_in_;
// var_name of outputs from fetch in order
std::vector<std::string> var_out_;
// map input vars to nodes
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_in_node_map_;
// map each var name with a ngraph node
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_node_map_;
// cache key to check if function is cached
std::shared_ptr<std::string> GetCacheKey();
// get ngraph input and define ngraph input parameters
void GetNgInputShape(std::shared_ptr<OperatorBase> op);
// Call ngraph bridge to map ops
void BuildNgNodes();
// get the ngraph input and output var list
void BuildNgIO();
// build ngraph function call
void BuildNgFunction();
// Check cache for ngraph function or otherwise build the function
void GetNgFunction();
};
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
NgraphOperator::NgraphOpIntervals(
std::vector<std::unique_ptr<paddle::framework::OperatorBase>>* ops) {
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
intervals;
if (ops->empty()) {
return intervals;
}
size_t size = ops->size();
size_t left = 0;
while (left < size && ops->at(left)->Type() != kFeedOpType) {
++left;
}
if (left == size) {
return intervals;
}
while (left < size && ops->at(left)->Type() == kFeedOpType) {
++left;
}
size_t right = left;
while (right < size && ops->at(right)->Type() != kFetchOpType) {
++right;
}
if (right == size) {
return intervals;
}
if (left >= right) return intervals;
// (left, right - 1) represents indices between feed and fetch
size_t pivot = left;
while (pivot < right) {
auto op_type = ops->at(pivot)->Type();
if (paddle::framework::NgraphBridge::NG_NODE_MAP.find(op_type) ==
paddle::framework::NgraphBridge::NG_NODE_MAP.end()) {
++pivot;
} else {
size_t start = pivot, end = start;
while (pivot < right &&
(paddle::framework::NgraphBridge::NG_NODE_MAP.find(
ops->at(pivot)->Type()) !=
paddle::framework::NgraphBridge::NG_NODE_MAP.end())) {
++pivot;
++end;
}
std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>
interval = {ops->begin() + start, ops->begin() + end};
intervals.push_back(interval);
}
} // end while
return intervals;
}
NgraphOperator::NgraphOperator(
const ProgramDesc& prog, size_t block_id,
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs),
pdesc_(prog),
block_(block_id) {
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = start;
it != end; ++it) {
fused_ops_.push_back(std::move(*it));
}
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = end;
(*it)->Type() != kFetchOpType; ++it) {
for (auto& var_name_item : (*it)->Inputs()) {
for (auto& var_name : var_name_item.second) {
post_op_inputs_.insert(var_name);
}
}
}
if ((*(start - 1))->Type() == kFeedOpType && (*end)->Type() == kFetchOpType) {
is_full_ = true;
}
Process();
}
void NgraphOperator::Process() {
auto& bdesc = pdesc_.Block(block_);
for (auto& var : bdesc.AllVars()) {
if (!(var->GetType() == proto::VarType::SELECTED_ROWS ||
var->GetType() == proto::VarType::LOD_TENSOR ||
var->GetType() == proto::VarType::LOD_TENSOR_ARRAY)) {
continue;
}
auto var_name = var->Name();
if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (var_name != "fetch" && var_name != "feed") {
auto pd_type = var->GetDataType();
if (pd2ng_type_map.find(pd_type) == pd2ng_type_map.end()) {
PADDLE_THROW("Data type of var %s not found in pd2ng_type_map",
var_name);
}
var_type_map_[var_name] = pd2ng_type_map[pd_type];
}
if (var->Persistable()) {
persistables_.insert(var->Name());
}
}
for (auto* op : bdesc.AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
fetches_.insert(fetch_target_name);
}
}
}
void NgraphOperator::RunImpl(const Scope& scope,
const platform::Place& place) const {
op_state ng_op_state = PARTIAL_TEST;
auto& bdesc = pdesc_.Block(block_);
for (auto* op : bdesc.AllOps()) {
if (op->Type().find("_grad") != std::string::npos) {
ng_op_state = PARTIAL_TRAIN;
break;
}
}
if (is_full_) {
ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN;
}
NgraphEngine ngraph_engine(scope, place, fused_ops_, var_type_map_,
persistables_, fetches_, post_op_inputs_,
ng_op_state);
ngraph_engine.Run(scope, place);
}
std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
NgraphEngine::func_cache_ = {};
std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU");
void NgraphEngine::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_);
op->RuntimeInferShape(scope_, place_, ctx);
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto sp = Ddim2Shape(tensor_pd->dims());
if (std::find(var_in_.begin(), var_in_.end(), var_name) !=
var_in_.end()) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto ng_type = var_type_map_.at(var_name);
auto prm =
std::make_shared<ngraph::op::Parameter>(ng_type, sp, true);
(*var_node_map_)[var_name] = prm;
(*var_in_node_map_)[var_name] = prm;
}
}
}
}
}
}
void NgraphEngine::BuildNgNodes() {
for (auto& var_name : var_out_) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims();
auto ng_shape = Ddim2Shape(ddim);
auto ng_type = var_type_map_.at(var_name);
auto prm =
std::make_shared<ngraph::op::Parameter>(ng_type, ng_shape, true);
(*var_node_map_)[var_name] = prm;
}
}
}
paddle::framework::NgraphBridge ngb(var_node_map_);
for (auto& op : fused_ops_) {
ngb.BuildNgNode(op);
}
}
void NgraphEngine::BuildNgIO() {
std::unordered_set<std::string> inputs;
std::unordered_set<std::string> outputs;
for (auto& op : fused_ops_) {
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
inputs.insert(var_name);
const bool is_output = outputs.find(var_name) != outputs.end();
if (!is_output &&
std::find(var_in_.begin(), var_in_.end(), var_name) ==
var_in_.end()) {
// fill var_in here to keep lhs and rhs order
var_in_.push_back(var_name);
}
}
}
if (op->Type() != "fill_constant") {
GetNgInputShape(op);
}
for (auto& var_name_item : op->Outputs()) {
PADDLE_ENFORCE_LE(var_name_item.second.size(), 1,
"op %s has more than 1 output - Not handling yet",
op->Type());
for (auto& var_name : var_name_item.second) {
outputs.insert(var_name);
}
}
}
// var_out.clear();
for (auto& op : fused_ops_) {
for (auto& var_name_item : op->Outputs()) {
PADDLE_ENFORCE_LE(var_name_item.second.size(), 1,
"op %s has more than 1 output - Not handling yet",
op->Type());
for (auto& var_name : var_name_item.second) {
switch (ng_op_state_) {
case PARTIAL_TEST:
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
fetches_.find(var_name) != fetches_.end()) {
var_out_.push_back(var_name);
}
break;
case FULL_TEST:
if (fetches_.find(var_name) != fetches_.end()) {
var_out_.push_back(var_name);
}
break;
case PARTIAL_TRAIN:
if (fetches_.find(var_name) != fetches_.end() ||
post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
persistables_.find(var_name) != persistables_.end()) {
var_out_.push_back(var_name);
}
break;
case FULL_TRAIN:
if (fetches_.find(var_name) != fetches_.end() ||
persistables_.find(var_name) != persistables_.end()) {
var_out_.push_back(var_name);
}
break;
default:
var_out_.push_back(var_name);
}
}
}
}
}
void NgraphEngine::BuildNgFunction() {
BuildNgNodes();
ngraph_function_ = nullptr;
ngraph::NodeVector func_outputs;
ngraph::ParameterVector func_inputs;
for (auto& vo : var_out_) {
func_outputs.push_back(var_node_map_->at(vo));
}
for (auto& vi : var_in_) {
std::shared_ptr<ngraph::op::Parameter> prm =
std::dynamic_pointer_cast<ngraph::op::Parameter>(
var_in_node_map_->at(vi));
func_inputs.push_back(prm);
}
ngraph_function_ =
std::make_shared<ngraph::Function>(func_outputs, func_inputs);
}
std::shared_ptr<std::string> NgraphEngine::GetCacheKey() {
auto cache_key = std::make_shared<std::string>("");
*cache_key += std::to_string(fused_ops_.size());
for (auto& op : fused_ops_) {
*cache_key += op->Type();
}
for (auto& var_name : var_in_) {
auto shape = var_node_map_->at(var_name)->get_shape();
*cache_key += var_name;
*cache_key += var_type_map_.at(var_name).c_type_string();
for (size_t i = 0; i < shape.size(); ++i) {
*cache_key += std::to_string(shape.at(i));
}
}
for (auto& var_name : var_out_) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims();
for (int i = 0; i < ddim.size(); ++i) {
*cache_key += std::to_string(ddim[i]);
}
}
}
return cache_key;
}
void NgraphEngine::GetNgFunction() {
bool cache_on = true;
if (cache_on) {
std::string cache_key_val = *GetCacheKey();
if (func_cache_.find(cache_key_val) != func_cache_.end()) {
ngraph_function_ = func_cache_.at(cache_key_val);
} else {
BuildNgFunction();
func_cache_[cache_key_val] = ngraph_function_;
}
} else {
BuildNgFunction();
}
}
void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const {
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out;
for (size_t i = 0; i < var_in_.size(); ++i) {
auto vi = var_in_.at(i);
auto sp = var_node_map_->at(vi)->get_shape();
std::shared_ptr<ngraph::runtime::Tensor> ti;
auto* var = scope.FindVar(vi);
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()),
"Ensure ngraph tensor layout align with paddle tensor");
if (tensor_pd->type() == proto::VarType::FP32) {
const float* arr = tensor_pd->data<float>();
ti = backend_->create_tensor(ngraph::element::f32, sp,
const_cast<float*>(arr));
} else if (tensor_pd->type() == proto::VarType::INT32) {
const int* arr = tensor_pd->data<int>();
ti = backend_->create_tensor(ngraph::element::i32, sp,
const_cast<int*>(arr));
} else if (tensor_pd->type() == proto::VarType::INT64) {
const int64_t* arr = tensor_pd->data<int64_t>();
ti = backend_->create_tensor(ngraph::element::i64, sp,
const_cast<int64_t*>(arr));
} else if (tensor_pd->type() == proto::VarType::FP64) {
const double* arr = tensor_pd->data<double>();
ti = backend_->create_tensor(ngraph::element::f64, sp,
const_cast<double*>(arr));
} else if (tensor_pd->type() == proto::VarType::BOOL) {
const bool* arr = tensor_pd->data<bool>();
ti = backend_->create_tensor(ngraph::element::boolean, sp,
const_cast<bool*>(arr));
} else {
PADDLE_THROW("Data type not handling for var %s", vi);
}
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
}
bool is_test = (ng_op_state_ == PARTIAL_TEST || ng_op_state_ == FULL_TEST)
? true
: false;
bool is_persistable =
(persistables_.find(vi) != persistables_.end()) ? true : false;
if (is_test && is_persistable) {
ti->set_stale(false);
}
t_in.push_back(ti);
}
for (size_t i = 0; i < var_out_.size(); ++i) {
auto var_name = var_out_[i];
auto* var = scope.FindVar(var_name);
std::shared_ptr<ngraph::runtime::Tensor> to;
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
auto dd = tensor_pd->dims();
ngraph::Shape sp = Ddim2Shape(dd);
auto ng_type = var_type_map_.at(var_name);
if (ng_type == ngraph::element::f32) {
auto pd_arr = tensor_pd->mutable_data<float>(place);
to = backend_->create_tensor(ngraph::element::f32, sp, pd_arr);
} else if (ng_type == ngraph::element::i64) {
auto pd_arr = tensor_pd->mutable_data<int64_t>(place);
to = backend_->create_tensor(ngraph::element::i64, sp, pd_arr);
} else if (ng_type == ngraph::element::f64) {
auto pd_arr = tensor_pd->mutable_data<double>(place);
to = backend_->create_tensor(ngraph::element::f64, sp, pd_arr);
} else if (ng_type == ngraph::element::boolean) {
auto pd_arr = tensor_pd->mutable_data<bool>(place);
to = backend_->create_tensor(ngraph::element::boolean, sp, pd_arr);
} else {
PADDLE_THROW("Data type not handled in for var %s", var_name);
}
t_out.push_back(to);
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", var_name);
}
}
backend_->call(backend_->compile(ngraph_function_), t_out, t_in);
} // NgraphEngine::RunImpl
} // namespace framework
} // namespace paddle
......@@ -204,59 +204,68 @@ framework::LoDTensor& VarBase::GradValue() {
}
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
if (!grad_op_desc_ && backward_id_ <= 0) {
if (grad_op_descs_.empty() && backward_id_ <= 0) {
LOG(WARNING) << "op with no grad: " << op_desc_->Type();
return {};
}
std::map<std::string, std::vector<framework::Variable*>> grad_outputs;
std::vector<framework::VariableValueMap> grad_outputs;
if (backward_id_ > 0) {
VLOG(3) << "py_layer_grad";
grad_outputs[framework::GradVarName(PyLayer::kFwdOut)] = PyLayer::ApplyGrad(
backward_id_,
grad_input_vars_[framework::GradVarName(PyLayer::kFwdInp)]);
grad_outputs.resize(1);
grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] =
PyLayer::ApplyGrad(
backward_id_,
grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
} else {
VLOG(3) << "op grad " << grad_op_desc_->Type();
for (auto it : grad_output_vars_) {
auto& outputs = grad_outputs[it.first];
for (size_t i = 0; i < it.second.size(); ++i) {
// Allocate a new variable
Variable* tmp_var = new framework::Variable();
tmp_var->GetMutable<framework::LoDTensor>();
outputs.push_back(tmp_var);
grad_outputs.resize(grad_op_descs_.size());
for (size_t k = 0; k < grad_op_descs_.size(); ++k) {
framework::OpDesc* grad_op_desc = grad_op_descs_[k];
VLOG(3) << "op grad " << grad_op_desc->Type();
for (auto it : grad_output_vars_[k]) {
auto& outputs = grad_outputs[k][it.first];
for (size_t i = 0; i < it.second.size(); ++i) {
// Allocate a new variable
Variable* tmp_var = new framework::Variable();
tmp_var->GetMutable<framework::LoDTensor>();
outputs.push_back(tmp_var);
}
}
}
framework::RuntimeContext ctx(grad_input_vars_, grad_outputs);
framework::RuntimeContext ctx(grad_input_vars_[k], grad_outputs[k]);
// No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_);
grad_op_desc_->InferVarType(block_);
// No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_);
grad_op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc_);
framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc);
framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
}
}
for (auto it : grad_output_vars_) {
auto& outputs = grad_outputs[it.first];
auto& origin_outputs = it.second;
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
framework::Variable* grad = outputs[i];
framework::Variable* orig_grad = origin_outputs[i];
AddTo(grad, orig_grad, place_);
delete grad;
for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
for (auto it : grad_output_vars_[k]) {
auto& outputs = grad_outputs[k][it.first];
auto& origin_outputs = it.second;
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
framework::Variable* grad = outputs[i];
framework::Variable* orig_grad = origin_outputs[i];
AddTo(grad, orig_grad, place_);
delete grad;
}
}
}
return input_vars_;
}
......
......@@ -184,12 +184,13 @@ class OpBase {
OpBase()
: op_desc_(nullptr),
forward_id_(-1),
grad_op_desc_(nullptr),
backward_id_(-1),
place_(platform::CPUPlace()) {}
virtual ~OpBase() {
if (grad_op_desc_) delete grad_op_desc_;
for (framework::OpDesc* desc : grad_op_descs_) {
delete desc;
}
}
std::map<std::string, std::vector<VarBase*>> ApplyGrad();
......@@ -198,9 +199,11 @@ class OpBase {
// For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_.
framework::OpDesc* op_desc_;
int forward_id_;
// When has backward, one of `grad_op_desc_` or `backward_id_` is set,
// When has backward, one of `grad_op_descs_` or `backward_id_` is set,
// not both.
framework::OpDesc* grad_op_desc_;
// Note: each fwd op corresponds to a vector of bwd ops.
std::vector<framework::OpDesc*> grad_op_descs_;
int backward_id_;
platform::Place place_;
......@@ -210,8 +213,11 @@ class OpBase {
OpBasePtrMap pre_ops_;
std::map<std::string, std::vector<int>> pre_ops_out_idx_;
framework::VariableValueMap grad_input_vars_;
framework::VariableValueMap grad_output_vars_;
// Inputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_input_vars_;
// Outputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_output_vars_;
framework::BlockDesc* block_;
};
......
......@@ -24,17 +24,17 @@ namespace imperative {
void CreateGradOp(const framework::OpDesc& op_desc,
const std::unordered_set<std::string>& no_grad_set,
const std::vector<framework::BlockDesc*>& grad_sub_block,
framework::OpDesc** grad_op_desc,
std::vector<framework::OpDesc*>* grad_op_descs,
std::unordered_map<std::string, std::string>* grad_to_var) {
std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs =
PADDLE_ENFORCE(grad_op_descs->empty());
std::vector<std::unique_ptr<framework::OpDesc>> descs =
framework::OpInfoMap::Instance()
.Get(op_desc.Type())
.GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block);
PADDLE_ENFORCE(grad_op_descs.size() == 1, "Only support 1 grad op now.");
// TODO(panyx0718): Leak?
// TODO(marsyang1993): Change grad_op_desc pointer to
// vector<framework::OpDesc*> to allow multi grad_op
*grad_op_desc = grad_op_descs[0].release();
for (auto& desc : descs) {
grad_op_descs->emplace_back(desc.release());
}
}
void InitVar(framework::Variable* var, framework::Variable* grad_var,
......@@ -140,49 +140,52 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx));
if (!stop_gradient) {
framework::OpDesc* grad_op_desc;
// TODO(panyx): Is this leaked?
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
new std::unordered_map<std::string, std::string>());
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var.get());
op->grad_op_desc_ = grad_op_desc;
for (auto it : grad_op_desc->Inputs()) {
auto& grad_in_vars = op->grad_input_vars_[it.first];
for (const std::string& grad_invar : it.second) {
block->FindRecursiveOrCreateVar(grad_invar);
auto var_it = grad_to_var->find(grad_invar);
if (var_it == grad_to_var->end()) {
auto fwd_var_it = vars.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != vars.end());
// Forward inputs or outputs.
grad_in_vars.push_back(fwd_var_it->second->var_);
} else {
CreateGradOp(*op_desc, {}, {block}, &op->grad_op_descs_, grad_to_var.get());
op->grad_input_vars_.resize(op->grad_op_descs_.size());
op->grad_output_vars_.resize(op->grad_op_descs_.size());
for (size_t i = 0; i < op->grad_op_descs_.size(); ++i) {
framework::OpDesc* grad_op_desc = op->grad_op_descs_[i];
for (auto it : grad_op_desc->Inputs()) {
auto& grad_in_vars = op->grad_input_vars_[i][it.first];
for (const std::string& grad_invar : it.second) {
block->FindRecursiveOrCreateVar(grad_invar);
auto var_it = grad_to_var->find(grad_invar);
if (var_it == grad_to_var->end()) {
auto fwd_var_it = vars.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != vars.end());
// Forward inputs or outputs.
grad_in_vars.push_back(fwd_var_it->second->var_);
} else {
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_,
prepared_op.GetDeviceContext());
}
// Douts.
grad_in_vars.push_back(var->grads_->var_);
}
}
}
for (auto it : grad_op_desc->Outputs()) {
auto& grad_out_vars = op->grad_output_vars_[i][it.first];
for (const std::string& grad_outvar : it.second) {
block->FindRecursiveOrCreateVar(grad_outvar);
auto var_it = grad_to_var->find(grad_outvar);
PADDLE_ENFORCE(var_it != grad_to_var->end(),
"Could not found the grad op output var, should this "
"operator %s's stop gradient be True",
op_desc->Type());
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_,
prepared_op.GetDeviceContext());
}
// Douts.
grad_in_vars.push_back(var->grads_->var_);
}
}
}
for (auto it : grad_op_desc->Outputs()) {
auto& grad_out_vars = op->grad_output_vars_[it.first];
for (const std::string& grad_outvar : it.second) {
block->FindRecursiveOrCreateVar(grad_outvar);
auto var_it = grad_to_var->find(grad_outvar);
PADDLE_ENFORCE(var_it != grad_to_var->end(),
"Could not found the grad op output var, should this "
"operator %s's stop gradient be True",
op_desc->Type());
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_, prepared_op.GetDeviceContext());
grad_out_vars.push_back(var->grads_->var_);
}
grad_out_vars.push_back(var->grads_->var_);
}
}
}
......@@ -211,10 +214,12 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient);
}
if (!stop_gradient) {
op->grad_input_vars_.resize(1);
op->grad_output_vars_.resize(1);
auto& grad_input_vars =
op->grad_input_vars_[framework::GradVarName(PyLayer::kFwdInp)];
op->grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)];
auto& grad_output_vars =
op->grad_output_vars_[framework::GradVarName(PyLayer::kFwdOut)];
op->grad_output_vars_[0][framework::GradVarName(PyLayer::kFwdOut)];
for (const VarBase* inp : inputs) {
grad_input_vars.push_back(inp->var_);
......
......@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
......@@ -130,6 +131,8 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int);
DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode,
contrib::AnalysisConfig::Precision);
// Memory optimized related.
DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
......
......@@ -36,6 +36,14 @@ void SetAttr<int>(framework::proto::OpDesc *op, const std::string &name,
attr->set_i(data);
}
template <>
void SetAttr<bool>(framework::proto::OpDesc *op, const std::string &name,
const bool &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::BOOLEAN);
attr->set_b(data);
}
template <>
void SetAttr<int64_t>(framework::proto::OpDesc *op, const std::string &name,
const int64_t &data) {
auto *attr = op->add_attrs();
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <sys/stat.h>
#include <cstdio>
#include <fstream>
#include <set>
#include <string>
#include <typeindex>
#include <unordered_map>
......@@ -29,9 +30,14 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h"
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#define GCC_ATTRIBUTE(attr__) ;
#define MKDIR(path) _mkdir(path)
#else
#include <unistd.h>
#define GCC_ATTRIBUTE(attr__) __attribute__((attr__));
#define MKDIR(path) mkdir(path, S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH)
#endif
#define __SHOULD_USE_RESULT__ GCC_ATTRIBUTE(warn_unused_result)
......@@ -163,6 +169,54 @@ static bool PathExists(const std::string &path) {
return false;
}
static std::string GetDirRoot(const std::string &path) {
char sep = '/';
#ifdef _WIN32
sep = '\\';
#endif
size_t i = path.rfind(sep, path.length());
if (i != std::string::npos) {
return (path.substr(0, i));
}
return path;
}
static std::string GetOrCreateModelOptCacheDir(const std::string &model_root) {
std::string opt_cache_dir = model_root + "/_opt_cache/";
if (!PathExists(opt_cache_dir)) {
PADDLE_ENFORCE(MKDIR(opt_cache_dir.c_str()) != -1,
"Can not create optimize cache directory: %s, Make sure you "
"have permission to write",
opt_cache_dir);
}
return opt_cache_dir;
}
static std::string GetTrtCalibPath(const std::string &model_root,
const std::string &engine_key) {
return model_root + "/trt_calib_" + engine_key;
}
// If there is no calib table data file in model_opt_cache_dir, return "".
static std::string GetTrtCalibTableData(const std::string &model_opt_cache_dir,
const std::string &engine_key,
bool enable_int8) {
std::string trt_calib_table_path =
GetTrtCalibPath(model_opt_cache_dir, engine_key);
if (enable_int8 && FileExists(trt_calib_table_path)) {
VLOG(3) << "Calibration table file: " << trt_calib_table_path
<< "is found here";
std::ifstream infile(trt_calib_table_path, std::ios::in);
std::stringstream buffer;
buffer << infile.rdbuf();
std::string calibration_data(buffer.str());
return calibration_data;
}
return "";
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......
......@@ -67,6 +67,20 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("max_batch_size", new int(argument->tensorrt_max_batch_size()));
pass->Set("min_subgraph_size",
new int(argument->tensorrt_min_subgraph_size()));
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
bool enable_int8 = argument->tensorrt_precision_mode() ==
contrib::AnalysisConfig::Precision::kInt8;
pass->Set("enable_int8", new bool(enable_int8));
std::string model_opt_cache_dir =
argument->Has("model_dir")
? argument->model_dir()
: GetDirRoot(argument->model_program_path());
pass->Set(
"model_opt_cache_dir",
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
}
// graph_ = pass->Apply(std::move(graph_));
......@@ -91,11 +105,14 @@ std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
}
framework::proto::ProgramDesc IRPassManager::AcquireProgram(
std::unique_ptr<Graph> *graph, const ProgramDesc &program) const {
std::unique_ptr<Graph> *graph, ProgramDesc *program) const {
auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");
ProgramDesc desc(program);
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
ProgramDesc desc;
desc.CopyFrom(*program->Proto());
pass->SetNotOwned("program", &desc);
auto *the_graph = graph->release();
*graph = pass->Apply(std::unique_ptr<Graph>(the_graph));
......
......@@ -29,6 +29,7 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/helper.h"
namespace paddle {
namespace inference {
......@@ -42,8 +43,8 @@ class IRPassManager final {
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph);
framework::proto::ProgramDesc AcquireProgram(
std::unique_ptr<Graph> *graph, const ProgramDesc &program) const;
framework::proto::ProgramDesc AcquireProgram(std::unique_ptr<Graph> *graph,
ProgramDesc *program) const;
framework::ir::Graph &graph() const { return *graph_; }
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <algorithm>
#include <set>
#include <string>
#include <vector>
......@@ -67,12 +68,33 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
return graph;
}
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
const std::set<std::string> &engine_outputs) {
std::string engine_hash_key = "";
for (auto name : engine_inputs) {
engine_hash_key += name;
}
for (auto name : engine_outputs) {
engine_hash_key += name;
}
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
return engine_key;
}
void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
Graph *graph) const {
auto *op_desc = node->Op();
auto &subgraph = *Agent(node).subgraph();
PADDLE_ENFORCE(!subgraph.empty());
framework::ProgramDesc *program_desc =
Get<framework::ProgramDesc *>("program");
// Add new block for TensorRTEngineOP
const framework::BlockDesc &main_block =
program_desc->Block(framework::kRootBlockIndex);
// const framework::BlockDesc& main_block = program_desc->Block(0);
framework::BlockDesc *new_block = program_desc->AppendBlock(main_block);
// An fake block desc.
framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto);
......@@ -82,13 +104,18 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
subgraph.size());
for (auto *node : subgraph) {
auto *new_block_op = new_block->AppendOp();
auto *op = block_desc.AppendOp();
*new_block_op->Proto() = *node->Op()->Proto();
*op->Proto() = *node->Op()->Proto();
}
// collect inputs
std::unordered_set<std::string> input_names;
std::unordered_set<std::string> input_names_with_id;
// Then, we will use the input_names_with_id and output_names_with_id to
// generate the eigine key.
// So, We use set instead of unordered_set here to ensure that the engine key
// is unique.
std::set<std::string> input_names;
std::set<std::string> input_names_with_id;
for (auto *x : node->inputs) {
input_names.insert(x->Name());
input_names_with_id.insert(x->Name() + std::to_string(x->id()));
......@@ -96,8 +123,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
op_desc->SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
std::unordered_set<std::string> output_names;
std::unordered_set<std::string> output_names_with_id;
std::set<std::string> output_names;
std::set<std::string> output_names_with_id;
for (auto *x : node->outputs) {
output_names.insert(x->Name());
output_names_with_id.insert(x->Name() + std::to_string(x->id()));
......@@ -182,7 +209,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
// to Tensor.
std::vector<std::string> output_mapping;
for (auto name : output_names) {
// LOG(INFO) << name << " " << output_name_map.size();
PADDLE_ENFORCE(output_name_map.count(name) != 0);
output_mapping.push_back(output_name_map[name]);
}
......@@ -193,16 +219,29 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
*vars->Add() = *node->Var()->Proto();
}
}
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(),
"the block has no var-desc");
PADDLE_ENFORCE(!output_mapping.empty());
// Set attrs
op_desc->SetBlockAttr("sub_block", new_block);
SetAttr(op_desc->Proto(), "subgraph",
block_desc.Proto()->SerializeAsString());
// Set attrs
SetAttr(op_desc->Proto(), "max_batch_size", Get<int>("max_batch_size"));
SetAttr(op_desc->Proto(), "workspace_size", Get<int>("workspace_size"));
SetAttr(op_desc->Proto(), "parameters", ExtractParameters(graph->Nodes()));
SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
auto enable_int8 = Get<bool>("enable_int8");
auto engine_key =
GenerateEngineKey(input_names_with_id, output_names_with_id);
std::string calibration_data = GetTrtCalibTableData(
Get<std::string>("model_opt_cache_dir"), engine_key, enable_int8);
SetAttr(op_desc->Proto(), "calibration_data", calibration_data);
SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
SetAttr(op_desc->Proto(), "engine_key", engine_key);
}
std::vector<std::string> ExtractParameters(
......
......@@ -31,7 +31,11 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) {
}
std::unique_ptr<Graph> graph(argument->main_graph_ptr());
framework::ProgramDesc desc(argument->main_program());
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
framework::ProgramDesc desc;
desc.CopyFrom(*argument->main_program().Proto());
pass->SetNotOwned("program", &desc);
auto thegraph = pass->Apply(std::move(graph));
thegraph.release(); // the argument still own the graph.
......
......@@ -102,6 +102,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
CP_MEMBER(tensorrt_workspace_size_);
CP_MEMBER(tensorrt_max_batchsize_);
CP_MEMBER(tensorrt_min_subgraph_size_);
CP_MEMBER(tensorrt_precision_mode_);
// MKLDNN releated.
CP_MEMBER(use_mkldnn_);
CP_MEMBER(mkldnn_enabled_op_types_);
......@@ -141,9 +142,9 @@ void contrib::AnalysisConfig::EnableMKLDNN() {
Update();
}
void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
int max_batch_size,
int min_subgraph_size) {
void contrib::AnalysisConfig::EnableTensorRtEngine(
int workspace_size, int max_batch_size, int min_subgraph_size,
contrib::AnalysisConfig::Precision precision_mode) {
#ifdef PADDLE_WITH_CUDA
if (!use_gpu()) {
LOG(ERROR) << "To use TensorRT engine, please call EnableGpu() first";
......@@ -154,6 +155,7 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
tensorrt_workspace_size_ = workspace_size;
tensorrt_max_batchsize_ = max_batch_size;
tensorrt_min_subgraph_size_ = min_subgraph_size;
tensorrt_precision_mode_ = precision_mode;
Update();
#else
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <glog/logging.h>
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <vector>
......@@ -25,6 +26,7 @@
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
......@@ -37,6 +39,8 @@
#if PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#endif
DECLARE_bool(profile);
......@@ -44,6 +48,12 @@ DECLARE_bool(profile);
namespace paddle {
using contrib::AnalysisConfig;
using inference::Singleton;
#if PADDLE_WITH_TENSORRT
using inference::tensorrt::TRTInt8Calibrator;
using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorEngineManager;
#endif
namespace {
bool IsPersistable(const framework::VarDesc *var) {
......@@ -339,6 +349,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
!config_.params_file().empty(),
"Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file().empty());
std::string dir = inference::analysis::GetDirRoot(config_.prog_file());
argument_.SetModelProgramPath(config_.prog_file());
argument_.SetModelParamsPath(config_.params_file());
}
......@@ -349,6 +361,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_);
argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
}
if (config_.use_mkldnn_) {
......@@ -569,7 +582,67 @@ bool AnalysisPredictor::LoadParameters() {
return true;
}
#if PADDLE_WITH_TENSORRT
bool AnalysisPredictor::SaveTrtCalibToDisk() {
PADDLE_ENFORCE(config_.tensorrt_engine_enabled(),
"This func can be invoked only in trt mode");
auto &block = inference_program_->Block(0);
for (auto &op_desc : block.AllOps()) {
if (op_desc->Type() == "tensorrt_engine") {
std::string engine_name =
boost::get<std::string>(op_desc->GetAttr("engine_key"));
if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_name)) {
LOG(ERROR) << "You should run the predictor(with trt) on the real data "
"to generate calibration info";
return false;
}
TRTCalibratorEngine *calib_engine =
Singleton<TRTCalibratorEngineManager>::Global().Get(engine_name);
LOG(INFO) << "Wait for calib threads done.";
calib_engine->calib_->waitAndSetDone();
LOG(INFO) << "Generating TRT Calibration table data, this may cost a lot "
"of time...";
calib_engine->thr_->join();
std::string calibration_table_data =
calib_engine->calib_->getCalibrationTableAsString();
if (calibration_table_data.empty()) {
LOG(ERROR) << "the calibration table is empty.";
return false;
}
std::string model_opt_cache_dir =
argument_.Has("model_dir")
? argument_.model_dir()
: inference::analysis::GetDirRoot(argument_.model_program_path());
std::string calibration_table_data_path =
inference::analysis::GetTrtCalibPath(
inference::analysis::GetOrCreateModelOptCacheDir(
model_opt_cache_dir),
engine_name);
std::ofstream ofile(calibration_table_data_path, std::ios::out);
LOG(INFO) << "Write Paddle-TRT INT8 calibration table data to file "
<< calibration_table_data_path;
ofile << calibration_table_data;
ofile.close();
}
}
// Free all calibrator resources.
Singleton<TRTCalibratorEngineManager>::Global().DeleteALL();
return true;
}
#endif
AnalysisPredictor::~AnalysisPredictor() {
#if PADDLE_WITH_TENSORRT
if (config_.tensorrt_engine_enabled() &&
config_.tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
Singleton<TRTCalibratorEngineManager>::Global().Has()) {
SaveTrtCalibToDisk();
}
#endif
if (FLAGS_profile) {
platform::DisableProfiler(platform::EventSortingKey::kTotal,
"./profile.log");
......
......@@ -97,6 +97,21 @@ class AnalysisPredictor : public PaddlePredictor {
void GetFetchOne(const framework::LoDTensor &fetchs,
PaddleTensor *output_data);
#if PADDLE_WITH_TENSORRT
// When we use Paddle-TRT INT8 engine, we need to generate calibration table
// data first,
// the calibration table contains the range for each op's input and output,
// this whole process can be divided into several steps:
//
// 1. Builds a 32-bit engine, runs it on the calibration set, and records a
// histogram for each
// tensor of the distribution of activation values.
// 2. Builds a calibration table from the histograms.
//
// After step 2, we need to store the calibration table on disk
bool SaveTrtCalibToDisk();
#endif
// Some more detailed tests, they are made the friends of the predictor, so that
// the all the details can be tested.
#if PADDLE_WITH_TESTING
......
......@@ -42,6 +42,10 @@ struct AnalysisConfig {
explicit AnalysisConfig(const std::string& model_dir);
explicit AnalysisConfig(const std::string& prog_file,
const std::string& params_file);
enum class Precision {
kFloat32 = 0,
kInt8,
};
/** Set model with a directory.
*/
......@@ -135,7 +139,8 @@ struct AnalysisConfig {
* subgraph is less than this, it will not transfer to TensorRT engine.
*/
void EnableTensorRtEngine(int workspace_size = 1 << 20,
int max_batch_size = 1, int min_subgraph_size = 3);
int max_batch_size = 1, int min_subgraph_size = 3,
Precision precision = Precision::kFloat32);
/** A boolean state telling whether the TensorRT engine is used.
*/
bool tensorrt_engine_enabled() const { return use_tensorrt_; }
......@@ -229,6 +234,7 @@ struct AnalysisConfig {
// We set this variable to control the minimum number of nodes in the
// subgraph, 3 as default value.
int tensorrt_min_subgraph_size_{3};
Precision tensorrt_precision_mode_;
// memory reuse related.
bool enable_memory_optim_{false};
......
nv_library(tensorrt_engine SRCS engine.cc DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context)
nv_library(tensorrt_engine SRCS engine.cc trt_int8_calibrator.cc DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context)
nv_library(tensorrt_op_teller SRCS op_teller.cc DEPS framework_proto)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
......
......@@ -69,6 +69,13 @@ void TensorRTEngine::FreezeNetwork() {
// build engine.
infer_builder_->setMaxBatchSize(max_batch_);
infer_builder_->setMaxWorkspaceSize(max_workspace_);
if (enable_int8_) {
infer_builder_->setInt8Mode(true);
PADDLE_ENFORCE(
calibrator_ != nullptr,
"The precision mode is 'INT8', the calibrator should not be nullptr");
infer_builder_->setInt8Calibrator(calibrator_);
}
infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");
......
......@@ -23,12 +23,14 @@ limitations under the License. */
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class TRTInt8Calibrator;
/*
* TensorRT Engine.
*
......@@ -55,13 +57,16 @@ class TensorRTEngine : public EngineBase {
};
TensorRTEngine(int max_batch, int max_workspace, cudaStream_t stream,
int device = 0,
int device = 0, bool enable_int8 = false,
TRTInt8Calibrator* calibrator = nullptr,
nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch),
max_workspace_(max_workspace),
stream_(stream),
logger_(logger),
device_(device) {}
device_(device),
enable_int8_(enable_int8),
calibrator_(calibrator),
logger_(logger) {}
virtual ~TensorRTEngine();
......@@ -139,8 +144,8 @@ class TensorRTEngine : public EngineBase {
// In the normal case, the paddle-trt exists bug when runing the googlenet.
// When there are more than two convolutions of 1 * 1 with the same input, the
// paddle-tensorrt will do the merging optimization, which fuse those conv
// into
// one conv, and then trigger bug. So, We should use strategy to avoid this
// into one conv, and then trigger bug. So, We should use strategy to avoid
// this
// optimization for the time being. This bug will be fixed in the future.
std::unordered_map<std::string /*name*/, int /*ITensor_quote_num*/>
itensor_quote_num;
......@@ -153,9 +158,14 @@ class TensorRTEngine : public EngineBase {
// the max memory size the engine uses
int max_workspace_;
cudaStream_t stream_;
// The specific GPU id that the TensorRTEngine bounded to.
int device_;
bool enable_int8_;
TRTInt8Calibrator* calibrator_;
// batch size of the current data, will be updated each Executation.
int batch_size_{-1};
cudaStream_t stream_;
nvinfer1::ILogger& logger_;
......@@ -165,8 +175,6 @@ class TensorRTEngine : public EngineBase {
std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/>
itensor_map_;
// The specific GPU id that the TensorRTEngine bounded to.
int device_;
std::vector<std::unique_ptr<plugin::PluginTensorRT>> owned_plugin_;
// TensorRT related internal members
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "glog/logging.h"
namespace paddle {
namespace inference {
namespace tensorrt {
// set the batch size before constructing the thread to execute engine
int TRTInt8Calibrator::getBatchSize() const { return batch_size_; }
TRTInt8Calibrator::TRTInt8Calibrator(
const std::unordered_map<std::string, size_t>& buffers, int batch_size,
std::string engine_name, const platform::Place place)
: batch_size_(batch_size), engine_name_(engine_name) {
int i = 0;
VLOG(4) << "Init a new calibrator: " << engine_name_;
for (const auto it : buffers) {
framework::Tensor temp_tensor;
std::string input_name = it.first;
int data_size = it.second;
int num_ele = data_size / sizeof(int16_t);
framework::DDim data_shape = framework::make_ddim({num_ele});
temp_tensor.Resize(data_shape);
data_tensors_.push_back(temp_tensor);
data_buffers_[input_name] = std::pair<void*, size_t>(
static_cast<void*>(temp_tensor.mutable_data<int16_t>(place)), num_ele);
i += 1;
}
}
TRTInt8Calibrator::TRTInt8Calibrator(const std::string& calib_data)
: batch_size_(0),
calib_running_(false),
data_is_set_(false),
done_(true),
calibration_table_(calib_data) {}
void TRTInt8Calibrator::waitAndSetDone() {
std::unique_lock<std::mutex> lk(mut_);
while ((calib_running_ || data_is_set_) && !done_) cond_.wait(lk);
if (!done_) {
done_ = true;
cond_.notify_all();
}
}
// There might be more than one input for trt subgraph,
// So, we use a map to store input information.
bool TRTInt8Calibrator::setBatch(
const std::unordered_map<std::string, void*>& data) {
VLOG(3) << "set batch: " << engine_name_;
std::unique_lock<std::mutex> lk(mut_);
// There is a producer and a consumer. The producer set the batch data and
// the consumer get the batch data. The size of the data pool is one.
// So, the producer has to wait for the consumer to finish processing before
// they can set the data.
while ((calib_running_ || data_is_set_) && (!done_)) cond_.wait(lk);
// The done_ is set to true using waitAndSetDone, When all calibration data
// are processed.
if (done_) return false;
// Sets the batch.
for (const auto& it : data) {
auto dataptr = data_buffers_.find(it.first);
if (dataptr == data_buffers_.end()) {
LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first
<< "' does not match with the buffer names";
}
const auto& d = dataptr->second;
PADDLE_ENFORCE(
cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice),
"Fail to cudaMemcpy %s for %s", engine_name_, it.first);
}
data_is_set_ = true;
cond_.notify_all();
return true;
}
bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
int num_bindings) {
VLOG(4) << "get batch: " << engine_name_;
std::unique_lock<std::mutex> lk(mut_);
// The consumer has just finished processing a data.
// The producer can set the data again.
calib_running_ = false;
cond_.notify_all();
// As long as there is data in the pool, the consumer can get it.
while (!data_is_set_ && !done_) cond_.wait(lk);
if (done_) return false;
// Gets the batch
for (int i = 0; i < num_bindings; i++) {
auto it = data_buffers_.find(names[i]);
if (it == data_buffers_.end()) {
LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
<< names[i] << "' at position " << i;
}
bindings[i] = it->second.first;
}
data_is_set_ = false;
calib_running_ = true;
VLOG(4) << "get batch done: " << engine_name_;
return true;
}
void TRTInt8Calibrator::setDone() {
std::unique_lock<std::mutex> lk(mut_);
done_ = true;
cond_.notify_all();
}
const void* TRTInt8Calibrator::readCalibrationCache(size_t& length) {
if (calibration_table_.empty()) return nullptr;
length = calibration_table_.size();
return calibration_table_.data();
}
void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
std::size_t length) {
calibration_table_ = std::string((const char*)ptr, length);
VLOG(4) << "Got calibration data for " << engine_name_ << " " << ptr
<< " length=" << length;
}
TRTInt8Calibrator::~TRTInt8Calibrator() {
VLOG(4) << "Destroying calibrator for " << engine_name_;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <NvInfer.h>
#include <cuda_runtime_api.h>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class TensorRTEngine;
struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
public:
TRTInt8Calibrator(const std::unordered_map<std::string, size_t>& buffers,
int batch_size, std::string engine_name,
const platform::Place place);
explicit TRTInt8Calibrator(const std::string& calibration_data);
~TRTInt8Calibrator();
int getBatchSize() const override;
bool getBatch(void* bindings[], const char* names[],
int num_bindings) override;
bool setBatch(const std::unordered_map<std::string, void*>& data);
void setDone();
void waitAndSetDone();
const void* readCalibrationCache(std::size_t& length) override;
void writeCalibrationCache(const void* ptr, std::size_t length) override;
const std::string& getCalibrationTableAsString() {
return calibration_table_;
}
private:
const int batch_size_;
bool calib_running_{true};
bool data_is_set_{false};
bool done_{false};
std::mutex mut_;
std::condition_variable cond_;
std::unordered_map<std::string, std::pair<void*, size_t>> data_buffers_;
std::vector<framework::Tensor> data_tensors_;
std::string engine_name_;
std::string calibration_table_;
};
class TRTCalibratorEngine {
public:
TRTCalibratorEngine() {}
std::unique_ptr<TRTInt8Calibrator> calib_;
std::unique_ptr<std::thread> thr_;
std::unique_ptr<TensorRTEngine> engine_;
};
/*
* Manager to control the TensorRT Int8 calibration creation and deltetion.
*/
class TRTCalibratorEngineManager {
public:
bool Has() const { return res_.size() > 0; }
bool Has(const std::string& name) const {
if (res_.count(name) == 0) return false;
return res_.at(name).get() != nullptr;
}
// Get Int8Calibrator via name
TRTCalibratorEngine* Get(const std::string& name) const {
return res_.at(name).get();
}
// Look up or create a calibrator.
TRTCalibratorEngine* LookupOrCreate(const std::string& engine_name) {
if (res_.count(engine_name) == 0) {
auto* p = new TRTCalibratorEngine;
res_[engine_name].reset(p);
}
return res_.at(engine_name).get();
}
// Create an Int8Calibrator
TRTCalibratorEngine* Create(const std::string& engine_name) {
auto* p = new TRTCalibratorEngine;
res_[engine_name].reset(p);
return p;
}
void DeleteALL() {
for (auto& item : res_) {
item.second.reset(nullptr);
}
}
private:
std::unordered_map<std::string, std::unique_ptr<TRTCalibratorEngine>> res_;
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -54,6 +54,7 @@ else()
message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_seq_pool1")
endif()
# RNN2
set(RNN2_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn2")
download_model_and_data(${RNN2_INSTALL_DIR} "rnn2_model.tar.gz" "rnn2_data.txt.tar.gz")
......@@ -115,6 +116,10 @@ if (NOT EXISTS ${MOBILENET_INSTALL_DIR})
endif()
inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc SERIAL)
# googlenet
inference_analysis_api_test_with_fake_data(test_analyzer_googlenet
"${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz" SERIAL)
# resnet50
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" SERIAL)
......
......@@ -67,7 +67,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mid->mutable_data<T>(ctx.GetPlace());
const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha");
// MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
// Where sum of squares is divided by size of normalization window
// this is not the case for PaddlePaddle LRN.
// Hence we need to compensate for this diffrence by
// multipliing alpha by size of window(n)
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
const bool is_test = ctx.Attr<bool>("is_test");
......@@ -78,10 +84,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dims = paddle::framework::vectorize2int(x->dims());
auto src_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
auto dst_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
dims, mkldnn::memory::data_type::f32, x->format());
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
mkldnn::lrn_across_channels,
......@@ -92,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
k};
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine},
static_cast<void*>(output_data)};
if (!is_test) {
const std::string key = ctx.op().Output("Out");
......@@ -110,11 +111,16 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory->set_data_handle(
static_cast<void*>(const_cast<T*>(input_data)));
auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(),
static_cast<void*>(output_data));
auto workspace_memory = insert_to_context<mkldnn::memory>(
key_workspace_memory, dev_ctx,
forward_pd->workspace_primitive_desc());
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
} else {
auto forward_pd =
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
......@@ -122,8 +128,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
auto workspace_memory =
mkldnn::memory{forward_pd.workspace_primitive_desc()};
auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(),
static_cast<void*>(output_data));
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
}
}
};
......@@ -151,7 +162,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha");
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
......
......@@ -29,8 +29,14 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Xs", "A list of inputs.").AsDuplicable();
AddOutput("Ys", "A list of outputs").AsDuplicable();
AddAttr<std::string>("subgraph", "the subgraph.");
AddAttr<std::string>("calibration_data", "the calibration data for int8");
AddAttr<std::string>(
"engine_key",
"The engine_key here is used to distinguish different TRT Engines");
AddAttr<int>("max_batch_size", "the maximum batch size.");
AddAttr<int>("workspace_size", "the workspace size.");
AddAttr<framework::BlockDesc *>("sub_block", "the trt block");
AddAttr<bool>("enable_int8", "whether swith to int8 mode");
AddComment("TensorRT engine operator.");
}
};
......@@ -47,6 +53,6 @@ class TensorRTEngineInferVarType : public framework::VarTypeInference {
namespace ops = paddle::operators;
REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp,
ops::TensorRTEngineOpMaker);
ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker);
#endif // PADDLE_WITH_CUDA
......@@ -17,8 +17,10 @@
#ifdef PADDLE_WITH_CUDA
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h"
......@@ -62,6 +64,9 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
using inference::Singleton;
using inference::tensorrt::TensorRTEngine;
using inference::tensorrt::TRTInt8Calibrator;
using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorEngineManager;
class TensorRTEngineOp : public framework::OperatorBase {
private:
......@@ -70,6 +75,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
mutable std::unique_ptr<TensorRTEngine> trt_engine_;
int max_batch_size_;
int workspace_size_;
std::unique_ptr<TRTInt8Calibrator> calibrator_;
bool enable_int8_;
std::string calibration_data_;
std::string engine_key_;
bool calibration_mode_;
public:
TensorRTEngineOp(const std::string &type,
......@@ -80,19 +90,96 @@ class TensorRTEngineOp : public framework::OperatorBase {
input_names_ = Inputs("Xs");
max_batch_size_ = Attr<int>("max_batch_size");
workspace_size_ = Attr<int>("workspace_size");
enable_int8_ = Attr<bool>("enable_int8");
calibration_data_ = Attr<std::string>("calibration_data");
engine_key_ = Attr<std::string>("engine_key");
auto params = Attr<std::vector<std::string>>("parameters");
for (const auto &param : params) {
param_names_.insert(param);
}
// calibration_mode is ture represents we need to
// generate the calibration table data.
calibration_mode_ = (enable_int8_ && calibration_data_.size() == 0);
VLOG(4) << "calibration_mode: " << calibration_mode_;
if (enable_int8_ && calibration_data_.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
}
}
protected:
void RunNativeImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block");
auto *program = block->Program();
auto &current_scope = scope.NewScope();
auto ctx = executor.Prepare(*program, block->ID());
executor.RunPreparedContext(ctx.get(), &current_scope, false, true, true);
}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
if (calibration_mode_ == true) {
RunCalibration(scope, dev_place);
return;
}
RunTrt(scope, dev_place);
}
void RunCalibration(const framework::Scope &scope,
const platform::Place &dev_place) const {
// This process will builds a 32-bit trt engine, runs it on the calibration
// set, and records a histogram for each
// tensor of the distribution of activation values.
LOG_FIRST_N(INFO, 1) << "The TRT engine: " << engine_key_
<< " is running calibration trt int8... ";
int runtime_batch = 1;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_key_)) {
TRTCalibratorEngine *calib_res =
Singleton<TRTCalibratorEngineManager>::Global().Create(engine_key_);
std::unordered_map<std::string, size_t> calib_buffers;
for (auto &x : input_names_) {
if (param_names_.count(x)) continue;
auto &t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
calib_buffers[x] = t.memory_size();
auto t_shape = framework::vectorize(t.dims());
runtime_batch = t_shape[0];
}
calib_res->calib_.reset(new TRTInt8Calibrator(
calib_buffers, runtime_batch, engine_key_, dev_place));
calib_res->thr_.reset(new std::thread([&]() {
calib_res->engine_.reset(new TensorRTEngine(
max_batch_size_, workspace_size_, stream,
boost::get<platform::CUDAPlace>(dev_place).device, enable_int8_,
calib_res->calib_.get()));
VLOG(3) << "start the calib trt engine thread";
Prepare(scope, dev_place, calib_res->engine_.get());
}));
}
TRTInt8Calibrator *temp_calibrator =
Singleton<TRTCalibratorEngineManager>::Global()
.Get(engine_key_)
->calib_.get();
std::unordered_map<std::string, void *> calib_data;
for (auto &x : Inputs("Xs")) {
if (param_names_.count(x)) continue;
auto &t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
calib_data.emplace(x, t.data<void>());
}
temp_calibrator->setBatch(calib_data);
RunNativeImpl(scope, dev_place);
}
void RunTrt(const framework::Scope &scope,
const platform::Place &dev_place) const {
int runtime_batch = 1;
......@@ -101,9 +188,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
if (trt_engine_.get() == nullptr) {
trt_engine_.reset(new TensorRTEngine(
max_batch_size_, workspace_size_, stream,
boost::get<platform::CUDAPlace>(dev_place).device));
trt_engine_.reset(
new TensorRTEngine(max_batch_size_, workspace_size_, stream,
boost::get<platform::CUDAPlace>(dev_place).device,
enable_int8_, calibrator_.get()));
Prepare(scope, dev_place, trt_engine_.get());
}
......@@ -173,7 +261,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
void Prepare(const framework::Scope &scope, const platform::Place &dev_place,
TensorRTEngine *engine) const {
VLOG(4) << "Prepare engine";
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time.";
framework::proto::BlockDesc block_desc;
block_desc.ParseFromString(Attr<std::string>("subgraph"));
......
......@@ -96,19 +96,20 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetType("tensorrt_engine");
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", 2);
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 1 << 20);
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine");
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters",
std::vector<std::string>({}));
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(),
"output_name_mapping",
std::vector<std::string>({"z0"}));
engine_op_desc.SetBlockAttr("sub_block", &block_desc);
engine_op_desc.SetAttr("max_batch_size", static_cast<int>(2));
engine_op_desc.SetAttr("workspace_size", static_cast<int>(1 << 20));
engine_op_desc.SetAttr("parameters", std::vector<std::string>({}));
engine_op_desc.SetAttr("engine_key", std::string("a_engine"));
engine_op_desc.SetAttr("calibration_data", std::string(""));
engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z0"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
LOG(INFO) << "engine_op " << engine_op.get();
framework::Scope scope;
......@@ -190,20 +191,19 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x0"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z3"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", batch_size);
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 1 << 20);
SetAttr<std::vector<std::string>>(
engine_op_desc.Proto(), "parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"}));
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "b_engine");
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(),
"output_name_mapping",
std::vector<std::string>({"z3"}));
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
engine_op_desc.SetBlockAttr("sub_block", &block_desc);
engine_op_desc.SetAttr("max_batch_size", static_cast<int>(batch_size));
engine_op_desc.SetAttr("workspace_size", static_cast<int>(1 << 20));
engine_op_desc.SetAttr("parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"}));
engine_op_desc.SetAttr("engine_key", std::string("b_engine"));
engine_op_desc.SetAttr("calibration_data", std::string(""));
engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z3"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
// Execute them.
engine_op->Run(scope, place);
......
......@@ -180,8 +180,14 @@ void BindNativePredictor(py::module *m) {
}
void BindAnalysisConfig(py::module *m) {
py::class_<AnalysisConfig>(*m, "AnalysisConfig")
.def(py::init<const AnalysisConfig &>())
py::class_<AnalysisConfig> analysis_config(*m, "AnalysisConfig");
py::enum_<AnalysisConfig::Precision>(analysis_config, "Precision")
.value("Float32", AnalysisConfig::Precision::kFloat32)
.value("Int8", AnalysisConfig::Precision::kInt8)
.export_values();
analysis_config.def(py::init<const AnalysisConfig &>())
.def(py::init<const std::string &>())
.def(py::init<const std::string &, const std::string &>())
.def("set_model", (void (AnalysisConfig::*)(const std::string &)) &
......@@ -215,7 +221,8 @@ void BindAnalysisConfig(py::module *m) {
.def("specify_input_name", &AnalysisConfig::specify_input_name)
.def("enable_tensorrt_engine", &AnalysisConfig::EnableTensorRtEngine,
py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1,
py::arg("min_subgraph_size") = 3)
py::arg("min_subgraph_size") = 3,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32)
.def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
.def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug,
py::arg("x") = true)
......
......@@ -23,6 +23,7 @@ import argparse
import functools
import contextlib
import paddle.fluid.profiler as profiler
from paddle.dataset.common import download
from PIL import Image, ImageEnhance
import math
sys.path.append('..')
......@@ -116,27 +117,44 @@ def val(data_dir=DATA_DIR):
return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
class TestCalibration(unittest.TestCase):
class TestCalibrationForResnet50(unittest.TestCase):
def setUp(self):
# TODO(guomingz): Put the download process in the cmake.
# Download and unzip test data set
imagenet_dl_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/calibration_test_data.tar.gz'
zip_file_name = imagenet_dl_url.split('/')[-1]
cmd = 'rm -rf data {} && mkdir data && wget {} && tar xvf {} -C data'.format(
zip_file_name, imagenet_dl_url, zip_file_name)
os.system(cmd)
# resnet50 fp32 data
resnet50_fp32_model_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/resnet50_int8_model.tar.gz'
resnet50_zip_name = resnet50_fp32_model_url.split('/')[-1]
resnet50_unzip_folder_name = 'resnet50_fp32'
cmd = 'rm -rf {} {} && mkdir {} && wget {} && tar xvf {} -C {}'.format(
resnet50_unzip_folder_name, resnet50_zip_name,
resnet50_unzip_folder_name, resnet50_fp32_model_url,
resnet50_zip_name, resnet50_unzip_folder_name)
self.int8_download = 'int8/download'
self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
self.int8_download)
data_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/calibration_test_data.tar.gz'
data_md5 = '1b6c1c434172cca1bf9ba1e4d7a3157d'
self.data_cache_folder = self.download_data(data_url, data_md5, "data")
# reader/decorator.py requires the relative path to the data folder
cmd = 'rm -rf {0} && ln -s {1} {0}'.format("data",
self.data_cache_folder)
os.system(cmd)
self.iterations = 50
def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder,
zip_path)
os.system(cmd)
def download_data(self, data_url, data_md5, folder_name):
download(data_url, self.int8_download, data_md5)
data_cache_folder = os.path.join(self.cache_folder, folder_name)
file_name = data_url.split('/')[-1]
zip_path = os.path.join(self.cache_folder, file_name)
self.cache_unzipping(data_cache_folder, zip_path)
return data_cache_folder
def download_resnet50_model(self):
# resnet50 fp32 data
data_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/resnet50_int8_model.tar.gz'
data_md5 = '4a5194524823d9b76da6e738e1367881'
self.model_cache_folder = self.download_data(data_url, data_md5,
"resnet50_fp32")
def run_program(self, model_path, generate_int8=False, algo='direct'):
image_shape = [3, 224, 224]
os.environ['FLAGS_use_mkldnn'] = 'True'
......@@ -204,14 +222,32 @@ class TestCalibration(unittest.TestCase):
calibrator.save_int8_model()
print(
"Calibration is done and the corresponding files were generated at {}".
"Calibration is done and the corresponding files are generated at {}".
format(os.path.abspath("calibration_out")))
else:
return np.sum(test_info) / cnt
def test_calibration_for_resnet50(self):
fp32_acc1 = self.run_program("resnet50_fp32/model")
self.run_program("resnet50_fp32/model", True)
def test_calibration(self):
self.download_resnet50_model()
fp32_acc1 = self.run_program(self.model_cache_folder + "/model")
self.run_program(self.model_cache_folder + "/model", True)
int8_acc1 = self.run_program("calibration_out")
delta_value = np.abs(fp32_acc1 - int8_acc1)
self.assertLess(delta_value, 0.01)
class TestCalibrationForMobilenetv1(TestCalibrationForResnet50):
def download_mobilenetv1_model(self):
# mobilenetv1 fp32 data
data_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
data_md5 = '13892b0716d26443a8cdea15b3c6438b'
self.model_cache_folder = self.download_data(data_url, data_md5,
"mobilenetv1_fp32")
def test_calibration(self):
self.download_mobilenetv1_model()
fp32_acc1 = self.run_program(self.model_cache_folder + "/model")
self.run_program(self.model_cache_folder + "/model", True, algo='KL')
int8_acc1 = self.run_program("calibration_out")
delta_value = np.abs(fp32_acc1 - int8_acc1)
self.assertLess(delta_value, 0.01)
......
......@@ -189,6 +189,18 @@ class SimpleRNN(fluid.imperative.Layer):
class TestImperative(unittest.TestCase):
def test_sum_op(self):
x = np.ones([2, 2], np.float32)
with fluid.imperative.guard():
inputs = []
for _ in range(10):
inputs.append(fluid.imperative.base.to_variable(x))
ret = fluid.layers.sums(inputs)
loss = fluid.layers.reduce_sum(ret)
loss._backward()
self.assertTrue(np.allclose(ret._numpy(), x * 10))
self.assertTrue(np.allclose(inputs[0]._gradient(), x))
def test_layer(self):
with fluid.imperative.guard():
cl = core.Layer()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册