From 6a7b99573789ee8c85544cdb77af416f2ef97949 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Fri, 16 Nov 2018 04:37:36 +0000 Subject: [PATCH] Refine commit message to enable ci, test=develop --- .../inference/tensorrt/convert/prelu_op.cc | 34 +++++++++---------- .../inference/tensorrt/convert/split_op.cc | 2 +- paddle/fluid/inference/tensorrt/engine.h | 2 +- .../tensorrt/plugin/prelu_op_plugin.cu | 15 +------- .../tensorrt/plugin/prelu_op_plugin.h | 2 -- 5 files changed, 19 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc index bc7cf7d80..337885e6b 100644 --- a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc @@ -26,7 +26,7 @@ class PReluOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(40) << "convert fluid prelu op to tensorrt prelu layer"; + VLOG(4) << "convert fluid prelu op to tensorrt prelu layer"; framework::OpDesc op_desc(op, nullptr); // Declare inputs @@ -43,33 +43,31 @@ class PReluOpConverter : public OpConverter { PADDLE_ENFORCE_NOT_NULL(alpha_var); auto* alpha_tensor = alpha_var->GetMutable(); - platform::CPUPlace place; - std::unique_ptr alpha_tensor_host( + platform::CUDAPlace place; + std::unique_ptr alpha_tensor_device( new framework::LoDTensor()); - alpha_tensor_host->Resize(alpha_tensor->dims()); - TensorCopySync(*alpha_tensor, place, alpha_tensor_host.get()); - float* alpha_data = alpha_tensor_host->mutable_data(place); + alpha_tensor_device->Resize(alpha_tensor->dims()); + TensorCopySync(*alpha_tensor, place, alpha_tensor_device.get()); + float* alpha_data = alpha_tensor_device->mutable_data(place); // Transform alpha to TensorRTEngine::Weight TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT, static_cast(alpha_data), - alpha_tensor_host->numel()); - engine_->weight_map[op_desc.Input("Alpha")[0]] = - std::move(alpha_tensor_host); - // + alpha_tensor_device->numel()); PReluPlugin* plugin = new PReluPlugin(alpha_rt, mode); nvinfer1::IPluginLayer* layer = engine_->AddPlugin(&input, input_num, plugin); + // keep alpha tensor to avoid release it's memory + engine_->weight_map[op_desc.Input("Alpha")[0]] = + std::move(alpha_tensor_device); std::string layer_name = "prelu (Output: "; - for (size_t i = 0; i < output_num; i++) { - auto output_name = op_desc.Output("Out")[i]; - layer->getOutput(i)->setName(output_name.c_str()); - engine_->SetITensor(output_name, layer->getOutput(i)); - layer_name += output_name; - if (test_mode) { - engine_->DeclareOutput(output_name); - } + auto output_name = op_desc.Output("Out")[0]; + layer->getOutput(0)->setName(output_name.c_str()); + engine_->SetITensor(output_name, layer->getOutput(0)); + layer_name += output_name; + if (test_mode) { + engine_->DeclareOutput(output_name); } layer->setName((layer_name + ")").c_str()); } diff --git a/paddle/fluid/inference/tensorrt/convert/split_op.cc b/paddle/fluid/inference/tensorrt/convert/split_op.cc index 12179cccc..159854ab5 100644 --- a/paddle/fluid/inference/tensorrt/convert/split_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/split_op.cc @@ -26,7 +26,7 @@ class SplitOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(40) << "convert a fluid split op to tensorrt split layer"; + VLOG(4) << "convert a fluid split op to tensorrt split layer"; framework::OpDesc op_desc(op, nullptr); // Declare inputs diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 7a920ebd1..99420f19b 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -46,7 +46,7 @@ class TensorRTEngine : public EngineBase { w_.values = value; w_.count = num_elem; } - nvinfer1::Weights& get() { return w_; } + const nvinfer1::Weights& get() { return w_; } std::vector dims; diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index d1ae06377..0f1ca1129 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -109,25 +109,12 @@ nvinfer1::Dims PReluPlugin::getOutputDimensions(int index, return output_dims; } -int PReluPlugin::initialize() { - nvinfer1::Weights &alpha = cuda_alpha_.get(); - alpha.type = alpha_.get().type; - alpha.count = alpha_.get().count; - - CHECK_EQ(cudaMalloc(&alpha.values, alpha.count * sizeof(float)), cudaSuccess); - CHECK_EQ(cudaMemcpy(const_cast(alpha.values), alpha_.get().values, - alpha.count * sizeof(float), cudaMemcpyHostToDevice), - cudaSuccess); - return 0; -} - int PReluPlugin::enqueue(int batchSize, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) { // input dims is CHW. const auto &input_dims = this->getInputDims(0); const float *input = reinterpret_cast(inputs[0]); - const float *alpha = - reinterpret_cast(cuda_alpha_.get().values); + const float *alpha = reinterpret_cast(alpha_.get().values); float *output = reinterpret_cast(outputs)[0]; if (mode_ == "channel") { PReluChannelWise(stream, input, alpha, output, batchSize, input_dims); diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h index 7c12705fa..aa0f865c8 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h @@ -24,7 +24,6 @@ namespace tensorrt { class PReluPlugin : public PluginTensorRT { TensorRTEngine::Weight alpha_; - TensorRTEngine::Weight cuda_alpha_; std::string mode_; protected: @@ -60,7 +59,6 @@ class PReluPlugin : public PluginTensorRT { int getNbOutputs() const override { return 1; } nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, int nbInputDims) override; - int initialize() override; int enqueue(int batchSize, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) override; }; -- GitLab