Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4cdeab7b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4cdeab7b
编写于
12月 21, 2022
作者:
Y
Yuanle Liu
提交者:
GitHub
12月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix get trt weight (#49197)
上级
7f0eb2e3
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
46 addition
and
14 deletion
+46
-14
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+46
-14
未找到文件。
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
4cdeab7b
...
...
@@ -582,10 +582,8 @@ TensorRTEngine::Weight TensorRTEngine::GetFp16TrtWeight(
TensorRTEngine
::
Weight
weight
;
weight
.
SetCount
(
weight_tensor
.
numel
());
weight
.
SetDataType
(
nvinfer1
::
DataType
::
kHALF
);
// weight_tensor.dims().;
// if trt not support dtype, we need to cast to
fp16.
// if trt not support dtype, we need to cast to fp16.
if
(
weight_tensor
.
dtype
()
==
phi
::
DataType
::
BFLOAT16
)
{
phi
::
DenseTensor
bf16_tensor
;
bf16_tensor
.
clear
();
...
...
@@ -593,13 +591,14 @@ TensorRTEngine::Weight TensorRTEngine::GetFp16TrtWeight(
weight_tensor
,
platform
::
CPUPlace
(),
&
bf16_tensor
);
weight_map
[
name_with_suffix
]
->
set_type
(
paddle
::
experimental
::
DataType
::
FLOAT16
);
weight_map
[
name_with_suffix
]
->
Resize
(
weight_tensor
.
dims
());
auto
*
fp16_data
=
weight_map
[
name_with_suffix
]
->
mutable_data
<
float16
>
(
platform
::
CPUPlace
());
auto
*
bf16_data
=
bf16_tensor
.
mutable_data
<
bfloat16
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
weight_tensor
.
numel
();
i
++
)
{
fp16_data
[
i
]
=
static_cast
<
float16
>
(
bf16_data
[
i
]);
}
weight
.
SetDataType
(
phi
::
DataType
::
FLOAT16
);
weight
.
SetValues
(
fp16_data
);
}
else
if
(
weight_tensor
.
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
phi
::
DenseTensor
fp32_tensor
;
fp32_tensor
.
clear
();
...
...
@@ -607,18 +606,35 @@ TensorRTEngine::Weight TensorRTEngine::GetFp16TrtWeight(
weight_tensor
,
platform
::
CPUPlace
(),
&
fp32_tensor
);
weight_map
[
name_with_suffix
]
->
set_type
(
paddle
::
experimental
::
DataType
::
FLOAT16
);
weight_map
[
name_with_suffix
]
->
Resize
(
weight_tensor
.
dims
());
auto
*
fp16_data
=
weight_map
[
name_with_suffix
]
->
mutable_data
<
float16
>
(
platform
::
CPUPlace
());
auto
*
fp32_data
=
fp32_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
weight_tensor
.
numel
();
i
++
)
{
fp16_data
[
i
]
=
static_cast
<
float16
>
(
fp32_data
[
i
]);
}
weight
.
SetDataType
(
phi
::
DataType
::
FLOAT16
);
weight
.
SetValues
(
fp16_data
);
}
else
if
(
weight_tensor
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
phi
::
DenseTensor
int64_tensor
;
int64_tensor
.
clear
();
paddle
::
framework
::
TensorCopySync
(
weight_tensor
,
platform
::
CPUPlace
(),
&
int64_tensor
);
weight_map
[
name_with_suffix
]
->
set_type
(
paddle
::
experimental
::
DataType
::
INT32
);
auto
*
int32_data
=
weight_map
[
name_with_suffix
]
->
mutable_data
<
int32_t
>
(
platform
::
CPUPlace
());
auto
*
int64_data
=
int64_tensor
.
mutable_data
<
int64_t
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
weight_tensor
.
numel
();
i
++
)
{
int32_data
[
i
]
=
int64_data
[
i
];
}
weight
.
SetDataType
(
phi
::
DataType
::
INT32
);
weight
.
SetValues
(
int32_data
);
}
else
{
paddle
::
framework
::
TensorCopySync
(
weight_tensor
,
cpu_place
,
weight_map
[
name_with_suffix
].
get
());
weight
.
SetDataType
(
weight_tensor
.
dtype
());
weight
.
SetValues
(
weight_map
[
name_with_suffix
]
->
data
());
}
weight
.
SetValues
(
weight_map
[
name_with_suffix
]
->
data
());
name_suffix_counter
+=
1
;
return
weight
;
}
...
...
@@ -642,10 +658,8 @@ TensorRTEngine::Weight TensorRTEngine::GetFp32TrtWeight(
TensorRTEngine
::
Weight
weight
;
weight
.
SetCount
(
weight_tensor
.
numel
());
weight
.
SetDataType
(
nvinfer1
::
DataType
::
kFLOAT
);
// weight_tensor.dims().;
// if trt not support dtype, we need to cast to
fp32.
// if trt not support dtype, we need to cast to fp32.
if
(
weight_tensor
.
dtype
()
==
phi
::
DataType
::
BFLOAT16
)
{
phi
::
DenseTensor
bf16_tensor
;
bf16_tensor
.
clear
();
...
...
@@ -653,13 +667,14 @@ TensorRTEngine::Weight TensorRTEngine::GetFp32TrtWeight(
weight_tensor
,
platform
::
CPUPlace
(),
&
bf16_tensor
);
weight_map
[
name_with_suffix
]
->
set_type
(
paddle
::
experimental
::
DataType
::
FLOAT32
);
weight_map
[
name_with_suffix
]
->
Resize
(
weight_tensor
.
dims
());
auto
*
fp32_data
=
weight_map
[
name_with_suffix
]
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
bf16_data
=
bf16_tensor
.
mutable_data
<
bfloat16
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
weight_tensor
.
numel
();
i
++
)
{
fp32_data
[
i
]
=
static_cast
<
float
>
(
bf16_data
[
i
]);
}
weight
.
SetDataType
(
phi
::
DataType
::
FLOAT32
);
weight
.
SetValues
(
fp32_data
);
}
else
if
(
weight_tensor
.
dtype
()
==
phi
::
DataType
::
FLOAT16
)
{
phi
::
DenseTensor
fp16_tensor
;
fp16_tensor
.
clear
();
...
...
@@ -667,18 +682,35 @@ TensorRTEngine::Weight TensorRTEngine::GetFp32TrtWeight(
weight_tensor
,
platform
::
CPUPlace
(),
&
fp16_tensor
);
weight_map
[
name_with_suffix
]
->
set_type
(
paddle
::
experimental
::
DataType
::
FLOAT32
);
weight_map
[
name_with_suffix
]
->
Resize
(
weight_tensor
.
dims
());
auto
*
fp32_data
=
weight_map
[
name_with_suffix
]
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
*
fp16_data
=
fp16_tensor
.
mutable_data
<
float16
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
weight_tensor
.
numel
();
i
++
)
{
fp32_data
[
i
]
=
static_cast
<
float
>
(
fp16_data
[
i
]);
}
weight
.
SetDataType
(
phi
::
DataType
::
FLOAT32
);
weight
.
SetValues
(
fp32_data
);
}
else
if
(
weight_tensor
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
phi
::
DenseTensor
int64_tensor
;
int64_tensor
.
clear
();
paddle
::
framework
::
TensorCopySync
(
weight_tensor
,
platform
::
CPUPlace
(),
&
int64_tensor
);
weight_map
[
name_with_suffix
]
->
set_type
(
paddle
::
experimental
::
DataType
::
INT32
);
auto
*
int32_data
=
weight_map
[
name_with_suffix
]
->
mutable_data
<
int32_t
>
(
platform
::
CPUPlace
());
auto
*
int64_data
=
int64_tensor
.
mutable_data
<
int64_t
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
weight_tensor
.
numel
();
i
++
)
{
int32_data
[
i
]
=
int64_data
[
i
];
}
weight
.
SetDataType
(
phi
::
DataType
::
INT32
);
weight
.
SetValues
(
int32_data
);
}
else
{
paddle
::
framework
::
TensorCopySync
(
weight_tensor
,
cpu_place
,
weight_map
[
name_with_suffix
].
get
());
weight
.
SetDataType
(
weight_tensor
.
dtype
());
weight
.
SetValues
(
weight_map
[
name_with_suffix
]
->
data
());
}
weight
.
SetValues
(
weight_map
[
name_with_suffix
]
->
data
());
name_suffix_counter
+=
1
;
return
weight
;
}
...
...
@@ -729,8 +761,8 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight(
weight_tensor
,
platform
::
CPUPlace
(),
&
int64_tensor
);
weight_map
[
name_with_suffix
]
->
set_type
(
paddle
::
experimental
::
DataType
::
INT32
);
auto
*
int32_data
=
weight_map
[
name_with_suffix
]
->
mutable_data
<
int
>
(
platform
::
CPUPlace
());
auto
*
int32_data
=
weight_map
[
name_with_suffix
]
->
mutable_data
<
int32_t
>
(
platform
::
CPUPlace
());
auto
*
int64_data
=
int64_tensor
.
mutable_data
<
int64_t
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
weight_tensor
.
numel
();
i
++
)
{
int32_data
[
i
]
=
int64_data
[
i
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录