提交 6a7b9957 编写于 作者: H hjchen2

Refine commit message to enable ci, test=develop

上级 413f5948
......@@ -26,7 +26,7 @@ class PReluOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(40) << "convert fluid prelu op to tensorrt prelu layer";
VLOG(4) << "convert fluid prelu op to tensorrt prelu layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
......@@ -43,33 +43,31 @@ class PReluOpConverter : public OpConverter {
PADDLE_ENFORCE_NOT_NULL(alpha_var);
auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();
platform::CPUPlace place;
std::unique_ptr<framework::LoDTensor> alpha_tensor_host(
platform::CUDAPlace place;
std::unique_ptr<framework::LoDTensor> alpha_tensor_device(
new framework::LoDTensor());
alpha_tensor_host->Resize(alpha_tensor->dims());
TensorCopySync(*alpha_tensor, place, alpha_tensor_host.get());
float* alpha_data = alpha_tensor_host->mutable_data<float>(place);
alpha_tensor_device->Resize(alpha_tensor->dims());
TensorCopySync(*alpha_tensor, place, alpha_tensor_device.get());
float* alpha_data = alpha_tensor_device->mutable_data<float>(place);
// Transform alpha to TensorRTEngine::Weight
TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT,
static_cast<void*>(alpha_data),
alpha_tensor_host->numel());
engine_->weight_map[op_desc.Input("Alpha")[0]] =
std::move(alpha_tensor_host);
//
alpha_tensor_device->numel());
PReluPlugin* plugin = new PReluPlugin(alpha_rt, mode);
nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin);
// keep alpha tensor to avoid release it's memory
engine_->weight_map[op_desc.Input("Alpha")[0]] =
std::move(alpha_tensor_device);
std::string layer_name = "prelu (Output: ";
for (size_t i = 0; i < output_num; i++) {
auto output_name = op_desc.Output("Out")[i];
layer->getOutput(i)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(i));
layer_name += output_name;
if (test_mode) {
engine_->DeclareOutput(output_name);
}
auto output_name = op_desc.Output("Out")[0];
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
layer_name += output_name;
if (test_mode) {
engine_->DeclareOutput(output_name);
}
layer->setName((layer_name + ")").c_str());
}
......
......@@ -26,7 +26,7 @@ class SplitOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(40) << "convert a fluid split op to tensorrt split layer";
VLOG(4) << "convert a fluid split op to tensorrt split layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
......
......@@ -46,7 +46,7 @@ class TensorRTEngine : public EngineBase {
w_.values = value;
w_.count = num_elem;
}
nvinfer1::Weights& get() { return w_; }
const nvinfer1::Weights& get() { return w_; }
std::vector<int64_t> dims;
......
......@@ -109,25 +109,12 @@ nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
return output_dims;
}
int PReluPlugin::initialize() {
nvinfer1::Weights &alpha = cuda_alpha_.get();
alpha.type = alpha_.get().type;
alpha.count = alpha_.get().count;
CHECK_EQ(cudaMalloc(&alpha.values, alpha.count * sizeof(float)), cudaSuccess);
CHECK_EQ(cudaMemcpy(const_cast<void *>(alpha.values), alpha_.get().values,
alpha.count * sizeof(float), cudaMemcpyHostToDevice),
cudaSuccess);
return 0;
}
int PReluPlugin::enqueue(int batchSize, const void *const *inputs,
void **outputs, void *workspace, cudaStream_t stream) {
// input dims is CHW.
const auto &input_dims = this->getInputDims(0);
const float *input = reinterpret_cast<const float *>(inputs[0]);
const float *alpha =
reinterpret_cast<const float *>(cuda_alpha_.get().values);
const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
float *output = reinterpret_cast<float **>(outputs)[0];
if (mode_ == "channel") {
PReluChannelWise(stream, input, alpha, output, batchSize, input_dims);
......
......@@ -24,7 +24,6 @@ namespace tensorrt {
class PReluPlugin : public PluginTensorRT {
TensorRTEngine::Weight alpha_;
TensorRTEngine::Weight cuda_alpha_;
std::string mode_;
protected:
......@@ -60,7 +59,6 @@ class PReluPlugin : public PluginTensorRT {
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
int initialize() override;
int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册