未验证 提交 8b063030 编写于 作者: M ming1753 提交者: GitHub

paddle-TRT support float64 (#55520)

* Paddle-TRT support float64  in/out type, support fill_any_like_op in int64
上级 51ebcf68
......@@ -336,9 +336,9 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
x->outputs.size() <= 1) {
params_not_shared.push_back(x->Name());
}
// When TRT Engine's input is INT64, we need do some extra work.
// So we reserved a name for later use when casting INT64 -> INT32.
// We must check whether scope has had the same name var!
// When TRT Engine's input is INT64 or FP64, we need do some extra work.
// So we reserved a name for later use when casting INT64 -> INT32 or
// FP64->FP32. We must check whether scope has had the same name var!
if (x->Var()->GetDataType() == framework::proto::VarType::INT64) {
std::string tmp_name = x->Name() + "_cast_to_INT32";
LOG(WARNING)
......@@ -353,6 +353,20 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
tmp_name));
*/
scope->Var(tmp_name);
} else if (x->Var()->GetDataType() == framework::proto::VarType::FP64) {
std::string tmp_name = x->Name() + "_cast_to_FP32";
LOG(WARNING) << "tensorrt_subgraph's input named " << x->Name()
<< " having float64 dtype in pdmodel description, we will "
"cast them to "
"float32 dtype to feed them into paddle-trt.";
/*
PADDLE_ENFORCE_EQ(scope->FindVar(tmp_name),
nullptr,
platform::errors::InvalidArgument(
"The var name %s has exists in scope.",
tmp_name));
*/
scope->Var(tmp_name);
}
}
......@@ -489,9 +503,9 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
renamed_output_rank.push_back(origin_name_output_rank[name]);
origin_outputs_dtype.push_back(map_origin_outputs_dtype[name]);
// When TRT Engine's output is INT64, we need do some extra work.
// So we reserved a name for later use when casting INT32 -> INT64.
// We must check whether scope has had the same name var!
// When TRT Engine's output is INT64 or FP64, we need do some extra work.
// So we reserved a name for later use when casting INT32 -> INT64 or FP32
// -> FP64. We must check whether scope has had the same name var!
if (static_cast<framework::proto::VarType_Type>(
map_origin_outputs_dtype[name]) ==
framework::proto::VarType::INT64) {
......@@ -506,6 +520,21 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
platform::errors::InvalidArgument(
"The var name %s has exists in scope.", tmp_name));
scope->Var(tmp_name);
} else if (static_cast<framework::proto::VarType_Type>(
map_origin_outputs_dtype[name]) ==
framework::proto::VarType::FP64) {
std::string tmp_name = name + "_cast_to_FP64";
LOG(WARNING)
<< "tensorrt_subgraph's output named " << name
<< " having float64 dtype in pdmodel description, but in fact "
"it is float32 "
"dtype after executing this tensorrt_subgraph, so we "
"need cast them into float64.";
PADDLE_ENFORCE_EQ(scope->FindVar(tmp_name),
nullptr,
platform::errors::InvalidArgument(
"The var name %s has exists in scope.", tmp_name));
scope->Var(tmp_name);
}
}
PADDLE_ENFORCE_EQ(output_mapping.empty(),
......
......@@ -46,6 +46,7 @@ class CastOpConverter : public OpConverter {
layer->getOutput(0)->setType(nvinfer1::DataType::kHALF);
break;
case 5: // FP32 = 5
case 6: // FP64 = 6
layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT);
break;
......
......@@ -37,6 +37,11 @@ class FillAnyLikeOpConverter : public OpConverter {
(dtype == -1 && input->getType() == nvinfer1::DataType::kINT32)) {
value_tensor = Add1DConstantLayer(static_cast<int32_t>(value),
output_name + "_value_tensor_");
} else if (dtype == 3) {
LOG(WARNING) << "the fill_any_like has int64 dtype, it "
"will be cast to int32.";
value_tensor = Add1DConstantLayer(static_cast<int32_t>(value),
output_name + "_value_tensor_");
} else {
value_tensor = Add1DConstantLayer(value, output_name + "_value_tensor_");
}
......
......@@ -117,6 +117,7 @@ namespace { // NOLINT
TRT_DT FluidDataType2TRT(FluidDT type) {
switch (type) {
case FluidDT::VarType_Type_FP32:
case FluidDT::VarType_Type_FP64:
return TRT_DT::kFLOAT;
case FluidDT::VarType_Type_INT32:
case FluidDT::VarType_Type_INT64:
......
......@@ -100,37 +100,6 @@ struct SimpleOpTypeSetTeller : public Teller {
return false;
}
// Dont.t allow fp64!
{
auto inputs = desc.Inputs();
for (auto iter : inputs) {
for (auto var_name : iter.second) {
auto* block = desc.Block();
if (block) {
auto* var_desc = block->FindVar(var_name);
auto dtype = var_desc->GetDataType();
if (dtype == framework::proto::VarType::FP64) {
return false;
}
}
}
}
auto outputs = desc.Outputs();
for (auto iter : outputs) {
for (auto var_name : iter.second) {
auto* block = desc.Block();
if (block) {
auto* var_desc = block->FindVar(var_name);
auto dtype = var_desc->GetDataType();
if (dtype == framework::proto::VarType::FP64) {
return false;
}
}
}
}
}
// do not support the op which is labeled the `skip_quant`
if ((desc.HasAttr("namescope") &&
PADDLE_GET_CONST(std::string, desc.GetAttr("op_namescope")) ==
......@@ -425,7 +394,8 @@ struct SimpleOpTypeSetTeller : public Teller {
auto start_var_name = desc.Input("Start")[0];
auto* start_var_desc = block->FindVar(start_var_name);
auto start_dtype = start_var_desc->GetDataType();
if (start_dtype == framework::proto::VarType::FP32) {
if (start_dtype == framework::proto::VarType::FP32 ||
start_dtype == framework::proto::VarType::FP64) {
return false;
}
#endif
......@@ -751,7 +721,8 @@ struct SimpleOpTypeSetTeller : public Teller {
auto x_dtype = x_var_desc->GetDataType();
if (!(x_dtype == framework::proto::VarType::FP32 ||
x_dtype == framework::proto::VarType::FP16)) {
x_dtype == framework::proto::VarType::FP16 ||
x_dtype == framework::proto::VarType::FP64)) {
return false;
}
......@@ -1229,16 +1200,18 @@ struct SimpleOpTypeSetTeller : public Teller {
const auto x_shape = x_var_desc->GetShape();
auto dtype = x_var_desc->GetDataType();
if (!with_dynamic_shape) {
// At present, only support float32 or float16 into trt.
// At present, only support float32 or float16 or float64 into trt.
if (!(dtype == framework::proto::VarType::FP32 ||
dtype == framework::proto::VarType::FP64 ||
dtype == framework::proto::VarType::FP16)) {
return false;
}
} else {
// At present, only support float32 or float16 or int32 or int64 into
// trt.
// At present, only support float32 or float16 or float64 or int32 or
// int64 into trt.
if (!(dtype == framework::proto::VarType::FP32 ||
dtype == framework::proto::VarType::FP16 ||
dtype == framework::proto::VarType::FP64 ||
dtype == framework::proto::VarType::INT32 ||
dtype == framework::proto::VarType::INT64)) {
return false;
......@@ -1339,15 +1312,19 @@ struct SimpleOpTypeSetTeller : public Teller {
return true;
}
#endif
if (dtype != -1 && dtype != 2 && dtype != 5) {
VLOG(3) << "the fill_any_like only supports int32 and float32 by "
"trt8.4 below";
if (dtype != -1 && dtype != 2 && dtype != 3 && dtype != 5 && dtype != 6) {
VLOG(3)
<< "the fill_any_like only supports int32/int64/float32/float64 by"
"trt8.4 below";
return false;
}
if (dtype == -1) {
if (input_type != framework::proto::VarType::INT32 &&
input_type != framework::proto::VarType::FP32) {
VLOG(3) << "the fill_any_like only supports int32 and float32 by "
input_type != framework::proto::VarType::INT64 &&
input_type != framework::proto::VarType::FP32 &&
input_type != framework::proto::VarType::FP64) {
VLOG(3) << "the fill_any_like only supports "
"int32/int64/float32/float64 by"
"trt8.4 below";
return false;
}
......@@ -2245,13 +2222,19 @@ struct SimpleOpTypeSetTeller : public Teller {
} else {
#if IS_TRT_VERSION_GE(7000)
if (dtype != framework::proto::VarType::INT32 &&
dtype != framework::proto::VarType::FP32) {
VLOG(3) << "reduce op input data type must be int32 or float32";
dtype != framework::proto::VarType::INT64 &&
dtype != framework::proto::VarType::FP32 &&
dtype != framework::proto::VarType::FP64) {
VLOG(3) << "reduce op input data type must be int32 or int64 or "
"float32 or "
"float64";
return false;
}
#else
if (dtype != framework::proto::VarType::FP32) {
VLOG(3) << "reduce op input data type must be float32 using TensorRT "
if (dtype != framework::proto::VarType::FP32 &&
dtype != framework::proto::VarType::FP64) {
VLOG(3) << "reduce op input data type must be float32 or float64 "
"using TensorRT "
"< 7.0";
return false;
}
......
......@@ -702,6 +702,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (t.dtype() == phi::DataType::FLOAT32) {
buffers[bind_index] = static_cast<void *>(t.data<float>());
} else if (t.dtype() == phi::DataType::FLOAT64) {
auto fp32_tensor =
scope.FindVar(x + "_cast_to_FP32")->GetMutable<phi::DenseTensor>();
*fp32_tensor = phi::Cast<double>(
reinterpret_cast<const phi::GPUContext &>(dev_ctx),
t,
phi::DataType::FLOAT32);
buffers[bind_index] = static_cast<void *>(fp32_tensor->data<float>());
} else if (t.dtype() == phi::DataType::INT64) {
auto int32_tensor =
scope.FindVar(x + "_cast_to_INT32")->GetMutable<phi::DenseTensor>();
......@@ -722,7 +730,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
} else {
PADDLE_THROW(platform::errors::Fatal(
"The TRT Engine OP only support "
"float/int32_t/int64_t/float16/bool input."));
"float/double/int32_t/int64_t/float16/bool input."));
}
}
......@@ -828,6 +836,19 @@ class TensorRTEngineOp : public framework::OperatorBase {
reinterpret_cast<const phi::GPUContext &>(dev_ctx),
*int32_tensor,
phi::DataType::INT64);
} else if (type == framework::proto::VarType::FP64) {
auto y = Outputs("Ys")[i];
auto *fluid_v = scope.FindVar(y);
auto *fluid_t = fluid_v->GetMutable<phi::DenseTensor>();
auto fp32_tensor =
scope.FindVar(y + "_cast_to_FP64")->GetMutable<phi::DenseTensor>();
fp32_tensor->Resize(fluid_t->dims());
dev_ctx.Alloc<float>(fp32_tensor);
framework::TensorCopy(*fluid_t, dev_place, dev_ctx, fp32_tensor);
*fluid_t =
phi::Cast<float>(reinterpret_cast<const phi::GPUContext &>(dev_ctx),
*fp32_tensor,
phi::DataType::FLOAT64);
}
}
}
......
......@@ -25,7 +25,7 @@ import paddle.inference as paddle_infer
class TrtConvertExpandV2Test(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
if self.dtype in [0, 3, 4]:
if self.dtype in [0, 1, 4]:
return False
if self.dims != 4 and self.dtype != 2:
return False
......@@ -37,14 +37,20 @@ class TrtConvertExpandV2Test(TrtLayerAutoScanTest):
self.input_shape = [1, 1, 4, 6]
if self.dtype == 0:
return np.random.random([1, 1, 4, 6]).astype(np.bool_)
elif self.dtype == 2 or self.dtype == -1:
elif self.dtype == 1:
return np.random.random([1, 1, 4, 6]).astype(np.int16)
elif self.dtype == 2:
return np.random.random([1, 1, 4, 6]).astype(np.int32)
elif self.dtype == 3:
return np.random.random([1, 1, 4, 6]).astype(np.int64)
elif self.dtype == 4:
return np.random.random([1, 1, 4, 6]).astype(np.float16)
else:
elif self.dtype == 5:
return np.random.random([1, 1, 4, 6]).astype(np.float32)
elif self.dtype == 6:
return np.random.random([1, 1, 4, 6]).astype(np.float64)
else:
return np.random.random([1, 1, 4, 6]).astype(np.int32)
elif self.dims == 3:
self.input_shape = [1, 8, 6]
return np.random.random([1, 8, 6]).astype(np.int32)
......@@ -66,7 +72,7 @@ class TrtConvertExpandV2Test(TrtLayerAutoScanTest):
for dims in [1, 2, 3, 4]:
for value in [2]:
for dtype in [-1, 0, 2, 3, 4, 5]:
for dtype in [-1, 0, 1, 2, 3, 4, 5, 6]:
dics = [
{
"value": value,
......
......@@ -53,6 +53,10 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
return np.random.random([1, 3, 64, 64]).astype(np.int32)
elif dtype == 0:
return np.random.random([1, 3, 64, 64]).astype(np.bool_)
elif dtype == 3:
return np.random.random([1, 3, 64, 64]).astype(np.int64)
elif dtype == 6:
return np.random.random([1, 3, 64, 64]).astype(np.float64)
for keep_dim in [True, False]:
for dim in [
......@@ -67,7 +71,7 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
[3, 4, 5],
]:
for reduce_all in [True, False]:
for out_dtype in [-1, 0, 2, 5]:
for out_dtype in [-1, 0, 2, 5, 3, 6]:
if out_dtype != 0:
reduce_type_list = [
"reduce_max",
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from typing import Any, Dict, List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtFloat64Test1(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
inputs = program_config.inputs
weights = program_config.weights
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
out_shape = list(inputs['input_data'].shape)
for x in range(len(attrs[0]["axes"])):
start = 0
end = 0
if attrs[0]["starts"][x] < 0:
start = (
attrs[0]["starts"][x]
+ inputs['input_data'].shape[attrs[0]["axes"][x]]
)
else:
start = attrs[0]["starts"][x]
if attrs[0]["ends"][x] < 0:
end = (
attrs[0]["ends"][x]
+ inputs['input_data'].shape[attrs[0]["axes"][x]]
)
else:
end = attrs[0]["ends"][x]
start = max(0, start)
end = max(0, end)
out_shape[attrs[0]["axes"][x]] = end - start
if start >= end:
return False
for x in attrs[0]["decrease_axis"]:
if x < 0:
return False
if out_shape[x] != 1:
return False
return True
def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]]):
return (10 * np.random.random([6, 6, 64, 64])).astype(np.float64)
for axes in [[0, 1], [1, 3], [2, 3]]:
for starts in [[0, 1]]:
for ends in [[2, 2], [5, 5], [1, -1]]:
for decrease_axis in [[], [1], [2], [-1], [-100]]:
for infer_flags in [[-1]]:
dics = [
{
"axes": axes,
"starts": starts,
"ends": ends,
"decrease_axis": decrease_axis,
"infer_flags": infer_flags,
}
]
ops_config = [
{
"op_type": "slice",
"op_inputs": {"Input": ["input_data"]},
"op_outputs": {
"Out": ["slice_output_data"]
},
"op_attrs": dics[0],
"outputs_dtype": {
"slice_output_data": np.float64
},
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input1, dics)
)
},
outputs=["slice_output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
self.dynamic_shape.max_input_shape = {"input_data": [8, 8, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [6, 6, 64, 64]}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-3
def test(self):
self.run_test()
class TrtFloat64Test2(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(shape, op_type):
return np.random.randint(low=1, high=10000, size=shape).astype(
np.float64
)
for shape in [[2, 32, 16], [1, 8, 16, 32]]:
for op_type in [
"elementwise_add",
"elementwise_mul",
"elementwise_sub",
]:
for axis in [0, -1]:
self.dims = len(shape)
dics = [{"axis": axis}]
ops_config = [
{
"op_type": op_type,
"op_inputs": {
"X": ["input_data1"],
"Y": ["input_data2"],
},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0],
"outputs_dtype": {"slice_output_data": np.float64},
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data1": TensorConfig(
data_gen=partial(generate_input, shape, op_type)
),
"input_data2": TensorConfig(
data_gen=partial(generate_input, shape, op_type)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 3:
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 4, 4],
"input_data2": [1, 4, 4],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [128, 128, 256],
"input_data2": [128, 128, 256],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 32, 16],
"input_data2": [2, 32, 16],
}
elif self.dims == 4:
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 4, 4, 4],
"input_data2": [1, 4, 4, 4],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [8, 128, 64, 128],
"input_data2": [8, 128, 64, 128],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 64, 32, 32],
"input_data2": [2, 64, 32, 32],
}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 3
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 3), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), (1e-3, 1e-3)
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册