未验证 提交 fd373579 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT] add cast between int64 tensor and Paddle-TRT (#45547)

* Add cast between int64 tensor and Paddle-TRT
* Add Unit testing.
上级 2935ce07
......@@ -253,6 +253,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// problem, so we filter them out.
std::vector<std::string> params_not_shared;
auto *scope = param_scope();
// The node->inputs contains input tensors and parameters.
for (auto *x : node->inputs) {
input_names.insert(x->Name());
......@@ -264,6 +265,21 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
x->outputs.size() <= 1) {
params_not_shared.push_back(x->Name());
}
// When TRT Engine's input is INT64, we need do some extra work.
// So we reserved a name for later use when casting INT64 -> INT32.
// We must check whether scope has had the same name var!
if (x->Var()->GetDataType() == framework::proto::VarType::INT64) {
std::string tmp_name = x->Name() + "_cast_to_INT32";
LOG(WARNING)
<< "tensorrt_subgraph's input named " << tmp_name
<< " having int64 dtype in pdmodel description, we will cast them to "
"int32 dtype to feed them into paddle-trt.";
PADDLE_ENFORCE_EQ(scope->FindVar(tmp_name),
nullptr,
platform::errors::InvalidArgument(
"The var name %s has exists in scope.", tmp_name));
scope->Var(tmp_name);
}
}
auto model_precision =
......@@ -273,13 +289,18 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::set<std::string> output_names;
std::set<std::string> output_names_with_id;
std::map<std::string, int> origin_name_output_dims;
std::map<std::string, int> origin_name_output_rank;
std::unordered_set<Node *> trt_outputs;
// record the origin output data type
std::vector<int> origin_outputs_dtype;
std::map<std::string, int> map_origin_outputs_dtype;
for (auto *x : node->outputs) {
output_names.insert(x->Name());
output_names_with_id.insert(x->Name() + std::to_string(x->id()));
origin_name_output_dims[x->Name()] = x->Var()->GetShape().size();
origin_name_output_rank[x->Name()] = x->Var()->GetShape().size();
trt_outputs.insert(x);
map_origin_outputs_dtype[x->Name()] =
static_cast<int>(x->Var()->GetDataType());
}
OutputProcess(
......@@ -353,14 +374,34 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// output_mapping help us copy the data from the renamed ITensor
// to Tensor.
std::vector<std::string> output_mapping;
std::vector<int> renamed_output_dims;
std::vector<int> renamed_output_rank;
for (auto name : output_names) {
PADDLE_ENFORCE_NE(output_name_map.count(name),
0,
platform::errors::PreconditionNotMet(
"The output_name_map should have %s", name));
output_mapping.push_back(output_name_map[name]);
renamed_output_dims.push_back(origin_name_output_dims[name]);
renamed_output_rank.push_back(origin_name_output_rank[name]);
origin_outputs_dtype.push_back(map_origin_outputs_dtype[name]);
// When TRT Engine's output is INT64, we need do some extra work.
// So we reserved a name for later use when casting INT32 -> INT64.
// We must check whether scope has had the same name var!
if (static_cast<framework::proto::VarType_Type>(
map_origin_outputs_dtype[name]) ==
framework::proto::VarType::INT64) {
std::string tmp_name = name + "_cast_to_INT64";
LOG(WARNING) << "tensorrt_subgraph's output named " << tmp_name
<< " having int64 dtype in pdmodel description, but in fact "
"it is int32 "
"dtype after executing this tensorrt_subgraph, so we "
"need cast them into int64.";
PADDLE_ENFORCE_EQ(scope->FindVar(tmp_name),
nullptr,
platform::errors::InvalidArgument(
"The var name %s has exists in scope.", tmp_name));
scope->Var(tmp_name);
}
}
PADDLE_ENFORCE_EQ(output_mapping.empty(),
false,
......@@ -381,11 +422,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
op_desc->SetBlockAttr("sub_block", new_block);
op_desc->SetAttr("subgraph", block_desc.Proto()->SerializeAsString());
op_desc->SetAttr("origin_outputs_dtype", origin_outputs_dtype);
op_desc->SetAttr("max_batch_size", max_batch_size);
op_desc->SetAttr("workspace_size", Get<int64_t>("workspace_size"));
op_desc->SetAttr("gpu_id", Get<int>("gpu_device_id"));
op_desc->SetAttr("output_name_mapping", output_mapping);
op_desc->SetAttr("origin_output_dims", renamed_output_dims);
op_desc->SetAttr("origin_output_rank", renamed_output_rank);
op_desc->SetAttr("parameters", params);
op_desc->SetAttr("allow_build_at_runtime", allow_build_at_runtime);
op_desc->SetAttr("shape_range_info_path", shape_range_info_path);
......@@ -548,7 +590,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time.";
auto *scope = param_scope();
framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto());
std::unordered_set<std::string> param_set(params.begin(), params.end());
inference::Singleton<inference::tensorrt::OpConverter>::Global()
......
......@@ -60,6 +60,7 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
case FluidDT::VarType_Type_FP32:
return TRT_DT::kFLOAT;
case FluidDT::VarType_Type_INT32:
case FluidDT::VarType_Type_INT64:
return TRT_DT::kINT32;
case FluidDT::VarType_Type_FP16:
return TRT_DT::kHALF;
......@@ -68,10 +69,9 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
return TRT_DT::kBOOL;
#endif
default:
return TRT_DT::kINT32;
PADDLE_THROW(platform::errors::InvalidArgument(
"unknown fluid datatype in TRT op converter"));
}
PADDLE_THROW(platform::errors::InvalidArgument(
"unknown fluid datatype in TRT op converter"));
return TRT_DT::kINT32;
}
......
......@@ -21,13 +21,13 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#ifdef PADDLE_WITH_CUDA
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/fluid/framework/data_device_transform.h"
#include "paddle/fluid/framework/executor.h"
......@@ -596,7 +596,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (type == framework::proto::VarType::FP32) {
buffers[bind_index] = static_cast<void *>(t.data<float>());
} else if (type == framework::proto::VarType::INT64) {
buffers[bind_index] = static_cast<void *>(t.data<int64_t>());
auto int32_tensor =
scope.FindVar(x + "_cast_to_INT32")->GetMutable<phi::DenseTensor>();
*int32_tensor = phi::Cast<int64_t>(
reinterpret_cast<const phi::GPUContext &>(dev_ctx),
t,
phi::DataType::INT32);
buffers[bind_index] =
static_cast<void *>(int32_tensor->data<int32_t>());
} else if (type == framework::proto::VarType::INT32) {
buffers[bind_index] = static_cast<void *>(t.data<int32_t>());
} else if (type == framework::proto::VarType::FP16) {
......@@ -614,8 +621,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
// Bind output tensor to TRT.
int output_index = 0;
std::vector<int> origin_output_dims =
Attr<std::vector<int>>("origin_output_dims");
std::vector<int> origin_output_rank =
Attr<std::vector<int>>("origin_output_rank");
VLOG(4) << "TensorRT Engine Op Outputs:";
for (const auto &y : Outputs("Ys")) {
const int bind_index =
......@@ -636,7 +643,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
for (; nb_dims > 0; nb_dims--) {
// some 'x 1' of shape is normal, no need to remove it
if (dims.d[nb_dims - 1] != 1 ||
nb_dims == origin_output_dims[output_index])
nb_dims == origin_output_rank[output_index])
break;
}
for (int i = 0; i < nb_dims; i++) ddim.push_back(dims.d[i]);
......@@ -694,6 +701,28 @@ class TensorRTEngineOp : public framework::OperatorBase {
}
// Execute the engine.
engine->Execute(runtime_batch, &buffers, stream);
std::vector<int> origin_outputs_dtype =
Attr<std::vector<int>>("origin_outputs_dtype");
for (size_t i = 0; i < Outputs("Ys").size(); i++) {
auto type =
static_cast<framework::proto::VarType_Type>(origin_outputs_dtype[i]);
if (type == framework::proto::VarType::INT64) {
auto y = Outputs("Ys")[i];
auto *fluid_v = scope.FindVar(y);
auto *fluid_t = fluid_v->GetMutable<phi::DenseTensor>();
auto int32_tensor =
scope.FindVar(y + "_cast_to_INT64")->GetMutable<phi::DenseTensor>();
int32_tensor->Resize(fluid_t->dims());
dev_ctx.Alloc<int32_t>(int32_tensor);
framework::TensorCopy(*fluid_t, dev_place, dev_ctx, int32_tensor);
*fluid_t = phi::Cast<int32_t>(
reinterpret_cast<const phi::GPUContext &>(dev_ctx),
*int32_tensor,
phi::DataType::INT64);
}
}
}
TensorRTEngine *GetEngine(const framework::Scope &scope,
......
......@@ -104,6 +104,7 @@ void DynamicShapeTest(bool allow_build_at_runtime) {
engine_op_desc.SetType("tensorrt_engine");
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
engine_op_desc.SetAttr("origin_outputs_dtype", std::vector<int>{5});
engine_op_desc.SetBlockAttr("sub_block", &block_desc);
engine_op_desc.SetAttr("max_batch_size", static_cast<int>(2));
......@@ -119,7 +120,7 @@ void DynamicShapeTest(bool allow_build_at_runtime) {
engine_op_desc.SetAttr("use_calib_mode", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z0"}));
engine_op_desc.SetAttr("origin_output_dims", std::vector<int>({2}));
engine_op_desc.SetAttr("origin_output_rank", std::vector<int>({2}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0;
......@@ -274,7 +275,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc.SetAttr("use_calib_mode", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z3"}));
engine_op_desc.SetAttr("origin_output_dims", std::vector<int>({2}));
engine_op_desc.SetAttr("origin_output_rank", std::vector<int>({2}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0;
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from typing import Any, Dict, List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtInt64Test1(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
inputs = program_config.inputs
weights = program_config.weights
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
out_shape = list(inputs['input_data'].shape)
for x in range(len(attrs[0]["axes"])):
start = 0
end = 0
if attrs[0]["starts"][x] < 0:
start = (
attrs[0]["starts"][x]
+ inputs['input_data'].shape[attrs[0]["axes"][x]]
)
else:
start = attrs[0]["starts"][x]
if attrs[0]["ends"][x] < 0:
end = (
attrs[0]["ends"][x]
+ inputs['input_data'].shape[attrs[0]["axes"][x]]
)
else:
end = attrs[0]["ends"][x]
start = max(0, start)
end = max(0, end)
out_shape[attrs[0]["axes"][x]] = end - start
if start >= end:
return False
for x in attrs[0]["decrease_axis"]:
if x < 0:
return False
if out_shape[x] != 1:
return False
return True
def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]]):
return (10 * np.random.random([6, 6, 64, 64])).astype(np.int64)
for axes in [[0, 1], [1, 3], [2, 3]]:
for starts in [[0, 1]]:
for ends in [[2, 2], [5, 5], [1, -1]]:
for decrease_axis in [[], [1], [2], [-1], [-100]]:
for infer_flags in [[-1]]:
dics = [
{
"axes": axes,
"starts": starts,
"ends": ends,
"decrease_axis": decrease_axis,
"infer_flags": infer_flags,
}
]
ops_config = [
{
"op_type": "slice",
"op_inputs": {"Input": ["input_data"]},
"op_outputs": {
"Out": ["slice_output_data"]
},
"op_attrs": dics[0],
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input1, dics)
)
},
outputs=["slice_output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
self.dynamic_shape.max_input_shape = {"input_data": [8, 8, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [6, 6, 64, 64]}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-3
def test(self):
self.run_test()
class TrtInt64Test2(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(shape, op_type):
return np.random.randint(
low=1, high=10000, size=shape, dtype=np.int64
)
for shape in [[2, 32, 16], [1, 8, 16, 32]]:
for op_type in [
"elementwise_add",
"elementwise_mul",
"elementwise_sub",
]:
for axis in [0, -1]:
self.dims = len(shape)
dics = [{"axis": axis}]
ops_config = [
{
"op_type": op_type,
"op_inputs": {
"X": ["input_data1"],
"Y": ["input_data2"],
},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0],
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data1": TensorConfig(
data_gen=partial(generate_input, shape, op_type)
),
"input_data2": TensorConfig(
data_gen=partial(generate_input, shape, op_type)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 3:
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 4, 4],
"input_data2": [1, 4, 4],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [128, 128, 256],
"input_data2": [128, 128, 256],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 32, 16],
"input_data2": [2, 32, 16],
}
elif self.dims == 4:
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 4, 4, 4],
"input_data2": [1, 4, 4, 4],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [8, 128, 64, 128],
"input_data2": [8, 128, 64, 128],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 64, 32, 32],
"input_data2": [2, 64, 32, 32],
}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 3
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 3), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), (1e-3, 1e-3)
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册