未验证 提交 b7b68f2a 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #15461 from NHZlX/fix_trt_stream_bug

fix trt stream bug.
...@@ -29,9 +29,9 @@ TEST(OpConverter, ConvertBlock) { ...@@ -29,9 +29,9 @@ TEST(OpConverter, ConvertBlock) {
// init trt engine // init trt engine
cudaStream_t stream_; cudaStream_t stream_;
std::unique_ptr<TensorRTEngine> engine_; std::unique_ptr<TensorRTEngine> engine_;
engine_.reset(new TensorRTEngine(5, 1 << 15, &stream_));
engine_->InitNetwork();
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0); PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
engine_.reset(new TensorRTEngine(5, 1 << 15, stream_));
engine_->InitNetwork();
engine_->DeclareInput("conv2d-X", nvinfer1::DataType::kFLOAT, engine_->DeclareInput("conv2d-X", nvinfer1::DataType::kFLOAT,
nvinfer1::Dims3(2, 5, 5)); nvinfer1::Dims3(2, 5, 5));
......
...@@ -78,11 +78,9 @@ class TRTConvertValidation { ...@@ -78,11 +78,9 @@ class TRTConvertValidation {
scope_(scope), scope_(scope),
if_add_batch_(if_add_batch), if_add_batch_(if_add_batch),
max_batch_size_(max_batch_size) { max_batch_size_(max_batch_size) {
// create engine.
engine_.reset(new TensorRTEngine(max_batch_size, workspace_size, &stream_));
engine_->InitNetwork();
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0); PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
engine_.reset(new TensorRTEngine(max_batch_size, workspace_size, stream_));
engine_->InitNetwork();
} }
// Declare a Variable as input with random initialization. // Declare a Variable as input with random initialization.
...@@ -175,7 +173,7 @@ class TRTConvertValidation { ...@@ -175,7 +173,7 @@ class TRTConvertValidation {
op_->Run(scope_, place); op_->Run(scope_, place);
// Execute TRT. // Execute TRT.
engine_->Execute(batch_size); engine_->Execute(batch_size);
cudaStreamSynchronize(*engine_->stream()); cudaStreamSynchronize(engine_->stream());
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty()); ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
const size_t output_space_size = 3000; const size_t output_space_size = 3000;
...@@ -184,7 +182,7 @@ class TRTConvertValidation { ...@@ -184,7 +182,7 @@ class TRTConvertValidation {
std::vector<float> fluid_out; std::vector<float> fluid_out;
std::vector<float> trt_out(output_space_size); std::vector<float> trt_out(output_space_size);
engine_->GetOutputInCPU(output, &trt_out[0], output_space_size); engine_->GetOutputInCPU(output, &trt_out[0], output_space_size);
cudaStreamSynchronize(*engine_->stream()); cudaStreamSynchronize(engine_->stream());
auto* var = scope_.FindVar(output); auto* var = scope_.FindVar(output);
auto tensor = var->GetMutable<framework::LoDTensor>(); auto tensor = var->GetMutable<framework::LoDTensor>();
......
...@@ -42,14 +42,13 @@ void TensorRTEngine::Execute(int batch_size) { ...@@ -42,14 +42,13 @@ void TensorRTEngine::Execute(int batch_size) {
PADDLE_ENFORCE(buf.device == DeviceType::GPU); PADDLE_ENFORCE(buf.device == DeviceType::GPU);
buffers.push_back(buf.buffer); buffers.push_back(buf.buffer);
} }
PADDLE_ENFORCE_NOT_NULL(stream_); infer_context_->enqueue(batch_size, buffers.data(), stream_, nullptr);
infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr); cudaStreamSynchronize(stream_);
cudaStreamSynchronize(*stream_);
SetRuntimeBatch(batch_size); SetRuntimeBatch(batch_size);
} }
TensorRTEngine::~TensorRTEngine() { TensorRTEngine::~TensorRTEngine() {
cudaStreamSynchronize(*stream_); cudaStreamSynchronize(stream_);
// clean buffer // clean buffer
for (auto &buf : buffers_) { for (auto &buf : buffers_) {
if (buf.device == DeviceType::GPU && buf.buffer != nullptr) { if (buf.device == DeviceType::GPU && buf.buffer != nullptr) {
...@@ -173,7 +172,7 @@ void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst, ...@@ -173,7 +172,7 @@ void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst,
auto &buf = buffer(name); auto &buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, dst_size, PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, dst_size,
cudaMemcpyDeviceToDevice, *stream_), cudaMemcpyDeviceToDevice, stream_),
0); 0);
} }
...@@ -194,7 +193,7 @@ void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst, ...@@ -194,7 +193,7 @@ void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst,
auto &buf = buffer(name); auto &buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, dst_size, PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, dst_size,
cudaMemcpyDeviceToHost, *stream_)); cudaMemcpyDeviceToHost, stream_));
} }
Buffer &TensorRTEngine::buffer(const std::string &name) { Buffer &TensorRTEngine::buffer(const std::string &name) {
...@@ -211,12 +210,11 @@ void TensorRTEngine::SetInputFromCPU(const std::string &name, const void *data, ...@@ -211,12 +210,11 @@ void TensorRTEngine::SetInputFromCPU(const std::string &name, const void *data,
auto &buf = buffer(name); auto &buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer); PADDLE_ENFORCE_NOT_NULL(buf.buffer);
PADDLE_ENFORCE_NOT_NULL(data); PADDLE_ENFORCE_NOT_NULL(data);
PADDLE_ENFORCE_NOT_NULL(stream_);
PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small"); PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
PADDLE_ENFORCE(buf.device == DeviceType::GPU); PADDLE_ENFORCE(buf.device == DeviceType::GPU);
buf.size = size; buf.size = size;
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size, PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
cudaMemcpyHostToDevice, *stream_)); cudaMemcpyHostToDevice, stream_));
} }
void TensorRTEngine::SetInputFromGPU(const std::string &name, const void *data, void TensorRTEngine::SetInputFromGPU(const std::string &name, const void *data,
...@@ -227,7 +225,7 @@ void TensorRTEngine::SetInputFromGPU(const std::string &name, const void *data, ...@@ -227,7 +225,7 @@ void TensorRTEngine::SetInputFromGPU(const std::string &name, const void *data,
PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small"); PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
PADDLE_ENFORCE(buf.device == DeviceType::GPU); PADDLE_ENFORCE(buf.device == DeviceType::GPU);
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size, PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
cudaMemcpyDeviceToDevice, *stream_)); cudaMemcpyDeviceToDevice, stream_));
} }
void TensorRTEngine::SetITensor(const std::string &name, void TensorRTEngine::SetITensor(const std::string &name,
......
...@@ -54,17 +54,14 @@ class TensorRTEngine : public EngineBase { ...@@ -54,17 +54,14 @@ class TensorRTEngine : public EngineBase {
nvinfer1::Weights w_; nvinfer1::Weights w_;
}; };
TensorRTEngine(int max_batch, int max_workspace, TensorRTEngine(int max_batch, int max_workspace, cudaStream_t stream,
cudaStream_t* stream = nullptr, int device = 0, int device = 0,
nvinfer1::ILogger& logger = NaiveLogger::Global()) nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch), : max_batch_(max_batch),
max_workspace_(max_workspace), max_workspace_(max_workspace),
stream_(stream ? stream : &default_stream_), stream_(stream),
logger_(logger), logger_(logger),
device_(device) { device_(device) {}
freshDeviceId();
cudaStreamCreate(stream_);
}
virtual ~TensorRTEngine(); virtual ~TensorRTEngine();
...@@ -102,7 +99,7 @@ class TensorRTEngine : public EngineBase { ...@@ -102,7 +99,7 @@ class TensorRTEngine : public EngineBase {
// NOTE this should be used after calling `FreezeNetwork`. // NOTE this should be used after calling `FreezeNetwork`.
Buffer& buffer(const std::string& name) override; Buffer& buffer(const std::string& name) override;
cudaStream_t* stream() { return stream_; } cudaStream_t stream() { return stream_; }
// Fill an input from CPU memory with name and size. // Fill an input from CPU memory with name and size.
void SetInputFromCPU(const std::string& name, const void* data, size_t size); void SetInputFromCPU(const std::string& name, const void* data, size_t size);
...@@ -158,9 +155,8 @@ class TensorRTEngine : public EngineBase { ...@@ -158,9 +155,8 @@ class TensorRTEngine : public EngineBase {
// batch size of the current data, will be updated each Executation. // batch size of the current data, will be updated each Executation.
int batch_size_{-1}; int batch_size_{-1};
cudaStream_t* stream_; cudaStream_t stream_;
// If stream_ is not set from outside, hold its own stream.
cudaStream_t default_stream_;
nvinfer1::ILogger& logger_; nvinfer1::ILogger& logger_;
std::vector<Buffer> buffers_; std::vector<Buffer> buffers_;
...@@ -208,38 +204,6 @@ class TensorRTEngine : public EngineBase { ...@@ -208,38 +204,6 @@ class TensorRTEngine : public EngineBase {
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \ #define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
engine__->network()->add##layer__(ARGS); engine__->network()->add##layer__(ARGS);
/*
* Helper to control the TensorRT engine's creation and deletion.
*/
class TRT_EngineManager {
public:
bool HasEngine(const std::string& name) const {
return engines_.count(name) != 0;
}
// Get an engine called `name`.
TensorRTEngine* Get(const std::string& name) const {
return engines_.at(name).get();
}
// Create or get an engine called `name`
TensorRTEngine* Create(int max_batch, int max_workspace, cudaStream_t* stream,
const std::string& name, int gpu_device = 0) {
auto* p = new TensorRTEngine(max_batch, max_workspace, stream, gpu_device);
engines_[name].reset(p);
return p;
}
void DeleteALl() {
for (auto& item : engines_) {
item.second.reset(nullptr);
}
}
private:
std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_;
};
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -27,8 +27,8 @@ namespace tensorrt { ...@@ -27,8 +27,8 @@ namespace tensorrt {
class TensorRTEngineTest : public ::testing::Test { class TensorRTEngineTest : public ::testing::Test {
protected: protected:
void SetUp() override { void SetUp() override {
// ASSERT_EQ(0, cudaStreamCreate(&stream_)); ASSERT_EQ(0, cudaStreamCreate(&stream_));
engine_ = new TensorRTEngine(10, 1 << 10, &stream_); engine_ = new TensorRTEngine(10, 1 << 10, stream_);
engine_->InitNetwork(); engine_->InitNetwork();
} }
......
...@@ -56,6 +56,13 @@ DECLARE_int32(paddle_num_threads); ...@@ -56,6 +56,13 @@ DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
namespace inference { namespace inference {
float Random(float low, float high) {
static std::random_device rd;
static std::mt19937 mt(rd());
std::uniform_real_distribution<double> dist(low, high);
return dist(mt);
}
void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) { void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
const auto *analysis_config = const auto *analysis_config =
reinterpret_cast<const contrib::AnalysisConfig *>(config); reinterpret_cast<const contrib::AnalysisConfig *>(config);
...@@ -176,7 +183,7 @@ void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs, ...@@ -176,7 +183,7 @@ void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
float *input_data = static_cast<float *>(input.data.data()); float *input_data = static_cast<float *>(input.data.data());
// fill input data, for profile easily, do not use random data here. // fill input data, for profile easily, do not use random data here.
for (size_t j = 0; j < len; ++j) { for (size_t j = 0; j < len; ++j) {
*(input_data + j) = static_cast<float>(j) / len; *(input_data + j) = Random(0.0, 1.0) / 10.;
} }
} }
(*inputs).emplace_back(input_slots); (*inputs).emplace_back(input_slots);
...@@ -344,6 +351,16 @@ void CompareNativeAndAnalysis( ...@@ -344,6 +351,16 @@ void CompareNativeAndAnalysis(
CompareResult(analysis_outputs, native_outputs); CompareResult(analysis_outputs, native_outputs);
} }
void CompareNativeAndAnalysis(
PaddlePredictor *native_pred, PaddlePredictor *analysis_pred,
const std::vector<std::vector<PaddleTensor>> &inputs) {
int batch_size = FLAGS_batch_size;
std::vector<PaddleTensor> native_outputs, analysis_outputs;
native_pred->Run(inputs[0], &native_outputs, batch_size);
analysis_pred->Run(inputs[0], &analysis_outputs, batch_size);
CompareResult(analysis_outputs, native_outputs);
}
template <typename T> template <typename T>
std::string LoDTensorSummary(const framework::LoDTensor &tensor) { std::string LoDTensorSummary(const framework::LoDTensor &tensor) {
std::stringstream ss; std::stringstream ss;
......
...@@ -107,6 +107,27 @@ void compare(std::string model_dir, bool use_tensorrt) { ...@@ -107,6 +107,27 @@ void compare(std::string model_dir, bool use_tensorrt) {
inputs_all); inputs_all);
} }
void compare_continuous_input(std::string model_dir, bool use_tensorrt) {
contrib::AnalysisConfig analysis_config;
SetConfig<contrib::AnalysisConfig>(&analysis_config, model_dir, true,
use_tensorrt, FLAGS_batch_size);
auto config =
reinterpret_cast<const PaddlePredictor::Config*>(&analysis_config);
auto native_pred = CreateTestPredictor(config, false);
auto analysis_pred = CreateTestPredictor(config, true);
for (int i = 0; i < 100; i++) {
std::vector<std::vector<PaddleTensor>> inputs_all;
if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) {
SetFakeImageInput(&inputs_all, model_dir, true, FLAGS_prog_filename,
FLAGS_param_filename);
} else {
SetFakeImageInput(&inputs_all, model_dir, false, "__model__", "");
}
CompareNativeAndAnalysis(native_pred.get(), analysis_pred.get(),
inputs_all);
}
}
TEST(TensorRT_mobilenet, compare) { TEST(TensorRT_mobilenet, compare) {
std::string model_dir = FLAGS_infer_model + "/mobilenet"; std::string model_dir = FLAGS_infer_model + "/mobilenet";
compare(model_dir, /* use_tensorrt */ true); compare(model_dir, /* use_tensorrt */ true);
...@@ -162,5 +183,15 @@ TEST(TensorRT_mobilenet, profile) { ...@@ -162,5 +183,15 @@ TEST(TensorRT_mobilenet, profile) {
profile(model_dir, true, false); profile(model_dir, true, false);
} }
TEST(resnet50, compare_continuous_input) {
std::string model_dir = FLAGS_infer_model + "/resnet50";
compare_continuous_input(model_dir, true);
}
TEST(resnet50, compare_continuous_input_native) {
std::string model_dir = FLAGS_infer_model + "/resnet50";
compare_continuous_input(model_dir, false);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -96,9 +96,13 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -96,9 +96,13 @@ class TensorRTEngineOp : public framework::OperatorBase {
void RunTrt(const framework::Scope &scope, void RunTrt(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
int runtime_batch = 1; int runtime_batch = 1;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
if (trt_engine_.get() == nullptr) { if (trt_engine_.get() == nullptr) {
trt_engine_.reset(new TensorRTEngine( trt_engine_.reset(new TensorRTEngine(
max_batch_size_, workspace_size_, nullptr, max_batch_size_, workspace_size_, stream,
boost::get<platform::CUDAPlace>(dev_place).device)); boost::get<platform::CUDAPlace>(dev_place).device));
Prepare(scope, dev_place, trt_engine_.get()); Prepare(scope, dev_place, trt_engine_.get());
} }
...@@ -126,6 +130,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -126,6 +130,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
} }
} }
cudaStreamSynchronize(stream);
PADDLE_ENFORCE_LE(runtime_batch, max_batch_size_); PADDLE_ENFORCE_LE(runtime_batch, max_batch_size_);
// Execute the engine. // Execute the engine.
engine->Execute(runtime_batch); engine->Execute(runtime_batch);
...@@ -163,7 +168,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -163,7 +168,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
output_index += 1; output_index += 1;
} }
cudaStreamSynchronize(*engine->stream()); cudaStreamSynchronize(stream);
} }
void Prepare(const framework::Scope &scope, const platform::Place &dev_place, void Prepare(const framework::Scope &scope, const platform::Place &dev_place,
......
...@@ -99,7 +99,7 @@ TEST(TensorRTEngineOp, manual) { ...@@ -99,7 +99,7 @@ TEST(TensorRTEngineOp, manual) {
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph", SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString()); block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", 2); SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", 2);
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 2 << 10); SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 1 << 20);
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine"); SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine");
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters", SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters",
std::vector<std::string>({})); std::vector<std::string>({}));
...@@ -193,7 +193,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { ...@@ -193,7 +193,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph", SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString()); block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", batch_size); SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", batch_size);
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 2 << 10); SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 1 << 20);
SetAttr<std::vector<std::string>>( SetAttr<std::vector<std::string>>(
engine_op_desc.Proto(), "parameters", engine_op_desc.Proto(), "parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"})); std::vector<std::string>({"y0", "y1", "y2", "y3"}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册