diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index bb603efaf30bb72d74b5583abc45d01a16c076a3..409efac6799b6fb8d27a1343a55e7a508760868f 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -32,11 +32,11 @@ void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides, for (int h = 0; h < shape.h(); ++h) { for (int w = 0; w < shape.w(); ++w) { odata[h * ostrides.h() + w * ostrides.w()] = - idata[h * ostrides.h() + w * ostrides.w()]; + idata[h * istrides.h() + w * istrides.w()]; } } } - +// indata c * k // Reorder the data layout from CK to KC. void ReorderCKtoKC(TensorRTEngine::Weight& iweights, TensorRTEngine::Weight* oweights) { @@ -79,9 +79,8 @@ class FcOpConverter : public OpConverter { framework::LoDTensor tmp; tmp.Resize(Y_t->dims()); - memcpy(tmp.mutable_data(platform::CPUPlace()), Y_t->data(), - Y_t->dims()[0] * Y_t->dims()[1]); - + memcpy(tmp.mutable_data(platform::CPUPlace()), weight_data, + Y_t->dims()[0] * Y_t->dims()[1] * sizeof(float)); TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), Y_t->memory_size() / sizeof(float)}; @@ -93,7 +92,7 @@ class FcOpConverter : public OpConverter { // The data layout of TRT FC layer's weight is different from fluid's FC, // need to reorder the elements. - ReorderCKtoKC(tmp_weight, &weight); + ReorderCKtoKC(weight, &tmp_weight); // Currently, the framework can only handle one fluid op -> one TRT layer, // but fc fuses `mul` and `bias` (2 fluid ops), so here is a trick, just @@ -103,7 +102,7 @@ class FcOpConverter : public OpConverter { auto* layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *const_cast(X), - n_output, weight.get(), bias.get()); + n_output, tmp_weight.get(), bias.get()); auto output_name = op_desc.Output("Out").front(); engine_->SetITensor(output_name, layer->getOutput(0)); @@ -118,4 +117,3 @@ class FcOpConverter : public OpConverter { } // namespace paddle REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter); -USE_OP(mul); diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc index 0a02a7bebf9efbd0555707e6cfa701ef1e7d9659..7dabfd9f6a9a8cfbdd1d9a66541180d3499b7bdc 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc @@ -37,7 +37,7 @@ TEST(ReluOpConverter, main) { validator.SetOp(*desc.Proto()); LOG(INFO) << "execute"; - validator.Execute(10); + validator.Execute(1); } } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc b/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc index a30253072ac581ceca85ca10151a176f87a7cb39..081f4d605975f1408d4d8a8ed3108c04d837a4de 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc @@ -23,11 +23,11 @@ namespace tensorrt { TEST(fc_op, test) { std::unordered_set parameters({"mul-Y"}); framework::Scope scope; - TRTConvertValidation validator(20, parameters, scope, 1000); - - validator.DeclInputVar("mul-X", nvinfer1::Dims4(8, 3, 1, 1)); - validator.DeclParamVar("mul-Y", nvinfer1::Dims2(3, 2)); - validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(8, 2)); + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("mul-X", nvinfer1::Dims4(1, 10, 1, 1)); + validator.DeclParamVar("mul-Y", nvinfer1::Dims2(10, 2)); + // validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2)); + validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(1, 2)); // Prepare Op description framework::OpDesc desc; @@ -38,9 +38,10 @@ TEST(fc_op, test) { validator.SetOp(*desc.Proto()); - validator.Execute(10); + validator.Execute(1); } } // namespace tensorrt } // namespace inference } // namespace paddle +USE_OP(mul); diff --git a/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc b/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc index 1ce1130e5d660d717a1262a1fbdb4b620462c0b3..674f37f2fdddf013a8f6f4671debbc19c3322423 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc @@ -39,7 +39,7 @@ TEST(MulOpConverter, main) { validator.SetOp(*desc.Proto()); LOG(INFO) << "execute"; - validator.Execute(10); + validator.Execute(1); } } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index 3b1f531adc5d756259df1c350f7f44bf71ee1f93..f14885b238134cdf38a278cd8a0734947bcacfe0 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -39,7 +39,7 @@ namespace tensorrt { float random(float low, float high) { static std::random_device rd; static std::mt19937 mt(rd()); - std::uniform_real_distribution dist(1.0, 10.0); + std::uniform_real_distribution dist(low, high); return dist(mt); } @@ -49,6 +49,7 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place, size_t num_elements = analysis::AccuDims(dims, dims.size()); PADDLE_ENFORCE_GT(num_elements, 0); auto* data = tensor->mutable_data(place); + for (size_t i = 0; i < num_elements; i++) { *(data + i) = random(0., 1.); } @@ -68,7 +69,7 @@ class TRTConvertValidation { int workspace_size = 1 << 10) : parameters_(parameters), scope_(scope) { // create engine. - engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_)); + engine_.reset(new TensorRTEngine(batch_size, workspace_size, &stream_)); engine_->InitNetwork(); PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0); @@ -138,12 +139,11 @@ class TRTConvertValidation { cudaStreamSynchronize(*engine_->stream()); ASSERT_FALSE(op_desc_->OutputArgumentNames().empty()); - const size_t output_space_size = 200; + const size_t output_space_size = 2000; for (const auto& output : op_desc_->OutputArgumentNames()) { std::vector fluid_out; std::vector trt_out(output_space_size); - engine_->GetOutputInCPU(output, &trt_out[0], - output_space_size * sizeof(float)); + engine_->GetOutputInCPU(output, &trt_out[0], output_space_size); cudaStreamSynchronize(*engine_->stream()); auto* var = scope_.FindVar(output); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index fefec0df6d03669a294ce9643b666d7416593708..b821c3d0bf425c46fae634fbf53f7ee63100ca5c 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -1,7 +1,7 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. +Licensed under the Apache License, Version 2.0 (the "License"); you may not use +this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 @@ -26,6 +26,8 @@ namespace paddle { namespace inference { namespace tensorrt { +int TensorRTEngine::runtime_batch_ = 1; + void TensorRTEngine::Build(const DescType &paddle_model) { PADDLE_ENFORCE(false, "not implemented"); } @@ -42,6 +44,7 @@ void TensorRTEngine::Execute(int batch_size) { PADDLE_ENFORCE_NOT_NULL(stream_); infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr); cudaStreamSynchronize(*stream_); + SetRuntimeBatch(batch_size); } TensorRTEngine::~TensorRTEngine() { @@ -80,17 +83,17 @@ void TensorRTEngine::FreezeNetwork() { auto dims = infer_engine_->getBindingDimensions(slot_offset); item.second = kDataTypeSize[static_cast( infer_engine_->getBindingDataType(slot_offset))] * - analysis::AccuDims(dims.d, dims.nbDims); + analysis::AccuDims(dims.d, dims.nbDims) * max_batch_; PADDLE_ENFORCE_GT(item.second, 0); } auto &buf = buffer(item.first); buf.max_size = item.second * max_batch_; CHECK(buf.buffer == nullptr); // buffer should be allocated only once. - PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, buf.max_size)); - PADDLE_ENFORCE_LE(buf.max_size, 1 << 30); // 10G - // buf.size will changed in the runtime. + + PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second * max_batch_)); buf.size = 0; + PADDLE_ENFORCE_LE(buf.max_size, 1 << 30); // 10G buf.device = DeviceType::GPU; } } @@ -105,7 +108,7 @@ nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name, auto *input = infer_network_->addInput(name.c_str(), dtype, dims); PADDLE_ENFORCE(input, "infer network add input %s failed", name); buffer_sizes_[name] = kDataTypeSize[static_cast(dtype)] * - analysis::AccuDims(dims.d, dims.nbDims); + analysis::AccuDims(dims.d, dims.nbDims) * max_batch_; PADDLE_ENFORCE(input->isNetworkInput()); TensorRTEngine::SetITensor(name, input); return input; @@ -149,35 +152,42 @@ void *TensorRTEngine::GetOutputInGPU(const std::string &name) { void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst, size_t max_size) { // determine data size + auto *output = TensorRTEngine::GetITensor(name); + nvinfer1::Dims dims = output->getDimensions(); + auto dim_size = analysis::AccuDims(dims.d, dims.nbDims); + size_t dst_size = dim_size * runtime_batch_ * + kDataTypeSize[static_cast(output->getType())]; + auto it = buffer_sizes_.find(name); PADDLE_ENFORCE(it != buffer_sizes_.end()); PADDLE_ENFORCE_GT(it->second, 0); - PADDLE_ENFORCE_GE(max_size, it->second); + PADDLE_ENFORCE_LE(dst_size, it->second); + PADDLE_ENFORCE_GE(max_size, dst_size); auto &buf = buffer(name); PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); - PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second, + PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, dst_size, cudaMemcpyDeviceToDevice, *stream_), 0); } void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst, size_t max_size) { - VLOG(4) << "get output in cpu"; - auto &buf = buffer(name); - - // Update needed buffer size. - auto slot_offset = infer_engine_->getBindingIndex(name.c_str()); - auto dims = infer_engine_->getBindingDimensions(slot_offset); - buf.size = kDataTypeSize[static_cast( - infer_engine_->getBindingDataType(slot_offset))] * - analysis::AccuDims(dims.d, dims.nbDims); - PADDLE_ENFORCE_LE(buf.size, buf.max_size); // determine data size + + auto *output = TensorRTEngine::GetITensor(name); + nvinfer1::Dims dims = output->getDimensions(); + auto dim_size = analysis::AccuDims(dims.d, dims.nbDims); + size_t dst_size = dim_size * runtime_batch_ * + kDataTypeSize[static_cast(output->getType())]; + auto it = buffer_sizes_.find(name); + PADDLE_ENFORCE(it != buffer_sizes_.end()); + PADDLE_ENFORCE_GT(it->second, 0); + PADDLE_ENFORCE_LE(dst_size, it->second); + PADDLE_ENFORCE_GE(max_size, dst_size); + auto &buf = buffer(name); PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); - // DEBUG - memset(dst, 0, buf.size); - PADDLE_ENFORCE_EQ( - 0, cudaMemcpy(dst, buf.buffer, buf.size, cudaMemcpyDeviceToHost)); + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, dst_size, + cudaMemcpyDeviceToHost, *stream_)); } Buffer &TensorRTEngine::buffer(const std::string &name) { @@ -225,6 +235,12 @@ nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) { return itensor_map_[name]; } +void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { + runtime_batch_ = batch_size; +} + +int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; } + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 7064d333f6db754f88c0ac6956a9527a48bf866c..694468c419c20089de1cdecff1a903ad0cc6e99f 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -117,10 +117,14 @@ class TensorRTEngine : public EngineBase { nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } + void SetRuntimeBatch(size_t batch_size); + int GetRuntimeBatch(); private: // the max batch size int max_batch_; + // the runtime batch size + static int runtime_batch_; // the max memory size the engine uses int max_workspace_; diff --git a/paddle/fluid/inference/tensorrt/test_engine.cc b/paddle/fluid/inference/tensorrt/test_engine.cc index fca3488008ed83418b5e28b8af42d8019aaaa2a4..f8732e51b66bdc78aa35d06ba9651f1942a74b01 100644 --- a/paddle/fluid/inference/tensorrt/test_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_engine.cc @@ -28,7 +28,7 @@ class TensorRTEngineTest : public ::testing::Test { protected: void SetUp() override { ASSERT_EQ(0, cudaStreamCreate(&stream_)); - engine_ = new TensorRTEngine(1, 1 << 10, &stream_); + engine_ = new TensorRTEngine(10, 1 << 10, &stream_); engine_->InitNetwork(); } @@ -71,7 +71,7 @@ TEST_F(TensorRTEngineTest, add_layer) { LOG(INFO) << "to get output"; float y_cpu; - engine_->GetOutputInCPU("y", &y_cpu, sizeof(float)); + engine_->GetOutputInCPU("y", &y_cpu, 1 * sizeof(float)); LOG(INFO) << "to checkout output"; ASSERT_EQ(y_cpu, x_v * 2 + 3); @@ -103,15 +103,49 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) { LOG(INFO) << "to get output"; float y_cpu[2] = {-1., -1.}; + auto dims = engine_->GetITensor("y")->getDimensions(); ASSERT_EQ(dims.nbDims, 3); ASSERT_EQ(dims.d[0], 2); ASSERT_EQ(dims.d[1], 1); - engine_->GetOutputInCPU("y", &y_cpu[0], sizeof(float) * 2); + engine_->GetOutputInCPU("y", &y_cpu[0], 2 * sizeof(float)); ASSERT_EQ(y_cpu[0], 4.5); ASSERT_EQ(y_cpu[1], 14.5); } +TEST_F(TensorRTEngineTest, test_conv2d_temp) { + // Weight in CPU memory. + float raw_weight[9] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + float raw_bias[1] = {0}; + + TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, 9); + TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, 1); + auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT, + nvinfer1::Dims3{1, 3, 3}); + auto* conv_layer = + TRT_ENGINE_ADD_LAYER(engine_, Convolution, *x, 1, nvinfer1::DimsHW{3, 3}, + weight.get(), bias.get()); + PADDLE_ENFORCE(conv_layer != nullptr); + conv_layer->setStride(nvinfer1::DimsHW{1, 1}); + conv_layer->setPadding(nvinfer1::DimsHW{1, 1}); + + engine_->DeclareOutput(conv_layer, 0, "y"); + engine_->FreezeNetwork(); + ASSERT_EQ(engine_->engine()->getNbBindings(), 2); + + float x_v[18] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + engine_->SetInputFromCPU("x", reinterpret_cast(&x_v), + 18 * sizeof(float)); + engine_->Execute(2); + + LOG(INFO) << "to get output"; + float* y_cpu = new float[18]; + engine_->GetOutputInCPU("y", &y_cpu[0], 18 * sizeof(float)); + ASSERT_EQ(y_cpu[0], 4.0); + ASSERT_EQ(y_cpu[1], 6.0); +} + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index 43672d6db92a981f0fbe6e8f7079dafc6ae4052e..db641a4bc2c637e0babee6b6bc6e67b068759ff5 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -55,13 +55,14 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector &shape) { "TensorRT' tensor input requires at least 2 dimensions"); PADDLE_ENFORCE_LE(shape.size(), 4UL, "TensorRT' tensor input requires at most 4 dimensions"); + switch (shape.size()) { case 2: - return nvinfer1::Dims2(shape[0], shape[1]); + return nvinfer1::Dims2(1, shape[1]); case 3: - return nvinfer1::Dims3(shape[0], shape[1], shape[2]); + return nvinfer1::Dims3(1, shape[1], shape[2]); case 4: - return nvinfer1::Dims4(shape[0], shape[1], shape[2], shape[3]); + return nvinfer1::Dims4(1, shape[1], shape[2], shape[3]); default: return nvinfer1::Dims(); } diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index a332d70030ffa6a033f6b2b33487a4fd279b7016..32d10fd8a5687ebaae1d7d75af531cbc45ef4245 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -93,13 +93,15 @@ class TensorRTEngineKernel : public framework::OpKernel { auto* fluid_v = context.scope().FindVar(y); PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y); auto* fluid_t = fluid_v->GetMutable(); - auto size = inference::analysis::AccuDims(dims.d, dims.nbDims); + fluid_t->Resize(framework::make_ddim(ddim)); // TODO(Superjomn) find some way to determine which device to output the // tensor. // if (platform::is_cpu_place(fluid_t->place())) { // TODO(Superjomn) change this float to dtype size. + auto size = inference::analysis::AccuDims(dims.d, dims.nbDims) * + FLAGS_tensorrt_engine_batch_size; engine->GetOutputInCPU(y, fluid_t->mutable_data(platform::CPUPlace()), size * sizeof(float)); diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc index 82a16361e40513aeaf6f510e450f58989369fcdb..7cb1e47a1516c32fb31a7818e7203b498e31e431 100644 --- a/paddle/fluid/operators/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt_engine_op_test.cc @@ -64,36 +64,37 @@ TEST(TensorRTEngineOp, manual) { LOG(INFO) << "create block desc"; framework::BlockDesc block_desc(&program, block_); - LOG(INFO) << "create mul op"; - auto* mul = block_desc.AppendOp(); - mul->SetType("mul"); - mul->SetInput("X", std::vector({"x"})); // 2 x 4 - mul->SetInput("Y", std::vector({"y"})); // 4 x 6 - mul->SetOutput("Out", std::vector({"z"})); // 2 x 6 + LOG(INFO) << "create fc op"; + auto* fc0 = block_desc.AppendOp(); + fc0->SetType("fc"); + fc0->SetInput("X", std::vector({"x"})); // 4 x 1 x 1 + fc0->SetInput("Y", std::vector({"y"})); // 4 x 6 + fc0->SetOutput("Out", std::vector({"z"})); // 6 x 1 x 1 LOG(INFO) << "create fc op"; - auto* fc = block_desc.AppendOp(); - fc->SetType("mul"); - fc->SetInput("X", std::vector({"z"})); - fc->SetInput("Y", std::vector({"y0"})); // 6 x 8 - fc->SetOutput("Out", std::vector({"z0"})); // 2 x 8 + auto* fc1 = block_desc.AppendOp(); + fc1->SetType("fc"); + fc1->SetInput("X", std::vector({"z"})); + fc1->SetInput("Y", std::vector({"y0"})); // 6 x 8 + fc1->SetOutput("Out", std::vector({"z0"})); // 8 x 1 x 1 // Set inputs' variable shape in BlockDesc - AddTensorToBlockDesc(block_, "x", std::vector({2, 4})); + // the batch size is 2, so the dims of 'x' is {2, 4, 1, 1} + AddTensorToBlockDesc(block_, "x", std::vector({2, 4, 1, 1})); AddTensorToBlockDesc(block_, "y", std::vector({4, 6})); AddTensorToBlockDesc(block_, "y0", std::vector({6, 8})); AddTensorToBlockDesc(block_, "z", std::vector({2, 6})); // It is wired, need to copy manually. - *block_->add_ops() = *mul->Proto(); - *block_->add_ops() = *fc->Proto(); + *block_->add_ops() = *fc0->Proto(); + *block_->add_ops() = *fc1->Proto(); ASSERT_EQ(block_->ops_size(), 2); LOG(INFO) << "create tensorrt desc"; framework::OpDesc engine_op_desc(nullptr); engine_op_desc.SetType("tensorrt_engine"); - engine_op_desc.SetInput("Xs", std::vector({"x", "y", "y0"})); + engine_op_desc.SetInput("Xs", std::vector({"x"})); engine_op_desc.SetOutput("Ys", std::vector({"z0"})); SetAttr(engine_op_desc.Proto(), "subgraph", block_->SerializeAsString()); @@ -207,5 +208,4 @@ TEST(TensorRTEngineOp, fc) { Execute(40, 28, 28); } } // namespace operators } // namespace paddle -USE_TRT_CONVERTER(mul) USE_TRT_CONVERTER(fc)