Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0fd70d71
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0fd70d71
编写于
11月 25, 2021
作者:
W
Wangzheee
提交者:
GitHub
11月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix_matmul_op_int8_plugin (#37525)
上级
2a905f6b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
4 addition
and
4 deletion
+4
-4
paddle/fluid/inference/tensorrt/plugin/matmul_op_int8_plugin.cu
.../fluid/inference/tensorrt/plugin/matmul_op_int8_plugin.cu
+4
-4
未找到文件。
paddle/fluid/inference/tensorrt/plugin/matmul_op_int8_plugin.cu
浏览文件 @
0fd70d71
...
@@ -299,13 +299,13 @@ void MatmulPlugin::configurePlugin(const nvinfer1::PluginTensorDesc* inputs,
...
@@ -299,13 +299,13 @@ void MatmulPlugin::configurePlugin(const nvinfer1::PluginTensorDesc* inputs,
matmulDesc_
,
CUBLASLT_MATMUL_DESC_POINTER_MODE
,
&
matmul_model
,
matmulDesc_
,
CUBLASLT_MATMUL_DESC_POINTER_MODE
,
&
matmul_model
,
sizeof
(
matmul_model
)));
sizeof
(
matmul_model
)));
float
alpha_tem
[
n_
]
;
std
::
vector
<
float
>
alpha_tem
(
n_
,
0
)
;
for
(
int
i
=
0
;
i
<
n_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
n_
;
i
++
)
{
alpha_tem
[
i
]
=
alpha_
*
inscale_0
*
inscale_1
/
outscale
;
alpha_tem
[
i
]
=
alpha_
*
inscale_0
*
inscale_1
/
outscale
;
}
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMalloc
((
void
**
)
&
alpha_scale_
,
n_
*
sizeof
(
float
)));
cudaMalloc
((
void
**
)
&
alpha_scale_
,
n_
*
sizeof
(
float
)));
cudaMemcpyAsync
(
alpha_scale_
,
alpha_tem
,
n_
*
sizeof
(
float
),
cudaMemcpyAsync
(
alpha_scale_
,
&
alpha_tem
[
0
]
,
n_
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
cudaMemcpyHostToDevice
);
float
zero_tem
=
zero
;
float
zero_tem
=
zero
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
...
@@ -624,13 +624,13 @@ void MatmulPluginDynamic::configurePlugin(
...
@@ -624,13 +624,13 @@ void MatmulPluginDynamic::configurePlugin(
sizeof
(
int8_t
)
*
((
m_max
+
32
-
1
)
/
32
*
32
)
/
32
*
ldctransform
));
sizeof
(
int8_t
)
*
((
m_max
+
32
-
1
)
/
32
*
32
)
/
32
*
ldctransform
));
if
(
type_
==
nvinfer1
::
DataType
::
kINT8
)
{
if
(
type_
==
nvinfer1
::
DataType
::
kINT8
)
{
float
alpha_tem
[
n_max
]
;
std
::
vector
<
float
>
alpha_tem
(
n_max
,
0
)
;
for
(
int
i
=
0
;
i
<
n_max
;
i
++
)
{
for
(
int
i
=
0
;
i
<
n_max
;
i
++
)
{
alpha_tem
[
i
]
=
alpha_
*
inscale_0
*
inscale_1
/
outscale
;
alpha_tem
[
i
]
=
alpha_
*
inscale_0
*
inscale_1
/
outscale
;
}
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMalloc
((
void
**
)
&
alpha_scale_
,
n_max
*
sizeof
(
float
)));
cudaMalloc
((
void
**
)
&
alpha_scale_
,
n_max
*
sizeof
(
float
)));
cudaMemcpyAsync
(
alpha_scale_
,
alpha_tem
,
n_max
*
sizeof
(
float
),
cudaMemcpyAsync
(
alpha_scale_
,
&
alpha_tem
[
0
]
,
n_max
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
cudaMemcpyHostToDevice
);
float
zero_tem
=
zero
;
float
zero_tem
=
zero
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录