From e43f710212a2b8f46a203ddb51fc88d90129888d Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Mon, 16 Jan 2023 16:20:08 +0800 Subject: [PATCH] [Paddle-TRT] support nhwc (#49633) * add trt_support_nhwc_pass --- .../fluid/framework/data_layout_transform.cc | 24 +- .../fluid/framework/data_layout_transform.h | 21 +- .../framework/data_layout_transform_test.cc | 1 + paddle/fluid/framework/ir/CMakeLists.txt | 4 +- .../ir/conv2d_fusion_layout_transfer_pass.cc | 72 +--- .../framework/ir/trt_support_nhwc_pass.cc | 365 ++++++++++++++++++ .../framework/ir/trt_support_nhwc_pass.h | 35 ++ .../inference/api/paddle_pass_builder.cc | 3 +- .../test_trt_convert_bilinear_interp_v2.py | 6 +- .../inference/test_trt_support_nhwc_pass.py | 132 +++++++ 10 files changed, 588 insertions(+), 75 deletions(-) create mode 100644 paddle/fluid/framework/ir/trt_support_nhwc_pass.cc create mode 100644 paddle/fluid/framework/ir/trt_support_nhwc_pass.h create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_support_nhwc_pass.py diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 73ce635f57c..3b7d5fb4d8c 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/framework/data_layout_transform.h" -#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -61,6 +61,18 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var, platform::errors::PreconditionNotMet( "TransDataLayout only support DataLayout transform on same place.")); + TransDataLayout(kernel_type_for_var.layout(), + expected_kernel_type.layout(), + place, + in, + out); +} + +void TransDataLayout(DataLayout from_layout, + DataLayout to_layout, + phi::Place place, + const phi::DenseTensor& in, + phi::DenseTensor* out) { PADDLE_ENFORCE_EQ( arity(in.dims()), 4, @@ -73,8 +85,7 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var, auto src_dim = in.dims(); std::vector dst_dim; - auto axis = - GetAxis(kernel_type_for_var.layout(), expected_kernel_type.layout()); + auto axis = GetAxis(from_layout, to_layout); dst_dim.resize(axis.size()); for (size_t i = 0; i < axis.size(); i++) { dst_dim[i] = src_dim[axis[i]]; @@ -83,10 +94,11 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var, out->Resize(phi::make_ddim(dst_dim)); out->mutable_data(place, in.dtype()); - framework::VisitDataType(framework::TransToProtoVarType(in.dtype()), - CastDataLayout(pool.Get(place), axis, in, out)); + framework::VisitDataType( + static_cast(phi::TransToProtoVarType(in.dtype())), + CastDataLayout(pool.Get(place), axis, in, out)); - out->set_layout(expected_kernel_type.layout()); + out->set_layout(to_layout); } } // namespace framework diff --git a/paddle/fluid/framework/data_layout_transform.h b/paddle/fluid/framework/data_layout_transform.h index 3bc55b8ad86..2881953c810 100644 --- a/paddle/fluid/framework/data_layout_transform.h +++ b/paddle/fluid/framework/data_layout_transform.h @@ -14,21 +14,14 @@ #pragma once -#include -#include #include -#include "paddle/fluid/framework/op_kernel_type.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/kernels/funcs/data_layout_transform.h" -namespace paddle { -namespace framework { -class OpKernelType; -} // namespace framework -} // namespace paddle - #ifdef PADDLE_WITH_MKLDNN #include "paddle/phi/backends/onednn/onednn_helper.h" #endif @@ -60,5 +53,11 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var, phi::DenseTensor* out, const phi::Place& place); +void TransDataLayout(phi::DataLayout from_layout, + phi::DataLayout to_layout, + phi::Place place, + const phi::DenseTensor& in, + phi::DenseTensor* out); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/data_layout_transform_test.cc b/paddle/fluid/framework/data_layout_transform_test.cc index 880fa5b057d..b57cc54fb04 100644 --- a/paddle/fluid/framework/data_layout_transform_test.cc +++ b/paddle/fluid/framework/data_layout_transform_test.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/data_layout_transform.h" #include "gtest/gtest.h" +#include "paddle/fluid/platform/bfloat16.h" TEST(DataTransform, DataLayoutFunction) { auto place = paddle::platform::CPUPlace(); diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 476881f0725..6eda1f4b23f 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -141,11 +141,9 @@ if(WITH_TENSORRT) pass_library(layernorm_shift_partition_fuse_pass inference) pass_library(reverse_roll_fuse_pass inference) pass_library(preln_layernorm_x_fuse_pass inference) + pass_library(trt_support_nhwc_pass inference) pass_library(elementwise_groupnorm_act_pass inference) pass_library(preln_elementwise_groupnorm_act_pass inference) -endif() - -if(WITH_TENSORRT) pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference) pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) endif() diff --git a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc index 8e7d435cb5a..4058547e70c 100644 --- a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc +++ b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc @@ -18,7 +18,6 @@ #include #include -#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/phi/common/layout.h" @@ -30,43 +29,11 @@ namespace framework { namespace ir { namespace { -void TransDataLayout(DataLayout from_layout, - DataLayout to_layout, - const phi::DenseTensor &in, - phi::DenseTensor *out) { - PADDLE_ENFORCE_EQ( - arity(in.dims()), - 4, - platform::errors::InvalidArgument( - "Input dimension arity only can be 4, the input dimension is %s.", - in.dims())); - - auto &pool = platform::DeviceContextPool::Instance(); - - auto src_dim = in.dims(); - std::vector dst_dim; - - auto axis = GetAxis(from_layout, to_layout); - dst_dim.resize(axis.size()); - for (size_t i = 0; i < axis.size(); i++) { - dst_dim[i] = src_dim[axis[i]]; - } - - out->Resize(phi::make_ddim(dst_dim)); - out->mutable_data(phi::CPUPlace(), in.dtype()); - - framework::VisitDataType( - framework::TransToProtoVarType(in.dtype()), - CastDataLayout(pool.Get(phi::CPUPlace()), axis, in, out)); - - out->set_layout(to_layout); -} - void InsertLayoutTransOp(ir::Graph *graph, ir::Node *prev_node, ir::Node *next_node, - DataLayout from_layout, - DataLayout to_layout, + phi::DataLayout from_layout, + phi::DataLayout to_layout, framework::BlockDesc *block_desc, std::unordered_map *cache) { auto do_insert = [&](const std::string &in_var_name, @@ -91,7 +58,7 @@ void InsertLayoutTransOp(ir::Graph *graph, op_out_var_desc->SetPersistable(false); op_out_var_desc->SetDataType(prev_node->Var()->GetDataType()); auto to_shape = prev_node->Var()->GetShape(); - if (from_layout == DataLayout::kNCHW) { + if (from_layout == phi::DataLayout::kNCHW) { auto n = to_shape[0]; auto c = to_shape[1]; auto h = to_shape[2]; @@ -117,12 +84,13 @@ void InsertLayoutTransOp(ir::Graph *graph, IR_NODE_UNLINK(prev_node, next_node); }; - if (from_layout == DataLayout::kNCHW && to_layout == DataLayout::kNHWC) { + if (from_layout == phi::DataLayout::kNCHW && + to_layout == phi::DataLayout::kNHWC) { auto in_var_name = prev_node->Var()->Name(); auto out_var_name = in_var_name + "_nchw_to_nhwc"; do_insert(in_var_name, out_var_name); - } else if (from_layout == DataLayout::kNHWC && - to_layout == DataLayout::kNCHW) { + } else if (from_layout == phi::DataLayout::kNHWC && + to_layout == phi::DataLayout::kNCHW) { auto in_var_name = prev_node->Var()->Name(); auto out_var_name = in_var_name + "_nhwc_to_nchw"; do_insert(in_var_name, out_var_name); @@ -135,7 +103,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be nullptr.")); - FusePassBase::Init("data_layout_transfer", graph); + FusePassBase::Init("conv2d_fusion_layout_transfer", graph); auto *scope = param_scope(); // only float16 compute precision need insert transfer_layout. @@ -170,7 +138,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { // Not support multiple block now. std::unordered_map cache; - auto op_nodes = ir::TopologySortOperations(*graph); + auto op_nodes = TopologySortOperations(*graph); auto iter = op_nodes.cbegin(); auto *block_desc = (*iter)->Op()->Block(); @@ -186,7 +154,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { op_node->Op()->GetAttrIfExists("data_format"); if (data_format != "NCHW") return false; auto filter_names = op_node->Op()->Input("Filter"); - constexpr int CUTLASS_NHWC_ALIGNMENT = 8; + constexpr int NHWC_ALIGNMENT = 8; // If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc. for (const auto &filter_name : filter_names) { auto *filter_var = scope->FindLocalVar(filter_name); @@ -195,7 +163,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { int oc = filter_tensor.dims()[0]; int ic = filter_tensor.dims()[1]; bool cutlass_can_support = - oc % CUTLASS_NHWC_ALIGNMENT == 0 && ic % CUTLASS_NHWC_ALIGNMENT == 0; + oc % NHWC_ALIGNMENT == 0 && ic % NHWC_ALIGNMENT == 0; if (!cutlass_can_support) { return false; } @@ -229,8 +197,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { if (cuDNNIsValid(op_node)) { valid_ops.insert(op_node); auto *op_desc = op_node->Op(); - auto nhwc_attr = framework::Attribute(std::string("NHWC")); - op_desc->SetAttr("data_format", nhwc_attr); + op_desc->SetAttr("data_format", std::string{"NHWC"}); if (cutlass_enable && CutlassIsValid(op_node)) { op_desc->SetType("conv2d_fusion_cutlass"); } @@ -244,8 +211,11 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { phi::DenseTensor temp_tensor = *filter_tensor; filter_tensor->clear(); - TransDataLayout( - DataLayout::kNCHW, DataLayout::kNHWC, temp_tensor, filter_tensor); + framework::TransDataLayout(phi::DataLayout::kNCHW, + phi::DataLayout::kNHWC, + phi::CPUPlace{}, + temp_tensor, + filter_tensor); } auto op_inputs = op_node->inputs; for (auto *in_var_node : op_inputs) { @@ -290,8 +260,8 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { InsertLayoutTransOp(graph, in_var_node, op_node, - DataLayout::kNCHW, - DataLayout::kNHWC, + phi::DataLayout::kNCHW, + phi::DataLayout::kNHWC, block_desc, &cache); } @@ -304,8 +274,8 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { InsertLayoutTransOp(graph, in_var_node, op_node, - DataLayout::kNHWC, - DataLayout::kNCHW, + phi::DataLayout::kNHWC, + phi::DataLayout::kNCHW, block_desc, &cache); } diff --git a/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc b/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc new file mode 100644 index 00000000000..3e56200dcaa --- /dev/null +++ b/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc @@ -0,0 +1,365 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/trt_support_nhwc_pass.h" + +#include +#include +#include + +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/data_layout_transform.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/errors.h" + +namespace paddle { +namespace framework { +namespace ir { + +namespace { + +void DoInsertTransposeOp(ir::Graph *graph, + ir::Node *prev_node, + ir::Node *next_node, + phi::DataLayout from_layout, + phi::DataLayout to_layout, + framework::BlockDesc *block_desc, + std::unordered_map *cache) { + auto do_insert = [&](const std::string &in_var_name, + const std::string &out_var_name) { + auto update_op_desc = [&](framework::OpDesc &desc, + const std::string &x_name, + const std::string &out_name, + const std::vector &axis_attr) { + desc.SetType("transpose"); + desc.SetInput("X", {x_name}); + desc.SetOutput("Out", {out_name}); + desc.SetAttr("axis", axis_attr); + desc.SetAttr("use_mkldnn", false); + desc.SetAttr("data_format", std::string{"AnyLayout"}); + desc.SetAttr("use_quantizer", false); + desc.SetAttr("mkldnn_data_type", std::string{"float32"}); + desc.Flush(); + }; + CHECK_NOTNULL(block_desc); + if (cache->count(prev_node) == 0) { + framework::OpDesc op_desc(block_desc); + if (from_layout == phi::DataLayout::kNCHW) { + update_op_desc(op_desc, in_var_name, out_var_name, {0, 2, 3, 1}); + } else if (from_layout == phi::DataLayout::kNHWC) { + update_op_desc(op_desc, in_var_name, out_var_name, {0, 3, 1, 2}); + } + auto *op_node = graph->CreateOpNode(&op_desc); + auto *op_out_var_desc = block_desc->Var(out_var_name); + + op_out_var_desc->SetPersistable(false); + op_out_var_desc->SetDataType(prev_node->Var()->GetDataType()); + auto to_shape = prev_node->Var()->GetShape(); + if (from_layout == phi::DataLayout::kNCHW) { + auto n = to_shape[0]; + auto c = to_shape[1]; + auto h = to_shape[2]; + auto w = to_shape[3]; + op_out_var_desc->SetShape({n, h, w, c}); + } else if (from_layout == phi::DataLayout::kNHWC) { + auto n = to_shape[0]; + auto h = to_shape[1]; + auto w = to_shape[2]; + auto c = to_shape[3]; + op_out_var_desc->SetShape({n, c, h, w}); + } + + auto *op_out_var_node = graph->CreateVarNode(op_out_var_desc); + IR_NODE_LINK_TO(op_node, op_out_var_node); + cache->insert(std::make_pair(prev_node, op_out_var_node)); + } + next_node->Op()->RenameInput(prev_node->Name(), + cache->at(prev_node)->Name()); + IR_NODE_LINK_TO(prev_node, cache->at(prev_node)->inputs.front()); + IR_NODE_LINK_TO(cache->at(prev_node), next_node); + + IR_NODE_UNLINK(prev_node, next_node); + }; + + if (from_layout == phi::DataLayout::kNCHW && + to_layout == phi::DataLayout::kNHWC) { + auto in_var_name = prev_node->Var()->Name(); + auto out_var_name = in_var_name + "_nchw_to_nhwc"; + do_insert(in_var_name, out_var_name); + } else if (from_layout == phi::DataLayout::kNHWC && + to_layout == phi::DataLayout::kNCHW) { + auto in_var_name = prev_node->Var()->Name(); + auto out_var_name = in_var_name + "_nhwc_to_nchw"; + do_insert(in_var_name, out_var_name); + } +} + +bool ModelLayoutIsNHWC(const std::vector &op_nodes) { + for (auto *op_node : op_nodes) { + if (op_node->IsOp()) { + auto *op_desc = op_node->Op(); + std::string data_format; + if (op_desc->HasAttr("data_format")) { + data_format = op_desc->GetAttrIfExists("data_format"); + } else if (op_desc->HasAttr("data_layout")) { + data_format = op_desc->GetAttrIfExists("data_layout"); + } + if (data_format == "NHWC") { + return true; + } + } + } + return false; +} + +} // namespace + +void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::PreconditionNotMet( + "During the trt_support_nhwc_pass, the graph " + "should not be null.")); + FusePassBase::Init("trt_support_nhwc_pass", graph); + auto *scope = param_scope(); + + auto op_nodes = TopologySortOperations(*graph); + + if (!ModelLayoutIsNHWC(op_nodes)) { + return; + } + + // + // + // TODO(liuyuanle): Add other op if needed! + // + // + std::unordered_set need_trans_weights{"prelu"}; + std::unordered_set not_trans_weights{"conv2d", + "pool2d", + "batch_norm", + "bilinear_interp", + "bilinear_interp_v2", + "nearest_interp", + "nearest_interp_v2"}; + // Ops must run under the original layout even though it has + // data_format/data_layout attribute, otherwise it will be very troublesome! + std::unordered_set must_original_layout_ops{"affine_channel", + "softmax"}; + // OPs unrelated to layout are consistent according to the layout of input + // var! + std::unordered_set any_layout_ops{"relu"}; + // + // + // TODO(liuyuanle): Add other op if needed! + // + // + + // Ops with "data_format" or "data_layout" attribute value of "NHWC" + std::unordered_set transposed_ops; + std::unordered_set vars_to_nchw; + + std::unordered_map cache; + + // Not support multiple block now + auto iter = op_nodes.cbegin(); + auto *block_desc = (*iter)->Op()->Block(); + + for (auto *op_node : op_nodes) { + CHECK_EQ(op_node->IsOp(), true); + auto *op_desc = op_node->Op(); + + std::string data_format; + if (op_desc->HasAttr("data_format")) { + data_format = op_desc->GetAttrIfExists("data_format"); + } else if (op_desc->HasAttr("data_layout")) { + data_format = op_desc->GetAttrIfExists("data_layout"); + } + + bool input_shape_4{true}; + auto op_inputs = op_node->inputs; + for (auto *in_var_node : op_inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + if (in_var_node->Var()->Persistable()) continue; + + auto input_shape = in_var_node->Var()->GetShape(); + input_shape_4 &= (input_shape.size() == 4); + } + + if (data_format != "NHWC" || !input_shape_4 || + any_layout_ops.count(op_desc->Type()) || + must_original_layout_ops.count(op_desc->Type())) { + continue; + } + // Transpose NHWC --> NCHW + // + // Update current op + transposed_ops.insert(op_node); + if (op_desc->HasAttr("data_format")) { + op_desc->SetAttr("data_format", std::string{"NCHW"}); + op_desc->Flush(); + } else if (op_desc->HasAttr("data_layout")) { + op_desc->SetAttr("data_layout", std::string{"NCHW"}); + op_desc->Flush(); + } + + auto UpdateOutputVars = [&] { + // Update output var of current op + auto op_outputs = op_node->outputs; + for (auto *out_var_node : op_outputs) { + CHECK_EQ(out_var_node->IsVar(), true); + if (out_var_node->Var()->Persistable()) continue; + + auto from_shape = out_var_node->Var()->GetShape(); + if (from_shape.size() == 4) { + out_var_node->Var()->SetShape( + {from_shape[0], from_shape[3], from_shape[1], from_shape[2]}); + vars_to_nchw.insert(out_var_node); + } + } + }; + + if (not_trans_weights.count(op_desc->Type())) { + UpdateOutputVars(); + } else if (need_trans_weights.count(op_desc->Type())) { + std::vector weights; + if (op_desc->Type() == "prelu") { + weights.push_back("Alpha"); + } + auto UpdateWeightVars = [&] { + for (auto const &weight : weights) { + // transfer weights + auto weight_names = op_desc->Input(weight); + for (const auto &weight_name : weight_names) { + auto *weight_var = scope->FindLocalVar(weight_name); + auto *weight_tensor = weight_var->GetMutable(); + if (weight_tensor->dims().size() == 4) { + phi::DenseTensor temp_tensor = *weight_tensor; + weight_tensor->clear(); + + framework::TransDataLayout(phi::DataLayout::kNHWC, + phi::DataLayout::kNCHW, + phi::CPUPlace{}, + temp_tensor, + weight_tensor); + } + } + auto op_inputs = op_node->inputs; + for (auto *in_var_node : op_inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + if (in_var_node->Var()->Persistable()) { + if (std::find(weight_names.cbegin(), + weight_names.cend(), + in_var_node->Var()->Name()) != + weight_names.cend()) { + auto from_shape = in_var_node->Var()->GetShape(); + in_var_node->Var()->SetShape({from_shape[0], + from_shape[2], + from_shape[3], + from_shape[1]}); + } + } + } + } + }; + UpdateWeightVars(); + UpdateOutputVars(); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "During the trt_support_nhwc_pass, %s op not supported. Please " + "update the supported op lists.", + op_desc->Type())); + } + } + + auto ProcessAnyLayoutOps = [&] { + // Process any layout ops + for (auto *op_node : op_nodes) { + CHECK_EQ(op_node->IsOp(), true); + auto op_inputs = op_node->inputs; + for (auto *in_var_node : op_inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + if (transposed_ops.count(op_node)) continue; + + if (vars_to_nchw.count(in_var_node) && + any_layout_ops.count(op_node->Op()->Type())) { + transposed_ops.insert(op_node); + // Update output var of current op + auto op_outputs = op_node->outputs; + for (auto *out_var_node : op_outputs) { + CHECK_EQ(out_var_node->IsVar(), true); + if (out_var_node->Var()->Persistable()) continue; + + auto from_shape = out_var_node->Var()->GetShape(); + if (from_shape.size() == 4) { + out_var_node->Var()->SetShape( + {from_shape[0], from_shape[3], from_shape[1], from_shape[2]}); + vars_to_nchw.insert(out_var_node); + } + } + } + } + } + }; + ProcessAnyLayoutOps(); + + auto InsertTransposeOp = [&] { + // Insert transpose op + for (auto *op_node : op_nodes) { + CHECK_EQ(op_node->IsOp(), true); + + if (transposed_ops.count(op_node)) { + auto op_inputs = op_node->inputs; + for (auto *in_var_node : op_inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + + if (in_var_node->Var()->Persistable()) continue; + if (vars_to_nchw.count(in_var_node)) continue; + + DoInsertTransposeOp(graph, + in_var_node, + op_node, + phi::DataLayout::kNHWC, + phi::DataLayout::kNCHW, + block_desc, + &cache); + } + } else { + auto op_inputs = op_node->inputs; + for (auto *in_var_node : op_inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + + if (vars_to_nchw.count(in_var_node)) { + DoInsertTransposeOp(graph, + in_var_node, + op_node, + phi::DataLayout::kNCHW, + phi::DataLayout::kNHWC, + block_desc, + &cache); + } + } + } + } + }; + InsertTransposeOp(); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(trt_support_nhwc_pass, paddle::framework::ir::TrtSupportNHWCPass); diff --git a/paddle/fluid/framework/ir/trt_support_nhwc_pass.h b/paddle/fluid/framework/ir/trt_support_nhwc_pass.h new file mode 100644 index 00000000000..e45ce7922a8 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_support_nhwc_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +class TrtSupportNHWCPass : public FusePassBase { + public: + TrtSupportNHWCPass() = default; + ~TrtSupportNHWCPass() = default; + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 5ea79d20fac..9f28343525c 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -85,7 +85,8 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) { void PaddlePassBuilder::ClearPasses() { passes_.clear(); } const std::vector kTRTSubgraphPasses({ - "adaptive_pool2d_convert_global_pass", // + "trt_support_nhwc_pass", + "adaptive_pool2d_convert_global_pass", // "shuffle_channel_detect_pass", // "quant_conv2d_dequant_fuse_pass", // "delete_fill_constant_op_pass", // diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py index 69d5dade40c..f93b598de71 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bilinear_interp_v2.py @@ -44,10 +44,10 @@ class TrtConvertBilinearInterpV2Test(TrtLayerAutoScanTest): for data_layout in ["NCHW", "NHWC"]: for scale_y in [2.0, 1.0]: - for scale_x in [2.0, 1.0]: + for scale_x in [2.0]: scale = [scale_y, scale_x] - for out_h in [32, 64, 128, 192]: - for out_w in [32, 64]: + for out_h in [32, 128]: + for out_w in [64]: dics = [ { "data_layout": data_layout, diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_support_nhwc_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_support_nhwc_pass.py new file mode 100644 index 00000000000..179b191ec38 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_support_nhwc_pass.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import unittest + +import numpy as np + +import paddle +import paddle.inference as inference +import paddle.nn as nn +import paddle.static as static + +paddle.enable_static() + + +class SimpleNet(nn.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.conv1 = nn.Conv2D( + in_channels=4, + out_channels=4, + kernel_size=3, + stride=2, + padding=0, + data_format='NHWC', + ) + self.relu1 = nn.ReLU() + self.conv2 = nn.Conv2D( + in_channels=4, + out_channels=2, + kernel_size=3, + stride=2, + padding=0, + data_format='NHWC', + ) + self.relu2 = nn.ReLU() + self.conv3 = nn.Conv2D( + in_channels=2, + out_channels=1, + kernel_size=3, + stride=2, + padding=0, + data_format='NHWC', + ) + self.relu3 = nn.ReLU() + self.flatten = nn.Flatten() + self.fc = nn.Linear(729, 10) + self.softmax = nn.Softmax() + + def forward(self, x): + x = self.conv1(x) + x = self.relu1(x) + x = self.conv2(x) + x = self.relu2(x) + x = self.conv3(x) + x = self.relu3(x) + x = self.flatten(x) + x = self.fc(x) + x = self.softmax(x) + return x + + +class TRTNHWCConvertTest(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.path = './inference_pass/nhwc_convert/infer_model' + + def create_model(self): + image = static.data( + name='img', shape=[None, 224, 224, 4], dtype='float32' + ) + predict = SimpleNet()(image) + exe = paddle.static.Executor(self.place) + exe.run(paddle.static.default_startup_program()) + paddle.static.save_inference_model(self.path, [image], [predict], exe) + + def create_predictor(self): + config = paddle.inference.Config( + self.path + '.pdmodel', self.path + '.pdiparams' + ) + config.enable_memory_optim() + config.enable_use_gpu(100, 0) + config.enable_tensorrt_engine( + workspace_size=1 << 30, + max_batch_size=1, + min_subgraph_size=3, + precision_mode=inference.PrecisionType.Float32, + use_static=False, + use_calib_mode=False, + ) + predictor = inference.create_predictor(config) + return predictor + + def infer(self, predictor, img): + input_names = predictor.get_input_names() + for i, name in enumerate(input_names): + input_tensor = predictor.get_input_handle(name) + input_tensor.reshape(img[i].shape) + input_tensor.copy_from_cpu(img[i].copy()) + predictor.run() + results = [] + output_names = predictor.get_output_names() + for i, name in enumerate(output_names): + output_tensor = predictor.get_output_handle(name) + output_data = output_tensor.copy_to_cpu() + results.append(output_data) + return results + + def test_nhwc_convert(self): + self.create_model() + predictor = self.create_predictor() + img = np.ones((1, 224, 224, 4), dtype=np.float32) + result = self.infer(predictor, img=[img]) + + def tearDown(self): + shutil.rmtree('./inference_pass/nhwc_convert/') + + +if __name__ == '__main__': + unittest.main() -- GitLab