提交 1600ba86 编写于 作者: N nhzlx

1. change tensorrt op from cpu to gpu

上级 bd87f67f
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
namespace paddle { namespace paddle {
namespace inference { namespace inference {
DEFINE_int32(tensorrt_max_batchsize, 3, "TensorRT maximum batch size"); DEFINE_int32(tensorrt_max_batchsize, 1, "TensorRT maximum batch size");
DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size"); DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size");
namespace analysis { namespace analysis {
...@@ -52,7 +52,6 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { ...@@ -52,7 +52,6 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
bool DataFlowGraphToFluidPass::Finalize() { return true; } bool DataFlowGraphToFluidPass::Finalize() { return true; }
void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
FilterRedundantOutputOfSubGraph(graph);
LOG(INFO) << "graph.inputs " << graph->inputs.size(); LOG(INFO) << "graph.inputs " << graph->inputs.size();
for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) { for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) {
if (node.deleted()) continue; if (node.deleted()) continue;
......
...@@ -153,6 +153,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() { ...@@ -153,6 +153,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
inlink_or_outlink_cleaner(o->inlinks); inlink_or_outlink_cleaner(o->inlinks);
} }
} }
FilterRedundantOutputOfSubGraph(graph_);
} }
} // namespace analysis } // namespace analysis
......
...@@ -35,12 +35,20 @@ class Conv2dOpConverter : public OpConverter { ...@@ -35,12 +35,20 @@ class Conv2dOpConverter : public OpConverter {
auto* Y_v = scope.FindVar(op_desc.Input("Filter").front()); auto* Y_v = scope.FindVar(op_desc.Input("Filter").front());
PADDLE_ENFORCE_NOT_NULL(Y_v); PADDLE_ENFORCE_NOT_NULL(Y_v);
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>(); auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
auto* weight_data = Y_t->mutable_data<float>(platform::CPUPlace());
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 4UL); platform::CPUPlace cpu_place;
const int n_output = Y_t->dims()[0]; framework::LoDTensor* weight_tensor = new framework::LoDTensor();
const int filter_h = Y_t->dims()[2]; weight_tensor->Resize(Y_t->dims());
const int filter_w = Y_t->dims()[3]; TensorCopySync((*Y_t), cpu_place, weight_tensor);
engine_->weight_map[op_desc.Input("Filter").front()] =
std::move(std::unique_ptr<framework::Tensor>(weight_tensor));
auto* weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL);
const int n_output = weight_tensor->dims()[0];
const int filter_h = weight_tensor->dims()[2];
const int filter_w = weight_tensor->dims()[3];
const int groups = boost::get<int>(op_desc.GetAttr("groups")); const int groups = boost::get<int>(op_desc.GetAttr("groups"));
const std::vector<int> dilations = const std::vector<int> dilations =
...@@ -57,7 +65,7 @@ class Conv2dOpConverter : public OpConverter { ...@@ -57,7 +65,7 @@ class Conv2dOpConverter : public OpConverter {
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data), static_cast<void*>(weight_data),
Y_t->memory_size() / sizeof(float)}; weight_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
auto* layer = TRT_ENGINE_ADD_LAYER( auto* layer = TRT_ENGINE_ADD_LAYER(
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle { namespace paddle {
...@@ -40,10 +39,19 @@ class ElementwiseWeightOpConverter : public OpConverter { ...@@ -40,10 +39,19 @@ class ElementwiseWeightOpConverter : public OpConverter {
auto* Y_v = scope.FindVar(op_desc.Input("Y").front()); auto* Y_v = scope.FindVar(op_desc.Input("Y").front());
PADDLE_ENFORCE_NOT_NULL(Y_v); PADDLE_ENFORCE_NOT_NULL(Y_v);
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>(); auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
auto* weight_data = Y_t->mutable_data<float>(platform::CPUPlace());
platform::CPUPlace cpu_place;
framework::LoDTensor* weight_tensor = new framework::LoDTensor();
weight_tensor->Resize(Y_t->dims());
TensorCopySync((*Y_t), cpu_place, weight_tensor);
engine_->weight_map[op_desc.Input("Y").front()] =
std::move(std::unique_ptr<framework::Tensor>(weight_tensor));
auto* weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
std::vector<int> dims_y = framework::vectorize2int(Y_t->dims()); std::vector<int> dims_y = framework::vectorize2int(weight_tensor->dims());
if (static_cast<int>(dims_y.size()) == dims_x.nbDims + 1) { if (static_cast<int>(dims_y.size()) == dims_x.nbDims + 1) {
if (dims_y[0] == 1) dims_y.erase(dims_y.begin()); if (dims_y[0] == 1) dims_y.erase(dims_y.begin());
} }
...@@ -70,9 +78,9 @@ class ElementwiseWeightOpConverter : public OpConverter { ...@@ -70,9 +78,9 @@ class ElementwiseWeightOpConverter : public OpConverter {
PADDLE_THROW("TensorRT unsupported weight Shape for Elementwise op!"); PADDLE_THROW("TensorRT unsupported weight Shape for Elementwise op!");
} }
TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight shift_weights{
static_cast<void*>(weight_data), nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
Y_t->memory_size() / sizeof(float)}; weight_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr, TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr,
0}; 0};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
......
...@@ -12,12 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -73,19 +68,28 @@ class FcOpConverter : public OpConverter { ...@@ -73,19 +68,28 @@ class FcOpConverter : public OpConverter {
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>(); auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
// This may trigger a GPU->CPU copy, because TRT's weight can only be // This may trigger a GPU->CPU copy, because TRT's weight can only be
// assigned from CPU memory, that can't be avoided. // assigned from CPU memory, that can't be avoided.
auto* weight_data = Y_t->mutable_data<float>(platform::CPUPlace()); platform::CPUPlace cpu_place;
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL); // a matrix framework::LoDTensor weight_tensor;
size_t n_output = Y_t->dims()[1]; weight_tensor.Resize(Y_t->dims());
TensorCopySync((*Y_t), cpu_place, &weight_tensor);
framework::LoDTensor tmp; auto* weight_data = weight_tensor.mutable_data<float>(platform::CPUPlace());
tmp.Resize(Y_t->dims());
memcpy(tmp.mutable_data<float>(platform::CPUPlace()), weight_data, PADDLE_ENFORCE_EQ(weight_tensor.dims().size(), 2UL); // a matrix
size_t n_output = weight_tensor.dims()[1];
framework::LoDTensor* tmp = new framework::LoDTensor();
tmp->Resize(weight_tensor.dims());
engine_->weight_map[op_desc.Input("Y").front()] =
std::move(std::unique_ptr<framework::Tensor>(tmp));
memcpy(tmp->mutable_data<float>(platform::CPUPlace()), weight_data,
Y_t->dims()[0] * Y_t->dims()[1] * sizeof(float)); Y_t->dims()[0] * Y_t->dims()[1] * sizeof(float));
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data), static_cast<void*>(weight_data),
Y_t->memory_size() / sizeof(float)}; Y_t->memory_size() / sizeof(float)};
TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT,
static_cast<void*>(tmp.data<float>()), static_cast<void*>(tmp->data<float>()),
Y_t->memory_size() / sizeof(float)); Y_t->memory_size() / sizeof(float));
weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]}); weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]});
tmp_weight.dims = weight.dims; tmp_weight.dims = weight.dims;
......
...@@ -57,6 +57,7 @@ TEST(OpConverter, ConvertBlock) { ...@@ -57,6 +57,7 @@ TEST(OpConverter, ConvertBlock) {
auto* x = scope.Var("conv2d-Y"); auto* x = scope.Var("conv2d-Y");
auto* x_tensor = x->GetMutable<framework::LoDTensor>(); auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(dim_vec)); x_tensor->Resize(framework::make_ddim(dim_vec));
x_tensor->mutable_data<float>(platform::CUDAPlace(0));
OpConverter converter; OpConverter converter;
converter.ConvertBlock(*block->Proto(), {"conv2d-Y"}, scope, converter.ConvertBlock(*block->Proto(), {"conv2d-Y"}, scope,
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
...@@ -48,11 +49,17 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place, ...@@ -48,11 +49,17 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
auto dims = tensor->dims(); auto dims = tensor->dims();
size_t num_elements = analysis::AccuDims(dims, dims.size()); size_t num_elements = analysis::AccuDims(dims, dims.size());
PADDLE_ENFORCE_GT(num_elements, 0); PADDLE_ENFORCE_GT(num_elements, 0);
auto* data = tensor->mutable_data<float>(place);
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(dims);
auto* temp_data = temp_tensor.mutable_data<float>(cpu_place);
for (size_t i = 0; i < num_elements; i++) { for (size_t i = 0; i < num_elements; i++) {
*(data + i) = random(0., 1.); *(temp_data + i) = random(0., 1.);
} }
TensorCopySync(temp_tensor, place, tensor);
} }
/* /*
...@@ -101,8 +108,8 @@ class TRTConvertValidation { ...@@ -101,8 +108,8 @@ class TRTConvertValidation {
} }
void DeclVar(const std::string& name, const std::vector<int> dim_vec) { void DeclVar(const std::string& name, const std::vector<int> dim_vec) {
platform::CPUPlace place; platform::CUDAPlace place;
platform::CPUDeviceContext ctx(place); platform::CUDADeviceContext ctx(place);
auto* x = scope_.Var(name); auto* x = scope_.Var(name);
auto* x_tensor = x->GetMutable<framework::LoDTensor>(); auto* x_tensor = x->GetMutable<framework::LoDTensor>();
...@@ -141,7 +148,7 @@ class TRTConvertValidation { ...@@ -141,7 +148,7 @@ class TRTConvertValidation {
PADDLE_ENFORCE(var); PADDLE_ENFORCE(var);
auto tensor = var->GetMutable<framework::LoDTensor>(); auto tensor = var->GetMutable<framework::LoDTensor>();
engine_->SetInputFromCPU( engine_->SetInputFromGPU(
input, static_cast<void*>(tensor->data<void>()), input, static_cast<void*>(tensor->data<void>()),
sizeof(float) * sizeof(float) *
analysis::AccuDims(tensor->dims(), tensor->dims().size())); analysis::AccuDims(tensor->dims(), tensor->dims().size()));
...@@ -151,8 +158,8 @@ class TRTConvertValidation { ...@@ -151,8 +158,8 @@ class TRTConvertValidation {
void Execute(int batch_size) { void Execute(int batch_size) {
// Execute Fluid Op // Execute Fluid Op
PADDLE_ENFORCE_LE(batch_size, max_batch_size_); PADDLE_ENFORCE_LE(batch_size, max_batch_size_);
platform::CPUPlace place; platform::CUDAPlace place;
platform::CPUDeviceContext ctx(place); platform::CUDADeviceContext ctx(place);
op_->Run(scope_, place); op_->Run(scope_, place);
// Execute TRT. // Execute TRT.
engine_->Execute(batch_size); engine_->Execute(batch_size);
......
...@@ -33,6 +33,7 @@ void TensorRTEngine::Build(const DescType &paddle_model) { ...@@ -33,6 +33,7 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
} }
void TensorRTEngine::Execute(int batch_size) { void TensorRTEngine::Execute(int batch_size) {
freshDeviceId();
batch_size_ = batch_size; batch_size_ = batch_size;
std::vector<void *> buffers; std::vector<void *> buffers;
for (auto &buf : buffers_) { for (auto &buf : buffers_) {
...@@ -60,6 +61,7 @@ TensorRTEngine::~TensorRTEngine() { ...@@ -60,6 +61,7 @@ TensorRTEngine::~TensorRTEngine() {
} }
void TensorRTEngine::FreezeNetwork() { void TensorRTEngine::FreezeNetwork() {
freshDeviceId();
PADDLE_ENFORCE(infer_builder_ != nullptr, PADDLE_ENFORCE(infer_builder_ != nullptr,
"Call InitNetwork first to initialize network."); "Call InitNetwork first to initialize network.");
PADDLE_ENFORCE(infer_network_ != nullptr, PADDLE_ENFORCE(infer_network_ != nullptr,
...@@ -241,6 +243,13 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { ...@@ -241,6 +243,13 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; } int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }
void TensorRTEngine::freshDeviceId() {
int count;
cudaGetDeviceCount(&count);
PADDLE_ENFORCE_LT(device_, count);
cudaSetDevice(device_);
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
...@@ -52,13 +53,15 @@ class TensorRTEngine : public EngineBase { ...@@ -52,13 +53,15 @@ class TensorRTEngine : public EngineBase {
}; };
TensorRTEngine(int max_batch, int max_workspace, TensorRTEngine(int max_batch, int max_workspace,
cudaStream_t* stream = nullptr, cudaStream_t* stream = nullptr, int device = 0,
nvinfer1::ILogger& logger = NaiveLogger::Global()) nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch), : max_batch_(max_batch),
max_workspace_(max_workspace), max_workspace_(max_workspace),
stream_(stream ? stream : &default_stream_), stream_(stream ? stream : &default_stream_),
logger_(logger) { logger_(logger),
cudaStreamCreate(&default_stream_); device_(device) {
freshDeviceId();
cudaStreamCreate(stream_);
} }
virtual ~TensorRTEngine(); virtual ~TensorRTEngine();
...@@ -119,6 +122,15 @@ class TensorRTEngine : public EngineBase { ...@@ -119,6 +122,15 @@ class TensorRTEngine : public EngineBase {
nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } nvinfer1::INetworkDefinition* network() { return infer_network_.get(); }
void SetRuntimeBatch(size_t batch_size); void SetRuntimeBatch(size_t batch_size);
int GetRuntimeBatch(); int GetRuntimeBatch();
int GetDevice() { return device_; }
// A pointer to CPU memory is needed of the TRT weight.
// Before TRT runs, fluid loads weight into GPU storage.
// so we need to copy the weights from GPU to CPU in our op converter.
// We use a map to store these weights for the weight memory is not released
// in advance, which affecting the construction of TRT Op.
std::unordered_map<std::string /*name*/, std::unique_ptr<framework::Tensor>>
weight_map;
private: private:
// the max batch size // the max batch size
...@@ -140,6 +152,8 @@ class TensorRTEngine : public EngineBase { ...@@ -140,6 +152,8 @@ class TensorRTEngine : public EngineBase {
std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_; std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_;
std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/> std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/>
itensor_map_; itensor_map_;
// The specific GPU id that the TensorRTEngine bounded to.
int device_;
// TensorRT related internal members // TensorRT related internal members
template <typename T> template <typename T>
...@@ -156,6 +170,10 @@ class TensorRTEngine : public EngineBase { ...@@ -156,6 +170,10 @@ class TensorRTEngine : public EngineBase {
infer_ptr<nvinfer1::INetworkDefinition> infer_network_; infer_ptr<nvinfer1::INetworkDefinition> infer_network_;
infer_ptr<nvinfer1::ICudaEngine> infer_engine_; infer_ptr<nvinfer1::ICudaEngine> infer_engine_;
infer_ptr<nvinfer1::IExecutionContext> infer_context_; infer_ptr<nvinfer1::IExecutionContext> infer_context_;
// Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling
// freshDeviceId().
void freshDeviceId();
}; // class TensorRTEngine }; // class TensorRTEngine
// Add an layer__ into engine__ with args ARGS. // Add an layer__ into engine__ with args ARGS.
...@@ -188,8 +206,8 @@ class TRT_EngineManager { ...@@ -188,8 +206,8 @@ class TRT_EngineManager {
// Create or get an engine called `name` // Create or get an engine called `name`
TensorRTEngine* Create(int max_batch, int max_workspace, cudaStream_t* stream, TensorRTEngine* Create(int max_batch, int max_workspace, cudaStream_t* stream,
const std::string& name) { const std::string& name, int gpu_device = 0) {
auto* p = new TensorRTEngine(max_batch, max_workspace, stream); auto* p = new TensorRTEngine(max_batch, max_workspace, stream, gpu_device);
engines_[name].reset(p); engines_[name].reset(p);
return p; return p;
} }
......
...@@ -27,7 +27,7 @@ namespace tensorrt { ...@@ -27,7 +27,7 @@ namespace tensorrt {
class TensorRTEngineTest : public ::testing::Test { class TensorRTEngineTest : public ::testing::Test {
protected: protected:
void SetUp() override { void SetUp() override {
ASSERT_EQ(0, cudaStreamCreate(&stream_)); // ASSERT_EQ(0, cudaStreamCreate(&stream_));
engine_ = new TensorRTEngine(10, 1 << 10, &stream_); engine_ = new TensorRTEngine(10, 1 << 10, &stream_);
engine_->InitNetwork(); engine_->InitNetwork();
} }
......
...@@ -100,7 +100,8 @@ function(op_library TARGET) ...@@ -100,7 +100,8 @@ function(op_library TARGET)
endif() endif()
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op") foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
...@@ -245,6 +246,7 @@ op_library(softmax_op DEPS softmax) ...@@ -245,6 +246,7 @@ op_library(softmax_op DEPS softmax)
op_library(sequence_softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax)
if (WITH_GPU AND TENSORRT_FOUND) if (WITH_GPU AND TENSORRT_FOUND)
op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter) op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(tensorrt_engine);\n")
nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
DEPS tensorrt_engine_op DEPS tensorrt_engine_op
analysis) analysis)
......
...@@ -17,10 +17,6 @@ ...@@ -17,10 +17,6 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/operators/tensorrt_engine_op.h" #include "paddle/fluid/operators/tensorrt_engine_op.h"
namespace paddle { namespace paddle {
...@@ -29,100 +25,6 @@ DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT"); ...@@ -29,100 +25,6 @@ DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT");
namespace operators { namespace operators {
using inference::Singleton;
using inference::tensorrt::TRT_EngineManager;
using FluidDT = framework::proto::VarType_Type;
using TRT_DT = nvinfer1::DataType;
namespace {
TRT_DT FluidDataType2TRT(FluidDT type) {
switch (type) {
case FluidDT::VarType_Type_FP32:
return TRT_DT::kFLOAT;
case FluidDT::VarType_Type_INT32:
return TRT_DT::kINT32;
default:
return TRT_DT::kINT32;
}
PADDLE_THROW("unkown type");
return TRT_DT::kINT32;
}
nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
PADDLE_ENFORCE_GT(shape.size(), 1UL,
"TensorRT' tensor input requires at least 2 dimensions");
PADDLE_ENFORCE_LE(shape.size(), 4UL,
"TensorRT' tensor input requires at most 4 dimensions");
PADDLE_ENFORCE_EQ(shape.size(), 4UL);
return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]);
}
} // namespace
template <typename DeviceContext, typename T>
void TensorRTEngineKernel<DeviceContext, T>::Prepare(
const framework::ExecutionContext &context) const {
VLOG(4) << "Prepare engine";
// Get the ProgramDesc and pass to convert.
framework::proto::BlockDesc block_desc;
block_desc.ParseFromString(context.Attr<std::string>("subgraph"));
int max_batch = context.Attr<int>("max_batch");
auto max_workspace = context.Attr<int>("max_workspace");
auto params = context.Attr<std::vector<std::string>>("parameters");
std::unordered_set<std::string> parameters;
for (const auto &param : params) {
parameters.insert(param);
}
std::vector<std::string> output_maps =
context.Attr<std::vector<std::string>>("output_name_mapping");
// TODO(Superjomn) replace this with a different stream
auto *engine = Singleton<TRT_EngineManager>::Global().Create(
max_batch, max_workspace, nullptr /*engine hold its own stream*/,
context.Attr<std::string>("engine_uniq_key"));
engine->InitNetwork();
framework::BlockDesc block(nullptr /*programdesc*/, &block_desc);
VLOG(4) << "parsed var size " << block.AllVars().size();
// Add inputs
VLOG(4) << "declare inputs";
for (auto &input : context.Inputs("Xs")) {
if (parameters.count(input)) continue;
VLOG(4) << "declare input " << input;
auto *var = block.FindVar(input);
// TensorRT engine need to create parameters. The parameter's description
// should be set in
PADDLE_ENFORCE(var, "no variable called %s", input);
PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR,
"TensorRT engine only takes LoDTensor as input");
auto shape = var->GetShape();
// For the special batch_size placeholder -1, drop it and pass the real
// shape of data.
// TODO(Superjomn) fix this with batch broadcast, or it can't handle
// variational batch size.
if (shape[0] == -1) {
shape[0] = FLAGS_tensorrt_engine_batch_size;
}
engine->DeclareInput(
input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(shape));
}
inference::Singleton<inference::tensorrt::OpConverter>::Global().ConvertBlock(
block_desc, parameters, context.scope(), engine);
// Add outputs
for (auto &output : output_maps) {
engine->DeclareOutput(output);
}
engine->FreezeNetwork();
}
class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -150,11 +52,4 @@ namespace ops = paddle::operators; ...@@ -150,11 +52,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp, REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp,
ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker); ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker);
REGISTER_OP_CPU_KERNEL(
tensorrt_engine,
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, float>,
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, double>,
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, int>,
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, int64_t>);
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/tensorrt_engine_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
tensorrt_engine,
ops::TensorRTEngineKernel<paddle::platform::CUDADeviceContext, float>,
ops::TensorRTEngineKernel<paddle::platform::CUDADeviceContext, double>,
ops::TensorRTEngineKernel<paddle::platform::CUDADeviceContext, int>,
ops::TensorRTEngineKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
namespace paddle { namespace paddle {
...@@ -29,6 +31,35 @@ DECLARE_int32(tensorrt_engine_batch_size); ...@@ -29,6 +31,35 @@ DECLARE_int32(tensorrt_engine_batch_size);
namespace operators { namespace operators {
using FluidDT = framework::proto::VarType_Type;
using TRT_DT = nvinfer1::DataType;
namespace {
TRT_DT FluidDataType2TRT(FluidDT type) {
switch (type) {
case FluidDT::VarType_Type_FP32:
return TRT_DT::kFLOAT;
case FluidDT::VarType_Type_INT32:
return TRT_DT::kINT32;
default:
return TRT_DT::kINT32;
}
PADDLE_THROW("unkown type");
return TRT_DT::kINT32;
}
nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t>& shape) {
PADDLE_ENFORCE_GT(shape.size(), 1UL,
"TensorRT' tensor input requires at least 2 dimensions");
PADDLE_ENFORCE_LE(shape.size(), 4UL,
"TensorRT' tensor input requires at most 4 dimensions");
PADDLE_ENFORCE_EQ(shape.size(), 4UL);
return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]);
}
} // namespace
using inference::Singleton; using inference::Singleton;
using inference::tensorrt::TRT_EngineManager; using inference::tensorrt::TRT_EngineManager;
...@@ -47,7 +78,7 @@ class TensorRTEngineOp : public framework::OperatorWithKernel { ...@@ -47,7 +78,7 @@ class TensorRTEngineOp : public framework::OperatorWithKernel {
.FindVar(input0) .FindVar(input0)
->GetMutable<framework::LoDTensor>() ->GetMutable<framework::LoDTensor>()
->type()), ->type()),
platform::CPUPlace()); ctx.GetPlace());
return kt; return kt;
} }
}; };
...@@ -94,7 +125,9 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -94,7 +125,9 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
// Convert output tensor from engine to fluid // Convert output tensor from engine to fluid
int output_index = 0; int output_index = 0;
VLOG(4) << "TensorRT Engine Op Outputs:";
for (const auto& y : context.Outputs("Ys")) { for (const auto& y : context.Outputs("Ys")) {
VLOG(4) << y;
// convert output and copy to fluid. // convert output and copy to fluid.
nvinfer1::ITensor* trt_t = engine->GetITensor(output_maps[output_index]); nvinfer1::ITensor* trt_t = engine->GetITensor(output_maps[output_index]);
auto dims = trt_t->getDimensions(); auto dims = trt_t->getDimensions();
...@@ -113,9 +146,11 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -113,9 +146,11 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
// TODO(Superjomn) change this float to dtype size. // TODO(Superjomn) change this float to dtype size.
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims) * auto size = inference::analysis::AccuDims(dims.d, dims.nbDims) *
FLAGS_tensorrt_engine_batch_size; FLAGS_tensorrt_engine_batch_size;
engine->GetOutputInCPU(output_maps[output_index], engine->GetOutputInGPU(
fluid_t->mutable_data<float>(platform::CPUPlace()), output_maps[output_index],
size * sizeof(float)); fluid_t->mutable_data<float>(platform::CUDAPlace(
boost::get<platform::CUDAPlace>(context.GetPlace()).device)),
size * sizeof(float));
//} else { //} else {
// engine->GetOutputInGPU( // engine->GetOutputInGPU(
// y, fluid_t->mutable_data<float>(platform::CUDAPlace()), // y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
...@@ -128,8 +163,67 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -128,8 +163,67 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
} }
protected: protected:
// Build the engine. void Prepare(const framework::ExecutionContext& context) const {
void Prepare(const framework::ExecutionContext& context) const; VLOG(4) << "Prepare engine";
// Get the ProgramDesc and pass to convert.
framework::proto::BlockDesc block_desc;
block_desc.ParseFromString(context.Attr<std::string>("subgraph"));
int max_batch = context.Attr<int>("max_batch");
auto max_workspace = context.Attr<int>("max_workspace");
auto params = context.Attr<std::vector<std::string>>("parameters");
std::unordered_set<std::string> parameters;
for (const auto& param : params) {
parameters.insert(param);
}
std::vector<std::string> output_maps =
context.Attr<std::vector<std::string>>("output_name_mapping");
// TODO(Superjomn) replace this with a different stream
auto* engine = Singleton<TRT_EngineManager>::Global().Create(
max_batch, max_workspace, nullptr /*engine hold its own stream*/,
context.Attr<std::string>("engine_uniq_key"),
boost::get<platform::CUDAPlace>(context.GetPlace()).device);
engine->InitNetwork();
framework::BlockDesc block(nullptr /*programdesc*/, &block_desc);
VLOG(4) << "parsed var size " << block.AllVars().size();
// Add inputs
VLOG(4) << "declare inputs";
for (auto& input : context.Inputs("Xs")) {
if (parameters.count(input)) continue;
VLOG(4) << "declare input " << input;
auto* var = block.FindVar(input);
// TensorRT engine need to create parameters. The parameter's description
// should be set in
PADDLE_ENFORCE(var, "no variable called %s", input);
PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR,
"TensorRT engine only takes LoDTensor as input");
auto shape = var->GetShape();
// For the special batch_size placeholder -1, drop it and pass the real
// shape of data.
// TODO(Superjomn) fix this with batch broadcast, or it can't handle
// variational batch size.
if (shape[0] == -1) {
shape[0] = FLAGS_tensorrt_engine_batch_size;
}
engine->DeclareInput(
input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(shape));
}
inference::Singleton<inference::tensorrt::OpConverter>::Global()
.ConvertBlock(block_desc, parameters, context.scope(), engine);
// Add outputs
for (auto& output : output_maps) {
engine->DeclareOutput(output);
}
engine->FreezeNetwork();
}
}; };
} // namespace operators } // namespace operators
......
...@@ -23,20 +23,20 @@ limitations under the License. */ ...@@ -23,20 +23,20 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
USE_CPU_ONLY_OP(tensorrt_engine); USE_CUDA_ONLY_OP(tensorrt_engine);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace { namespace {
void CreateCPUTensor(framework::Scope* scope, const std::string& name, void CreateCUDATensor(framework::Scope* scope, const std::string& name,
const std::vector<int64_t>& shape) { const std::vector<int64_t>& shape) {
auto* var = scope->Var(name); auto* var = scope->Var(name);
auto* tensor = var->GetMutable<framework::LoDTensor>(); auto* tensor = var->GetMutable<framework::LoDTensor>();
auto dims = framework::make_ddim(shape); auto dims = framework::make_ddim(shape);
tensor->Resize(dims); tensor->Resize(dims);
platform::CPUPlace place; platform::CUDAPlace place;
platform::CPUDeviceContext ctx(place); platform::CUDADeviceContext ctx(place);
inference::tensorrt::RandomizeTensor(tensor, place, ctx); inference::tensorrt::RandomizeTensor(tensor, place, ctx);
} }
...@@ -112,15 +112,15 @@ TEST(TensorRTEngineOp, manual) { ...@@ -112,15 +112,15 @@ TEST(TensorRTEngineOp, manual) {
LOG(INFO) << "engine_op " << engine_op.get(); LOG(INFO) << "engine_op " << engine_op.get();
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CUDAPlace place;
platform::CPUDeviceContext ctx(place); platform::CUDADeviceContext ctx(place);
// Prepare variables. // Prepare variables.
CreateCPUTensor(&scope, "x", std::vector<int64_t>({2, 4})); CreateCUDATensor(&scope, "x", std::vector<int64_t>({2, 4}));
CreateCPUTensor(&scope, "y", std::vector<int64_t>({4, 6})); CreateCUDATensor(&scope, "y", std::vector<int64_t>({4, 6}));
CreateCPUTensor(&scope, "z", std::vector<int64_t>({2, 6})); CreateCUDATensor(&scope, "z", std::vector<int64_t>({2, 6}));
CreateCPUTensor(&scope, "y0", std::vector<int64_t>({6, 8})); CreateCUDATensor(&scope, "y0", std::vector<int64_t>({6, 8}));
CreateCPUTensor(&scope, "z0", std::vector<int64_t>({2, 8})); CreateCUDATensor(&scope, "z0", std::vector<int64_t>({2, 8}));
// Execute them. // Execute them.
LOG(INFO) << "engine_op run"; LOG(INFO) << "engine_op run";
...@@ -130,8 +130,8 @@ TEST(TensorRTEngineOp, manual) { ...@@ -130,8 +130,8 @@ TEST(TensorRTEngineOp, manual) {
void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
framework::ProgramDesc program; framework::ProgramDesc program;
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CUDAPlace place;
platform::CPUDeviceContext ctx(place); platform::CUDADeviceContext ctx(place);
auto* block_ = program.Proto()->add_blocks(); auto* block_ = program.Proto()->add_blocks();
block_->set_idx(0); block_->set_idx(0);
...@@ -165,10 +165,10 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { ...@@ -165,10 +165,10 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
// Prepare variables. // Prepare variables.
if (!x_created) { if (!x_created) {
CreateCPUTensor(&scope, x_name, std::vector<int64_t>(x_shape)); CreateCUDATensor(&scope, x_name, std::vector<int64_t>(x_shape));
} }
CreateCPUTensor(&scope, y_name, std::vector<int64_t>(y_shape)); CreateCUDATensor(&scope, y_name, std::vector<int64_t>(y_shape));
CreateCPUTensor(&scope, z_name, std::vector<int64_t>(z_shape)); CreateCUDATensor(&scope, z_name, std::vector<int64_t>(z_shape));
// It is wired, need to copy manually. // It is wired, need to copy manually.
*block_->add_ops() = *fc->Proto(); *block_->add_ops() = *fc->Proto();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册