未验证 提交 e43f7102 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle-TRT] support nhwc (#49633)

* add trt_support_nhwc_pass
上级 7de9420a
......@@ -14,7 +14,7 @@
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......@@ -61,6 +61,18 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
platform::errors::PreconditionNotMet(
"TransDataLayout only support DataLayout transform on same place."));
TransDataLayout(kernel_type_for_var.layout(),
expected_kernel_type.layout(),
place,
in,
out);
}
void TransDataLayout(DataLayout from_layout,
DataLayout to_layout,
phi::Place place,
const phi::DenseTensor& in,
phi::DenseTensor* out) {
PADDLE_ENFORCE_EQ(
arity(in.dims()),
4,
......@@ -73,8 +85,7 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
auto src_dim = in.dims();
std::vector<int64_t> dst_dim;
auto axis =
GetAxis(kernel_type_for_var.layout(), expected_kernel_type.layout());
auto axis = GetAxis(from_layout, to_layout);
dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]];
......@@ -83,10 +94,11 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
out->Resize(phi::make_ddim(dst_dim));
out->mutable_data(place, in.dtype());
framework::VisitDataType(framework::TransToProtoVarType(in.dtype()),
CastDataLayout(pool.Get(place), axis, in, out));
framework::VisitDataType(
static_cast<proto::VarType::Type>(phi::TransToProtoVarType(in.dtype())),
CastDataLayout(pool.Get(place), axis, in, out));
out->set_layout(expected_kernel_type.layout());
out->set_layout(to_layout);
}
} // namespace framework
......
......@@ -14,21 +14,14 @@
#pragma once
#include <map>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/kernels/funcs/data_layout_transform.h"
namespace paddle {
namespace framework {
class OpKernelType;
} // namespace framework
} // namespace paddle
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/phi/backends/onednn/onednn_helper.h"
#endif
......@@ -60,5 +53,11 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
phi::DenseTensor* out,
const phi::Place& place);
void TransDataLayout(phi::DataLayout from_layout,
phi::DataLayout to_layout,
phi::Place place,
const phi::DenseTensor& in,
phi::DenseTensor* out);
} // namespace framework
} // namespace paddle
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/data_layout_transform.h"
#include "gtest/gtest.h"
#include "paddle/fluid/platform/bfloat16.h"
TEST(DataTransform, DataLayoutFunction) {
auto place = paddle::platform::CPUPlace();
......
......@@ -141,11 +141,9 @@ if(WITH_TENSORRT)
pass_library(layernorm_shift_partition_fuse_pass inference)
pass_library(reverse_roll_fuse_pass inference)
pass_library(preln_layernorm_x_fuse_pass inference)
pass_library(trt_support_nhwc_pass inference)
pass_library(elementwise_groupnorm_act_pass inference)
pass_library(preln_elementwise_groupnorm_act_pass inference)
endif()
if(WITH_TENSORRT)
pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
endif()
......
......@@ -18,7 +18,6 @@
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/phi/common/layout.h"
......@@ -30,43 +29,11 @@ namespace framework {
namespace ir {
namespace {
void TransDataLayout(DataLayout from_layout,
DataLayout to_layout,
const phi::DenseTensor &in,
phi::DenseTensor *out) {
PADDLE_ENFORCE_EQ(
arity(in.dims()),
4,
platform::errors::InvalidArgument(
"Input dimension arity only can be 4, the input dimension is %s.",
in.dims()));
auto &pool = platform::DeviceContextPool::Instance();
auto src_dim = in.dims();
std::vector<int64_t> dst_dim;
auto axis = GetAxis(from_layout, to_layout);
dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]];
}
out->Resize(phi::make_ddim(dst_dim));
out->mutable_data(phi::CPUPlace(), in.dtype());
framework::VisitDataType(
framework::TransToProtoVarType(in.dtype()),
CastDataLayout(pool.Get(phi::CPUPlace()), axis, in, out));
out->set_layout(to_layout);
}
void InsertLayoutTransOp(ir::Graph *graph,
ir::Node *prev_node,
ir::Node *next_node,
DataLayout from_layout,
DataLayout to_layout,
phi::DataLayout from_layout,
phi::DataLayout to_layout,
framework::BlockDesc *block_desc,
std::unordered_map<ir::Node *, ir::Node *> *cache) {
auto do_insert = [&](const std::string &in_var_name,
......@@ -91,7 +58,7 @@ void InsertLayoutTransOp(ir::Graph *graph,
op_out_var_desc->SetPersistable(false);
op_out_var_desc->SetDataType(prev_node->Var()->GetDataType());
auto to_shape = prev_node->Var()->GetShape();
if (from_layout == DataLayout::kNCHW) {
if (from_layout == phi::DataLayout::kNCHW) {
auto n = to_shape[0];
auto c = to_shape[1];
auto h = to_shape[2];
......@@ -117,12 +84,13 @@ void InsertLayoutTransOp(ir::Graph *graph,
IR_NODE_UNLINK(prev_node, next_node);
};
if (from_layout == DataLayout::kNCHW && to_layout == DataLayout::kNHWC) {
if (from_layout == phi::DataLayout::kNCHW &&
to_layout == phi::DataLayout::kNHWC) {
auto in_var_name = prev_node->Var()->Name();
auto out_var_name = in_var_name + "_nchw_to_nhwc";
do_insert(in_var_name, out_var_name);
} else if (from_layout == DataLayout::kNHWC &&
to_layout == DataLayout::kNCHW) {
} else if (from_layout == phi::DataLayout::kNHWC &&
to_layout == phi::DataLayout::kNCHW) {
auto in_var_name = prev_node->Var()->Name();
auto out_var_name = in_var_name + "_nhwc_to_nchw";
do_insert(in_var_name, out_var_name);
......@@ -135,7 +103,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::PreconditionNotMet("graph should not be nullptr."));
FusePassBase::Init("data_layout_transfer", graph);
FusePassBase::Init("conv2d_fusion_layout_transfer", graph);
auto *scope = param_scope();
// only float16 compute precision need insert transfer_layout.
......@@ -170,7 +138,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
// Not support multiple block now.
std::unordered_map<ir::Node *, ir::Node *> cache;
auto op_nodes = ir::TopologySortOperations(*graph);
auto op_nodes = TopologySortOperations(*graph);
auto iter = op_nodes.cbegin();
auto *block_desc = (*iter)->Op()->Block();
......@@ -186,7 +154,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
op_node->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format != "NCHW") return false;
auto filter_names = op_node->Op()->Input("Filter");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
constexpr int NHWC_ALIGNMENT = 8;
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
for (const auto &filter_name : filter_names) {
auto *filter_var = scope->FindLocalVar(filter_name);
......@@ -195,7 +163,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
int oc = filter_tensor.dims()[0];
int ic = filter_tensor.dims()[1];
bool cutlass_can_support =
oc % CUTLASS_NHWC_ALIGNMENT == 0 && ic % CUTLASS_NHWC_ALIGNMENT == 0;
oc % NHWC_ALIGNMENT == 0 && ic % NHWC_ALIGNMENT == 0;
if (!cutlass_can_support) {
return false;
}
......@@ -229,8 +197,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
if (cuDNNIsValid(op_node)) {
valid_ops.insert(op_node);
auto *op_desc = op_node->Op();
auto nhwc_attr = framework::Attribute(std::string("NHWC"));
op_desc->SetAttr("data_format", nhwc_attr);
op_desc->SetAttr("data_format", std::string{"NHWC"});
if (cutlass_enable && CutlassIsValid(op_node)) {
op_desc->SetType("conv2d_fusion_cutlass");
}
......@@ -244,8 +211,11 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
phi::DenseTensor temp_tensor = *filter_tensor;
filter_tensor->clear();
TransDataLayout(
DataLayout::kNCHW, DataLayout::kNHWC, temp_tensor, filter_tensor);
framework::TransDataLayout(phi::DataLayout::kNCHW,
phi::DataLayout::kNHWC,
phi::CPUPlace{},
temp_tensor,
filter_tensor);
}
auto op_inputs = op_node->inputs;
for (auto *in_var_node : op_inputs) {
......@@ -290,8 +260,8 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
InsertLayoutTransOp(graph,
in_var_node,
op_node,
DataLayout::kNCHW,
DataLayout::kNHWC,
phi::DataLayout::kNCHW,
phi::DataLayout::kNHWC,
block_desc,
&cache);
}
......@@ -304,8 +274,8 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
InsertLayoutTransOp(graph,
in_var_node,
op_node,
DataLayout::kNHWC,
DataLayout::kNCHW,
phi::DataLayout::kNHWC,
phi::DataLayout::kNCHW,
block_desc,
&cache);
}
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/trt_support_nhwc_pass.h"
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
void DoInsertTransposeOp(ir::Graph *graph,
ir::Node *prev_node,
ir::Node *next_node,
phi::DataLayout from_layout,
phi::DataLayout to_layout,
framework::BlockDesc *block_desc,
std::unordered_map<ir::Node *, ir::Node *> *cache) {
auto do_insert = [&](const std::string &in_var_name,
const std::string &out_var_name) {
auto update_op_desc = [&](framework::OpDesc &desc,
const std::string &x_name,
const std::string &out_name,
const std::vector<int> &axis_attr) {
desc.SetType("transpose");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("axis", axis_attr);
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("data_format", std::string{"AnyLayout"});
desc.SetAttr("use_quantizer", false);
desc.SetAttr("mkldnn_data_type", std::string{"float32"});
desc.Flush();
};
CHECK_NOTNULL(block_desc);
if (cache->count(prev_node) == 0) {
framework::OpDesc op_desc(block_desc);
if (from_layout == phi::DataLayout::kNCHW) {
update_op_desc(op_desc, in_var_name, out_var_name, {0, 2, 3, 1});
} else if (from_layout == phi::DataLayout::kNHWC) {
update_op_desc(op_desc, in_var_name, out_var_name, {0, 3, 1, 2});
}
auto *op_node = graph->CreateOpNode(&op_desc);
auto *op_out_var_desc = block_desc->Var(out_var_name);
op_out_var_desc->SetPersistable(false);
op_out_var_desc->SetDataType(prev_node->Var()->GetDataType());
auto to_shape = prev_node->Var()->GetShape();
if (from_layout == phi::DataLayout::kNCHW) {
auto n = to_shape[0];
auto c = to_shape[1];
auto h = to_shape[2];
auto w = to_shape[3];
op_out_var_desc->SetShape({n, h, w, c});
} else if (from_layout == phi::DataLayout::kNHWC) {
auto n = to_shape[0];
auto h = to_shape[1];
auto w = to_shape[2];
auto c = to_shape[3];
op_out_var_desc->SetShape({n, c, h, w});
}
auto *op_out_var_node = graph->CreateVarNode(op_out_var_desc);
IR_NODE_LINK_TO(op_node, op_out_var_node);
cache->insert(std::make_pair(prev_node, op_out_var_node));
}
next_node->Op()->RenameInput(prev_node->Name(),
cache->at(prev_node)->Name());
IR_NODE_LINK_TO(prev_node, cache->at(prev_node)->inputs.front());
IR_NODE_LINK_TO(cache->at(prev_node), next_node);
IR_NODE_UNLINK(prev_node, next_node);
};
if (from_layout == phi::DataLayout::kNCHW &&
to_layout == phi::DataLayout::kNHWC) {
auto in_var_name = prev_node->Var()->Name();
auto out_var_name = in_var_name + "_nchw_to_nhwc";
do_insert(in_var_name, out_var_name);
} else if (from_layout == phi::DataLayout::kNHWC &&
to_layout == phi::DataLayout::kNCHW) {
auto in_var_name = prev_node->Var()->Name();
auto out_var_name = in_var_name + "_nhwc_to_nchw";
do_insert(in_var_name, out_var_name);
}
}
bool ModelLayoutIsNHWC(const std::vector<ir::Node *> &op_nodes) {
for (auto *op_node : op_nodes) {
if (op_node->IsOp()) {
auto *op_desc = op_node->Op();
std::string data_format;
if (op_desc->HasAttr("data_format")) {
data_format = op_desc->GetAttrIfExists<std::string>("data_format");
} else if (op_desc->HasAttr("data_layout")) {
data_format = op_desc->GetAttrIfExists<std::string>("data_layout");
}
if (data_format == "NHWC") {
return true;
}
}
}
return false;
}
} // namespace
void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::PreconditionNotMet(
"During the trt_support_nhwc_pass, the graph "
"should not be null."));
FusePassBase::Init("trt_support_nhwc_pass", graph);
auto *scope = param_scope();
auto op_nodes = TopologySortOperations(*graph);
if (!ModelLayoutIsNHWC(op_nodes)) {
return;
}
//
//
// TODO(liuyuanle): Add other op if needed!
//
//
std::unordered_set<std::string> need_trans_weights{"prelu"};
std::unordered_set<std::string> not_trans_weights{"conv2d",
"pool2d",
"batch_norm",
"bilinear_interp",
"bilinear_interp_v2",
"nearest_interp",
"nearest_interp_v2"};
// Ops must run under the original layout even though it has
// data_format/data_layout attribute, otherwise it will be very troublesome!
std::unordered_set<std::string> must_original_layout_ops{"affine_channel",
"softmax"};
// OPs unrelated to layout are consistent according to the layout of input
// var!
std::unordered_set<std::string> any_layout_ops{"relu"};
//
//
// TODO(liuyuanle): Add other op if needed!
//
//
// Ops with "data_format" or "data_layout" attribute value of "NHWC"
std::unordered_set<ir::Node *> transposed_ops;
std::unordered_set<ir::Node *> vars_to_nchw;
std::unordered_map<ir::Node *, ir::Node *> cache;
// Not support multiple block now
auto iter = op_nodes.cbegin();
auto *block_desc = (*iter)->Op()->Block();
for (auto *op_node : op_nodes) {
CHECK_EQ(op_node->IsOp(), true);
auto *op_desc = op_node->Op();
std::string data_format;
if (op_desc->HasAttr("data_format")) {
data_format = op_desc->GetAttrIfExists<std::string>("data_format");
} else if (op_desc->HasAttr("data_layout")) {
data_format = op_desc->GetAttrIfExists<std::string>("data_layout");
}
bool input_shape_4{true};
auto op_inputs = op_node->inputs;
for (auto *in_var_node : op_inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (in_var_node->Var()->Persistable()) continue;
auto input_shape = in_var_node->Var()->GetShape();
input_shape_4 &= (input_shape.size() == 4);
}
if (data_format != "NHWC" || !input_shape_4 ||
any_layout_ops.count(op_desc->Type()) ||
must_original_layout_ops.count(op_desc->Type())) {
continue;
}
// Transpose NHWC --> NCHW
//
// Update current op
transposed_ops.insert(op_node);
if (op_desc->HasAttr("data_format")) {
op_desc->SetAttr("data_format", std::string{"NCHW"});
op_desc->Flush();
} else if (op_desc->HasAttr("data_layout")) {
op_desc->SetAttr("data_layout", std::string{"NCHW"});
op_desc->Flush();
}
auto UpdateOutputVars = [&] {
// Update output var of current op
auto op_outputs = op_node->outputs;
for (auto *out_var_node : op_outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (out_var_node->Var()->Persistable()) continue;
auto from_shape = out_var_node->Var()->GetShape();
if (from_shape.size() == 4) {
out_var_node->Var()->SetShape(
{from_shape[0], from_shape[3], from_shape[1], from_shape[2]});
vars_to_nchw.insert(out_var_node);
}
}
};
if (not_trans_weights.count(op_desc->Type())) {
UpdateOutputVars();
} else if (need_trans_weights.count(op_desc->Type())) {
std::vector<std::string> weights;
if (op_desc->Type() == "prelu") {
weights.push_back("Alpha");
}
auto UpdateWeightVars = [&] {
for (auto const &weight : weights) {
// transfer weights
auto weight_names = op_desc->Input(weight);
for (const auto &weight_name : weight_names) {
auto *weight_var = scope->FindLocalVar(weight_name);
auto *weight_tensor = weight_var->GetMutable<phi::DenseTensor>();
if (weight_tensor->dims().size() == 4) {
phi::DenseTensor temp_tensor = *weight_tensor;
weight_tensor->clear();
framework::TransDataLayout(phi::DataLayout::kNHWC,
phi::DataLayout::kNCHW,
phi::CPUPlace{},
temp_tensor,
weight_tensor);
}
}
auto op_inputs = op_node->inputs;
for (auto *in_var_node : op_inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (in_var_node->Var()->Persistable()) {
if (std::find(weight_names.cbegin(),
weight_names.cend(),
in_var_node->Var()->Name()) !=
weight_names.cend()) {
auto from_shape = in_var_node->Var()->GetShape();
in_var_node->Var()->SetShape({from_shape[0],
from_shape[2],
from_shape[3],
from_shape[1]});
}
}
}
}
};
UpdateWeightVars();
UpdateOutputVars();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"During the trt_support_nhwc_pass, %s op not supported. Please "
"update the supported op lists.",
op_desc->Type()));
}
}
auto ProcessAnyLayoutOps = [&] {
// Process any layout ops
for (auto *op_node : op_nodes) {
CHECK_EQ(op_node->IsOp(), true);
auto op_inputs = op_node->inputs;
for (auto *in_var_node : op_inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (transposed_ops.count(op_node)) continue;
if (vars_to_nchw.count(in_var_node) &&
any_layout_ops.count(op_node->Op()->Type())) {
transposed_ops.insert(op_node);
// Update output var of current op
auto op_outputs = op_node->outputs;
for (auto *out_var_node : op_outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (out_var_node->Var()->Persistable()) continue;
auto from_shape = out_var_node->Var()->GetShape();
if (from_shape.size() == 4) {
out_var_node->Var()->SetShape(
{from_shape[0], from_shape[3], from_shape[1], from_shape[2]});
vars_to_nchw.insert(out_var_node);
}
}
}
}
}
};
ProcessAnyLayoutOps();
auto InsertTransposeOp = [&] {
// Insert transpose op
for (auto *op_node : op_nodes) {
CHECK_EQ(op_node->IsOp(), true);
if (transposed_ops.count(op_node)) {
auto op_inputs = op_node->inputs;
for (auto *in_var_node : op_inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (in_var_node->Var()->Persistable()) continue;
if (vars_to_nchw.count(in_var_node)) continue;
DoInsertTransposeOp(graph,
in_var_node,
op_node,
phi::DataLayout::kNHWC,
phi::DataLayout::kNCHW,
block_desc,
&cache);
}
} else {
auto op_inputs = op_node->inputs;
for (auto *in_var_node : op_inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (vars_to_nchw.count(in_var_node)) {
DoInsertTransposeOp(graph,
in_var_node,
op_node,
phi::DataLayout::kNCHW,
phi::DataLayout::kNHWC,
block_desc,
&cache);
}
}
}
}
};
InsertTransposeOp();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(trt_support_nhwc_pass, paddle::framework::ir::TrtSupportNHWCPass);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class TrtSupportNHWCPass : public FusePassBase {
public:
TrtSupportNHWCPass() = default;
~TrtSupportNHWCPass() = default;
protected:
void ApplyImpl(Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -85,7 +85,8 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) {
void PaddlePassBuilder::ClearPasses() { passes_.clear(); }
const std::vector<std::string> kTRTSubgraphPasses({
"adaptive_pool2d_convert_global_pass", //
"trt_support_nhwc_pass",
"adaptive_pool2d_convert_global_pass", //
"shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", //
"delete_fill_constant_op_pass", //
......
......@@ -44,10 +44,10 @@ class TrtConvertBilinearInterpV2Test(TrtLayerAutoScanTest):
for data_layout in ["NCHW", "NHWC"]:
for scale_y in [2.0, 1.0]:
for scale_x in [2.0, 1.0]:
for scale_x in [2.0]:
scale = [scale_y, scale_x]
for out_h in [32, 64, 128, 192]:
for out_w in [32, 64]:
for out_h in [32, 128]:
for out_w in [64]:
dics = [
{
"data_layout": data_layout,
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import unittest
import numpy as np
import paddle
import paddle.inference as inference
import paddle.nn as nn
import paddle.static as static
paddle.enable_static()
class SimpleNet(nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2D(
in_channels=4,
out_channels=4,
kernel_size=3,
stride=2,
padding=0,
data_format='NHWC',
)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(
in_channels=4,
out_channels=2,
kernel_size=3,
stride=2,
padding=0,
data_format='NHWC',
)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2D(
in_channels=2,
out_channels=1,
kernel_size=3,
stride=2,
padding=0,
data_format='NHWC',
)
self.relu3 = nn.ReLU()
self.flatten = nn.Flatten()
self.fc = nn.Linear(729, 10)
self.softmax = nn.Softmax()
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.relu3(x)
x = self.flatten(x)
x = self.fc(x)
x = self.softmax(x)
return x
class TRTNHWCConvertTest(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.path = './inference_pass/nhwc_convert/infer_model'
def create_model(self):
image = static.data(
name='img', shape=[None, 224, 224, 4], dtype='float32'
)
predict = SimpleNet()(image)
exe = paddle.static.Executor(self.place)
exe.run(paddle.static.default_startup_program())
paddle.static.save_inference_model(self.path, [image], [predict], exe)
def create_predictor(self):
config = paddle.inference.Config(
self.path + '.pdmodel', self.path + '.pdiparams'
)
config.enable_memory_optim()
config.enable_use_gpu(100, 0)
config.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=3,
precision_mode=inference.PrecisionType.Float32,
use_static=False,
use_calib_mode=False,
)
predictor = inference.create_predictor(config)
return predictor
def infer(self, predictor, img):
input_names = predictor.get_input_names()
for i, name in enumerate(input_names):
input_tensor = predictor.get_input_handle(name)
input_tensor.reshape(img[i].shape)
input_tensor.copy_from_cpu(img[i].copy())
predictor.run()
results = []
output_names = predictor.get_output_names()
for i, name in enumerate(output_names):
output_tensor = predictor.get_output_handle(name)
output_data = output_tensor.copy_to_cpu()
results.append(output_data)
return results
def test_nhwc_convert(self):
self.create_model()
predictor = self.create_predictor()
img = np.ones((1, 224, 224, 4), dtype=np.float32)
result = self.infer(predictor, img=[img])
def tearDown(self):
shutil.rmtree('./inference_pass/nhwc_convert/')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册