Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6fc74bba
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6fc74bba
编写于
9月 25, 2020
作者:
S
ShenLiang
提交者:
GitHub
9月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fp16 for matmul (#27523)
* add fp16 for matmul
上级
fab4e6d0
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
142 addition
and
55 deletion
+142
-55
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+29
-0
paddle/fluid/operators/matmul_v2_op.cu
paddle/fluid/operators/matmul_v2_op.cu
+6
-4
paddle/fluid/operators/matmul_v2_op.h
paddle/fluid/operators/matmul_v2_op.h
+31
-24
python/paddle/fluid/tests/unittests/test_matmul_v2_op.py
python/paddle/fluid/tests/unittests/test_matmul_v2_op.py
+74
-25
python/paddle/tensor/linalg.py
python/paddle/tensor/linalg.py
+2
-2
未找到文件。
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
6fc74bba
...
...
@@ -420,6 +420,22 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
});
}
template
<
>
template
<
>
inline
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMV
(
bool
trans_a
,
int
M
,
int
N
,
platform
::
float16
alpha
,
const
platform
::
float16
*
A
,
const
platform
::
float16
*
B
,
platform
::
float16
beta
,
platform
::
float16
*
C
)
const
{
// Because cublas doesn't support half gemv, we use cublasHgemm to achieve it.
if
(
trans_a
)
{
this
->
template
GEMM
<
platform
::
float16
>(
CblasNoTrans
,
CblasNoTrans
,
1
,
N
,
M
,
alpha
,
B
,
A
,
beta
,
C
);
}
else
{
this
->
template
GEMM
<
platform
::
float16
>(
CblasNoTrans
,
CblasNoTrans
,
M
,
1
,
N
,
alpha
,
A
,
B
,
beta
,
C
);
}
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
BatchedGEMM
(
...
...
@@ -479,6 +495,19 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
}
}
template
<
>
template
<
>
inline
void
Blas
<
platform
::
CUDADeviceContext
>::
BatchedGEMM
(
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
platform
::
float16
alpha
,
const
platform
::
float16
**
A
,
const
platform
::
float16
**
B
,
platform
::
float16
beta
,
platform
::
float16
**
C
,
int
batchCount
)
const
{
for
(
int
k
=
0
;
k
<
batchCount
;
++
k
)
{
this
->
template
GEMM
<
platform
::
float16
>(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
[
k
],
B
[
k
],
beta
,
C
[
k
]);
}
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
TRSM
(
CBLAS_SIDE
side
,
CBLAS_UPLO
uplo
,
...
...
paddle/fluid/operators/matmul_v2_op.cu
浏览文件 @
6fc74bba
...
...
@@ -17,10 +17,12 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
namespace
plf
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
matmul_v2
,
ops
::
MatMulV2Kernel
<
plf
::
CUDADeviceContext
,
float
>
,
ops
::
MatMulV2Kernel
<
plf
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
matmul_v2
,
ops
::
MatMulV2Kernel
<
plf
::
CUDADeviceContext
,
float
>
,
ops
::
MatMulV2Kernel
<
plf
::
CUDADeviceContext
,
double
>
,
ops
::
MatMulV2Kernel
<
plf
::
CUDADeviceContext
,
plf
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
matmul_v2_grad
,
ops
::
MatMulV2GradKernel
<
plf
::
CUDADeviceContext
,
float
>
,
ops
::
MatMulV2GradKernel
<
plf
::
CUDADeviceContext
,
double
>
);
ops
::
MatMulV2GradKernel
<
plf
::
CUDADeviceContext
,
double
>
,
ops
::
MatMulV2GradKernel
<
plf
::
CUDADeviceContext
,
plf
::
float16
>
);
paddle/fluid/operators/matmul_v2_op.h
浏览文件 @
6fc74bba
...
...
@@ -163,17 +163,20 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
if
(
trans_y
)
{
const
int
M
=
Y
->
numel
()
/
N
;
VLOG
(
3
)
<<
"MatMul's case 2"
;
blas
.
GEMV
(
false
,
M
,
N
,
1.
,
y_data
,
x_data
,
0.
,
Out
->
data
<
T
>
());
blas
.
GEMV
(
false
,
M
,
N
,
static_cast
<
T
>
(
1
),
y_data
,
x_data
,
static_cast
<
T
>
(
0
),
Out
->
data
<
T
>
());
}
else
{
const
int
M
=
y_dims
[
y_ndim
-
1
];
const
int
batch_size
=
Y
->
numel
()
/
(
M
*
N
);
if
(
batch_size
==
1
)
{
VLOG
(
3
)
<<
"MatMul's case 3"
;
blas
.
GEMV
(
true
,
N
,
M
,
1.
,
y_data
,
x_data
,
0.
,
Out
->
data
<
T
>
());
blas
.
GEMV
(
true
,
N
,
M
,
static_cast
<
T
>
(
1
),
y_data
,
x_data
,
static_cast
<
T
>
(
0
),
Out
->
data
<
T
>
());
}
else
{
VLOG
(
3
)
<<
"MatMul's case 4"
;
blas
.
BatchedGEMM
(
CblasTrans
,
CblasNoTrans
,
M
,
1
,
N
,
1.0
f
,
y_data
,
x_data
,
0
,
Out
->
data
<
T
>
(),
batch_size
,
M
*
N
,
0
);
blas
.
BatchedGEMM
(
CblasTrans
,
CblasNoTrans
,
M
,
1
,
N
,
static_cast
<
T
>
(
1
),
y_data
,
x_data
,
static_cast
<
T
>
(
0
),
Out
->
data
<
T
>
(),
batch_size
,
M
*
N
,
0
);
}
}
return
;
...
...
@@ -205,16 +208,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
const
int
batch_size
=
X
->
numel
()
/
(
M
*
N
);
if
(
batch_size
==
1
)
{
VLOG
(
3
)
<<
"MatMul's case 5"
;
blas
.
GEMV
(
true
,
N
,
M
,
1.0
f
,
x_data
,
y_data
,
0.0
f
,
Out
->
data
<
T
>
());
blas
.
GEMV
(
true
,
N
,
M
,
static_cast
<
T
>
(
1
),
x_data
,
y_data
,
static_cast
<
T
>
(
0
),
Out
->
data
<
T
>
());
}
else
{
VLOG
(
3
)
<<
"MatMul's case 6"
;
blas
.
BatchedGEMM
(
CblasTrans
,
CblasNoTrans
,
M
,
1
,
N
,
1.0
f
,
x_data
,
y_data
,
0
,
Out
->
data
<
T
>
(),
batch_size
,
M
*
N
,
0
);
blas
.
BatchedGEMM
(
CblasTrans
,
CblasNoTrans
,
M
,
1
,
N
,
static_cast
<
T
>
(
1
),
x_data
,
y_data
,
static_cast
<
T
>
(
0
),
Out
->
data
<
T
>
(),
batch_size
,
M
*
N
,
0
);
}
}
else
{
const
int
M
=
X
->
numel
()
/
N
;
VLOG
(
3
)
<<
"MatMul's case 7"
;
blas
.
GEMV
(
false
,
M
,
N
,
1.0
f
,
x_data
,
y_data
,
0.0
f
,
Out
->
data
<
T
>
());
blas
.
GEMV
(
false
,
M
,
N
,
static_cast
<
T
>
(
1
),
x_data
,
y_data
,
static_cast
<
T
>
(
0
),
Out
->
data
<
T
>
());
}
return
;
}
...
...
@@ -263,37 +269,38 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
if
(
x_batch_size
==
1
&&
y_batch_size
==
1
)
{
VLOG
(
3
)
<<
"MatMul's case 8"
;
blas
.
GEMM
(
trans_x
?
CblasTrans
:
CblasNoTrans
,
trans_y
?
CblasTrans
:
CblasNoTrans
,
M
,
N
,
K
,
1.0
f
,
x_data
,
y_data
,
0.0
f
,
Out
->
data
<
T
>
());
trans_y
?
CblasTrans
:
CblasNoTrans
,
M
,
N
,
K
,
static_cast
<
T
>
(
1
)
,
x_data
,
y_data
,
static_cast
<
T
>
(
0
)
,
Out
->
data
<
T
>
());
}
else
if
(
x_batch_size
==
1
)
{
if
(
M
==
1
&&
trans_y
)
{
VLOG
(
3
)
<<
"MatMul's case 9"
;
blas
.
GEMV
(
false
,
y_batch_size
*
N
,
K
,
1.0
f
,
y_data
,
x_data
,
0.0
f
,
Out
->
data
<
T
>
());
blas
.
GEMV
(
false
,
y_batch_size
*
N
,
K
,
static_cast
<
T
>
(
1
),
y_data
,
x_data
,
static_cast
<
T
>
(
0
),
Out
->
data
<
T
>
());
}
else
{
VLOG
(
3
)
<<
"MatMul's case 10"
;
blas
.
BatchedGEMM
(
trans_x
?
CblasTrans
:
CblasNoTrans
,
trans_y
?
CblasTrans
:
CblasNoTrans
,
M
,
N
,
K
,
1.0
f
,
x_data
,
y_data
,
0
,
Out
->
data
<
T
>
(),
out_batch_size
,
0
,
K
*
N
);
trans_y
?
CblasTrans
:
CblasNoTrans
,
M
,
N
,
K
,
static_cast
<
T
>
(
1
),
x_data
,
y_data
,
static_cast
<
T
>
(
0
)
,
Out
->
data
<
T
>
(),
out_batch_size
,
0
,
K
*
N
);
}
}
else
if
(
y_batch_size
==
1
)
{
if
(
!
trans_x
)
{
VLOG
(
3
)
<<
"MatMul's case 11"
;
blas
.
GEMM
(
CblasNoTrans
,
trans_y
?
CblasTrans
:
CblasNoTrans
,
x_batch_size
*
M
,
N
,
K
,
1.0
f
,
x_data
,
y_data
,
0.0
f
,
Out
->
data
<
T
>
());
x_batch_size
*
M
,
N
,
K
,
static_cast
<
T
>
(
1
),
x_data
,
y_data
,
static_cast
<
T
>
(
0
),
Out
->
data
<
T
>
());
}
else
{
VLOG
(
3
)
<<
"MatMul's case 12"
;
blas
.
BatchedGEMM
(
CblasTrans
,
trans_y
?
CblasTrans
:
CblasNoTrans
,
M
,
N
,
K
,
1.0
f
,
x_data
,
y_data
,
0
,
Out
->
data
<
T
>
(),
out_batch_size
,
M
*
K
,
0
);
static_cast
<
T
>
(
1
),
x_data
,
y_data
,
static_cast
<
T
>
(
0
)
,
Out
->
data
<
T
>
(),
out_batch_size
,
M
*
K
,
0
);
}
}
else
if
(
!
is_broadcast_dims
)
{
VLOG
(
3
)
<<
"MatMul's case 13"
;
blas
.
BatchedGEMM
(
trans_x
?
CblasTrans
:
CblasNoTrans
,
trans_y
?
CblasTrans
:
CblasNoTrans
,
M
,
N
,
K
,
1.0
f
,
x_data
,
y_data
,
0
,
Out
->
data
<
T
>
(),
out_batch_size
,
M
*
K
,
K
*
N
);
trans_y
?
CblasTrans
:
CblasNoTrans
,
M
,
N
,
K
,
static_cast
<
T
>
(
1
),
x_data
,
y_data
,
static_cast
<
T
>
(
0
),
Out
->
data
<
T
>
(),
out_batch_size
,
M
*
K
,
K
*
N
);
}
else
{
// in the case, can't use stridedgemm
std
::
vector
<
const
T
*>
x_ptr
(
out_batch_size
);
...
...
@@ -314,9 +321,9 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
}
VLOG
(
3
)
<<
"MatMul's case 14"
;
blas
.
BatchedGEMM
(
trans_x
?
CblasTrans
:
CblasNoTrans
,
trans_y
?
CblasTrans
:
CblasNoTrans
,
M
,
N
,
K
,
1.0
f
,
x_ptr
.
data
(),
y_ptr
.
data
(),
0.0
f
,
out
_ptr
.
data
(),
out_batch_size
);
trans_y
?
CblasTrans
:
CblasNoTrans
,
M
,
N
,
K
,
static_cast
<
T
>
(
1
),
x_ptr
.
data
(),
y
_ptr
.
data
(),
static_cast
<
T
>
(
0
),
out_ptr
.
data
(),
out_batch_size
);
}
}
...
...
python/paddle/fluid/tests/unittests/test_matmul_v2_op.py
浏览文件 @
6fc74bba
...
...
@@ -65,15 +65,21 @@ class TestMatMulV2Op(OpTest):
self
.
y_shape
=
(
100
,
)
self
.
trans_x
=
False
self
.
trans_y
=
False
def
init_kernel_type
(
self
):
self
.
dtype
=
"float64"
def
setUp
(
self
):
self
.
init_kernel_type
()
self
.
config
()
self
.
op_type
=
"matmul_v2"
x
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
dtype
)
y
=
np
.
random
.
random
(
self
.
y_shape
).
astype
(
self
.
dtype
)
# -0.1 ~ 0.1
x
=
-
0.1
+
0.2
*
x
y
=
-
0.1
+
0.2
*
y
result
=
reference_matmul
(
x
,
y
,
self
.
trans_x
,
self
.
trans_y
)
result
=
result
.
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
x
,
'Y'
:
y
,
...
...
@@ -98,7 +104,6 @@ class TestMatMuklOp2(TestMatMulV2Op):
self
.
y_shape
=
(
1
,
3
,
2
,
100
)
self
.
trans_x
=
False
self
.
trans_y
=
True
self
.
dtype
=
"float64"
class
TestMatMuklOp3
(
TestMatMulV2Op
):
...
...
@@ -111,7 +116,6 @@ class TestMatMuklOp3(TestMatMulV2Op):
self
.
y_shape
=
(
1
,
1
,
100
,
2
)
self
.
trans_x
=
False
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp4
(
TestMatMulV2Op
):
...
...
@@ -124,7 +128,6 @@ class TestMatMuklOp4(TestMatMulV2Op):
self
.
y_shape
=
(
1
,
2
,
100
,
2
)
self
.
trans_x
=
False
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp5
(
TestMatMulV2Op
):
...
...
@@ -133,11 +136,10 @@ class TestMatMuklOp5(TestMatMulV2Op):
"""
def
config
(
self
):
self
.
x_shape
=
(
1
,
1
,
100
,
2
)
self
.
x_shape
=
(
1
,
1
,
100
,
1
)
self
.
y_shape
=
(
100
,
)
self
.
trans_x
=
True
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp6
(
TestMatMulV2Op
):
...
...
@@ -150,7 +152,6 @@ class TestMatMuklOp6(TestMatMulV2Op):
self
.
y_shape
=
(
100
,
)
self
.
trans_x
=
True
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp7
(
TestMatMulV2Op
):
...
...
@@ -163,7 +164,6 @@ class TestMatMuklOp7(TestMatMulV2Op):
self
.
y_shape
=
(
100
,
)
self
.
trans_x
=
False
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp8
(
TestMatMulV2Op
):
...
...
@@ -176,7 +176,6 @@ class TestMatMuklOp8(TestMatMulV2Op):
self
.
y_shape
=
(
1
,
1
,
100
,
2
)
self
.
trans_x
=
False
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp9
(
TestMatMulV2Op
):
...
...
@@ -189,7 +188,6 @@ class TestMatMuklOp9(TestMatMulV2Op):
self
.
y_shape
=
(
2
,
1
,
2
,
100
)
self
.
trans_x
=
False
self
.
trans_y
=
True
self
.
dtype
=
"float64"
class
TestMatMuklOp10
(
TestMatMulV2Op
):
...
...
@@ -198,11 +196,10 @@ class TestMatMuklOp10(TestMatMulV2Op):
"""
def
config
(
self
):
self
.
x_shape
=
(
1
,
1
,
2
,
100
)
self
.
y_shape
=
(
1
,
2
,
100
,
2
)
self
.
x_shape
=
(
1
,
1
,
2
5
,
4
)
self
.
y_shape
=
(
1
,
2
,
4
,
25
)
self
.
trans_x
=
False
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp11
(
TestMatMulV2Op
):
...
...
@@ -215,7 +212,6 @@ class TestMatMuklOp11(TestMatMulV2Op):
self
.
y_shape
=
(
1
,
1
,
100
,
2
)
self
.
trans_x
=
False
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp12
(
TestMatMulV2Op
):
...
...
@@ -224,11 +220,10 @@ class TestMatMuklOp12(TestMatMulV2Op):
"""
def
config
(
self
):
self
.
x_shape
=
(
2
,
1
,
100
,
2
)
self
.
y_shape
=
(
1
,
1
,
100
,
2
)
self
.
x_shape
=
(
2
,
1
,
4
,
25
)
self
.
y_shape
=
(
1
,
1
,
4
,
25
)
self
.
trans_x
=
True
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp13
(
TestMatMulV2Op
):
...
...
@@ -237,11 +232,10 @@ class TestMatMuklOp13(TestMatMulV2Op):
"""
def
config
(
self
):
self
.
x_shape
=
(
2
,
2
,
100
,
2
)
self
.
y_shape
=
(
2
,
2
,
100
,
2
)
self
.
x_shape
=
(
2
,
2
,
2
,
50
)
self
.
y_shape
=
(
2
,
2
,
2
,
50
)
self
.
trans_x
=
True
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp14
(
TestMatMulV2Op
):
...
...
@@ -254,7 +248,6 @@ class TestMatMuklOp14(TestMatMulV2Op):
self
.
y_shape
=
(
1
,
2
,
2
,
100
,
2
)
self
.
trans_x
=
True
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp15
(
TestMatMulV2Op
):
...
...
@@ -267,7 +260,6 @@ class TestMatMuklOp15(TestMatMulV2Op):
self
.
y_shape
=
(
1
,
2
,
2
,
100
,
1
)
self
.
trans_x
=
False
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp16
(
TestMatMulV2Op
):
...
...
@@ -277,10 +269,9 @@ class TestMatMuklOp16(TestMatMulV2Op):
def
config
(
self
):
self
.
x_shape
=
(
100
)
self
.
y_shape
=
(
1
,
2
,
2
,
100
,
1
)
self
.
y_shape
=
(
1
,
2
,
2
,
100
,
2
)
self
.
trans_x
=
False
self
.
trans_y
=
False
self
.
dtype
=
"float64"
class
TestMatMuklOp17
(
TestMatMulV2Op
):
...
...
@@ -293,7 +284,54 @@ class TestMatMuklOp17(TestMatMulV2Op):
self
.
y_shape
=
(
100
)
self
.
trans_x
=
False
self
.
trans_y
=
False
self
.
dtype
=
"float64"
#--------------------test matmul fp16--------------------
def
create_test_fp16_class
(
parent
,
atol
=
0.001
,
max_relative_error
=
1.0
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestMatMulOpFp16Case
(
parent
):
def
init_kernel_type
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
atol
)
def
test_check_grad
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
place
,
[
'X'
,
'Y'
],
'Out'
,
max_relative_error
=
max_relative_error
)
cls_name
=
"{0}_{1}"
.
format
(
parent
.
__name__
,
"Fp16"
)
TestMatMulOpFp16Case
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestMatMulOpFp16Case
create_test_fp16_class
(
TestMatMulV2Op
)
create_test_fp16_class
(
TestMatMuklOp2
)
create_test_fp16_class
(
TestMatMuklOp3
)
create_test_fp16_class
(
TestMatMuklOp4
)
create_test_fp16_class
(
TestMatMuklOp5
)
create_test_fp16_class
(
TestMatMuklOp6
)
create_test_fp16_class
(
TestMatMuklOp7
)
create_test_fp16_class
(
TestMatMuklOp8
)
create_test_fp16_class
(
TestMatMuklOp9
)
create_test_fp16_class
(
TestMatMuklOp10
)
create_test_fp16_class
(
TestMatMuklOp11
)
create_test_fp16_class
(
TestMatMuklOp12
)
create_test_fp16_class
(
TestMatMuklOp13
)
create_test_fp16_class
(
TestMatMuklOp14
)
create_test_fp16_class
(
TestMatMuklOp15
)
create_test_fp16_class
(
TestMatMuklOp16
)
create_test_fp16_class
(
TestMatMuklOp17
)
class
TestMatMulV2API
(
unittest
.
TestCase
):
...
...
@@ -331,6 +369,17 @@ class TestMatMulV2API(unittest.TestCase):
y
=
paddle
.
to_tensor
(
input_y
)
result
=
paddle
.
matmul
(
x
,
y
)
def
test_dygraph_fp16
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
with
fluid
.
dygraph
.
guard
(
place
):
input_x
=
np
.
random
.
random
([
4
,
3
]).
astype
(
"float16"
)
input_y
=
np
.
random
.
random
([
3
,
4
]).
astype
(
"float16"
)
x
=
paddle
.
to_tensor
(
input_x
)
y
=
paddle
.
to_tensor
(
input_y
)
result
=
paddle
.
matmul
(
x
,
y
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/tensor/linalg.py
浏览文件 @
6fc74bba
...
...
@@ -156,8 +156,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
def
__check_input
(
x
,
y
):
var_names
=
{
'x'
:
x
,
'y'
:
y
}
for
name
,
val
in
var_names
.
items
():
check_variable_and_dtype
(
val
,
name
,
[
'float32'
,
'float64'
],
'matmul'
)
check_variable_and_dtype
(
val
,
name
,
[
'float16'
,
'float32'
,
'float64'
],
'matmul'
)
__check_input
(
x
,
y
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录