提交 3de45566 编写于 作者: N nhzlx

concat op && map cnn model support

上级 557be6fc
...@@ -72,7 +72,7 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -72,7 +72,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
auto trt_teller = [&](const Node* node) { auto trt_teller = [&](const Node* node) {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax", {"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax",
"depthwise_conv2d", "batch_norm"}); "depthwise_conv2d", "batch_norm", "concat"});
if (!node->IsFunction()) return false; if (!node->IsFunction()) return false;
const auto* func = static_cast<const Function*>(node); const auto* func = static_cast<const Function*>(node);
......
...@@ -32,6 +32,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { ...@@ -32,6 +32,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
: NativePaddlePredictor(config), config_(config) {} : NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) { bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
FLAGS_IA_enable_tensorrt_subgraph_engine = true;
VLOG(3) << "Predictor::init()"; VLOG(3) << "Predictor::init()";
FLAGS_tensorrt_max_batch_size = config_.max_batch_size; FLAGS_tensorrt_max_batch_size = config_.max_batch_size;
FLAGS_tensorrt_workspace_size = config_.workspace_size; FLAGS_tensorrt_workspace_size = config_.workspace_size;
...@@ -161,3 +162,4 @@ USE_TRT_CONVERTER(fc); ...@@ -161,3 +162,4 @@ USE_TRT_CONVERTER(fc);
USE_TRT_CONVERTER(pool2d); USE_TRT_CONVERTER(pool2d);
USE_TRT_CONVERTER(softmax); USE_TRT_CONVERTER(softmax);
USE_TRT_CONVERTER(batch_norm); USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat);
# Add TRT tests # Add TRT tests
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry) DEPS tensorrt_engine operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
...@@ -18,12 +18,12 @@ nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc ...@@ -18,12 +18,12 @@ nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL)
nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL)
nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine elementwise_add_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine elementwise_add_op SERIAL)
nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL)
nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL)
nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
*/
class ConcatOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
std::vector<nvinfer1::ITensor*> itensors;
for (auto& input_name : op_desc.Input("X")) {
itensors.push_back(engine_->GetITensor(input_name));
}
int axis = boost::get<int>(op_desc.GetAttr("axis"));
PADDLE_ENFORCE(axis > 0,
"The axis attr of Concat op should be large than 0 for trt");
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(),
itensors.size());
axis = axis - 1; // Remove batch dim
layer->setAxis(axis);
auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(concat, ConcatOpConverter);
...@@ -79,6 +79,14 @@ class OpConverter { ...@@ -79,6 +79,14 @@ class OpConverter {
it = it =
Registry<OpConverter>::Lookup("elementwise_" + op_type + "_tensor"); Registry<OpConverter>::Lookup("elementwise_" + op_type + "_tensor");
} }
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type());
}
if (op_desc.Type() == "depthwise_conv2d") {
it = Registry<OpConverter>::Lookup("conv2d");
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type());
} }
if (!it) { if (!it) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(concat_op, test) {
std::unordered_set<std::string> parameters({""});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("concat_x1", nvinfer1::DimsCHW(10, 3, 1));
validator.DeclInputVar("concat_x2", nvinfer1::DimsCHW(3, 3, 1));
validator.DeclInputVar("concat_x3", nvinfer1::DimsCHW(7, 3, 1));
validator.DeclOutputVar("concat_out", nvinfer1::DimsCHW(20, 3, 1));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("concat");
desc.SetInput("X", {"concat_x1", "concat_x2", "concat_x3"});
desc.SetOutput("Out", {"concat_out"});
int axis = 1;
desc.SetAttr("axis", axis);
validator.SetOp(*desc.Proto());
validator.Execute(5);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(concat);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册