From 4d49e61ab8100cd0739d7ca41cc88e657f40482d Mon Sep 17 00:00:00 2001 From: nhzlx Date: Tue, 24 Jul 2018 11:42:44 +0000 Subject: [PATCH] fix comments --- .../inference/tensorrt/convert/CMakeLists.txt | 4 +- .../fluid/inference/tensorrt/convert/fc_op.cc | 3 +- .../inference/tensorrt/convert/mul_op.cc | 3 +- .../inference/tensorrt/convert/test_fc_op.cc | 2 +- .../inference/tensorrt/convert/test_mul_op.cc | 49 +++++++++++++++++++ .../operators/tensorrt_engine_op_test.cc | 6 +-- 6 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/test_mul_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 36bd3904e0..748f5a084e 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,12 +1,14 @@ # Add TRT tests nv_library(tensorrt_converter - SRCS conv2d_op.cc fc_op.cc + SRCS mul_op.cc conv2d_op.cc fc_op.cc DEPS tensorrt_engine mul_op) nv_test(test_op_converter SRCS test_op_converter.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter) nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) +nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index 498a3547de..409efac679 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -116,5 +116,4 @@ class FcOpConverter : public OpConverter { } // namespace inference } // namespace paddle -REGISTER_TRT_OP_CONVERTER(mul, FcOpConverter); -USE_OP(mul); +REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc index 9623ac27e2..3c34295736 100644 --- a/paddle/fluid/inference/tensorrt/convert/mul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc @@ -50,5 +50,4 @@ class MulOpConverter : public OpConverter { } // namespace paddle USE_OP(mul); -// TODO(xingzhaolong): change the name to mul then -REGISTER_TRT_OP_CONVERTER(mul_temp, MulOpConverter); +REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc b/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc index bd90e0ee7a..081f4d6059 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_fc_op.cc @@ -24,7 +24,6 @@ TEST(fc_op, test) { std::unordered_set parameters({"mul-Y"}); framework::Scope scope; TRTConvertValidation validator(10, parameters, scope, 1000); - validator.DeclInputVar("mul-X", nvinfer1::Dims4(1, 10, 1, 1)); validator.DeclParamVar("mul-Y", nvinfer1::Dims2(10, 2)); // validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2)); @@ -45,3 +44,4 @@ TEST(fc_op, test) { } // namespace tensorrt } // namespace inference } // namespace paddle +USE_OP(mul); diff --git a/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc b/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc new file mode 100644 index 0000000000..674f37f2fd --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_mul_op.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(MulOpConverter, main) { + framework::Scope scope; + std::unordered_set parameters; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6)); + validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10)); + validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("mul"); + desc.SetInput("X", {"mul-X"}); + desc.SetInput("Y", {"mul-Y"}); + desc.SetOutput("Out", {"mul-Out"}); + + LOG(INFO) << "set OP"; + validator.SetOp(*desc.Proto()); + LOG(INFO) << "execute"; + + validator.Execute(1); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(mul); diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc index 9b46fbb72b..7cb1e47a15 100644 --- a/paddle/fluid/operators/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt_engine_op_test.cc @@ -66,14 +66,14 @@ TEST(TensorRTEngineOp, manual) { framework::BlockDesc block_desc(&program, block_); LOG(INFO) << "create fc op"; auto* fc0 = block_desc.AppendOp(); - fc0->SetType("mul"); + fc0->SetType("fc"); fc0->SetInput("X", std::vector({"x"})); // 4 x 1 x 1 fc0->SetInput("Y", std::vector({"y"})); // 4 x 6 fc0->SetOutput("Out", std::vector({"z"})); // 6 x 1 x 1 LOG(INFO) << "create fc op"; auto* fc1 = block_desc.AppendOp(); - fc1->SetType("mul"); + fc1->SetType("fc"); fc1->SetInput("X", std::vector({"z"})); fc1->SetInput("Y", std::vector({"y0"})); // 6 x 8 fc1->SetOutput("Out", std::vector({"z0"})); // 8 x 1 x 1 @@ -208,4 +208,4 @@ TEST(TensorRTEngineOp, fc) { Execute(40, 28, 28); } } // namespace operators } // namespace paddle -USE_TRT_CONVERTER(mul) +USE_TRT_CONVERTER(fc) -- GitLab