提交 06a088a1 编写于 作者: N nhzlx

fix comments and fix cpplint

test=develop
上级 0ed63b21
......@@ -25,7 +25,7 @@ namespace ir {
static const char kParamScopeAttr[] = "__param_scope__";
static const char kFuseStatisAttr[] = "__fuse_statis__";
// When we use trt or other third_party lib, the parameters are managered by
// When we use trt or other third_party lib, the parameters are managed by
// the lib, but not the fluid. So we need to record them to avoid duplicate
// allocation.
static const char kRepetitiveParamAttr[] = "__repetitive_param__";
......
......@@ -17,10 +17,12 @@ limitations under the License. */
#include <sys/stat.h>
#include <cstdio>
#include <fstream>
#include <memory>
#include <set>
#include <string>
#include <typeindex>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
......
......@@ -22,7 +22,10 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
......
......@@ -235,7 +235,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::string trt_engine_serialized_data = GetTrtEngineSerializedData(
Get<std::string>("model_opt_cache_dir"), engine_key);
if (trt_engine_serialized_data.size() == 0) {
if (trt_engine_serialized_data.empty()) {
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time.";
std::unique_ptr<tensorrt::TensorRTEngine> trt_engine(
......
......@@ -13,9 +13,12 @@
// limitations under the License.
#pragma once
#include <paddle/fluid/framework/ir/fuse_pass_base.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
......
......@@ -15,6 +15,7 @@
#pragma once
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/naive_executor.h"
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
......
......@@ -19,7 +19,9 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
......
......@@ -15,6 +15,7 @@
#pragma once
#include <thrust/device_vector.h>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
......
......@@ -17,6 +17,7 @@
#include <NvInfer.h>
#include <cstring>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
......
......@@ -17,6 +17,7 @@
#include <NvInfer.h>
#include <cstring>
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
......
......@@ -35,7 +35,12 @@ class TensorRTEngineTest : public ::testing::Test {
engine_->InitNetwork();
}
void TearDown() override { delete engine_; }
void TearDown() override {
if (engine_) {
delete engine_;
engine_ = nullptr;
}
}
void PrepareInputOutput(const std::vector<float> &input,
std::vector<int> output_shape) {
......
......@@ -16,8 +16,10 @@
#ifdef PADDLE_WITH_CUDA
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/executor.h"
......@@ -220,11 +222,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
TensorRTEngine *GetEngine(const framework::Scope &scope,
const platform::Place &dev_place) const {
if (trt_engine_.get() == nullptr) {
if (!trt_engine_) {
trt_engine_.reset(new inference::tensorrt::TensorRTEngine(
max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(),
boost::get<platform::CUDAPlace>(dev_place).device));
if (engine_serialized_data_.size() > 0) {
if (!engine_serialized_data_.empty()) {
trt_engine_->Deserialize(engine_serialized_data_);
} else {
PrepareTRTEngine(scope, trt_engine_.get());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册