Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
2924c92a
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2924c92a
编写于
5月 11, 2018
作者:
Y
Yu Yang
提交者:
GitHub
5月 11, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10569 from reyoung/feature/matmul_support_float16_double
matmul support float16/double
上级
5ce2df9b
05a96db6
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
66 addition
and
41 deletion
+66
-41
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+15
-3
paddle/fluid/operators/math/blas_impl.h
paddle/fluid/operators/math/blas_impl.h
+3
-3
paddle/fluid/operators/math/math_function.cu
paddle/fluid/operators/math/math_function.cu
+4
-3
paddle/fluid/operators/matmul_op.cc
paddle/fluid/operators/matmul_op.cc
+44
-32
未找到文件。
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
2924c92a
...
@@ -96,10 +96,22 @@ struct CUBlas<platform::float16> {
...
@@ -96,10 +96,22 @@ struct CUBlas<platform::float16> {
reinterpret_cast
<
__half
*>
(
C
),
ldc
));
reinterpret_cast
<
__half
*>
(
C
),
ldc
));
}
}
template
<
typename
...
ARGS
>
static
void
GEMM_BATCH
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
static
void
GEMM_BATCH
(
ARGS
...
args
)
{
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float16
*
alpha
,
const
float16
*
A
,
int
lda
,
long
long
int
strideA
,
const
float16
*
B
,
// NOLINT
int
ldb
,
long
long
int
strideB
,
// NOLINT
const
float16
*
beta
,
float16
*
C
,
int
ldc
,
long
long
int
strideC
,
// NOLINT
int
batchCount
)
{
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemmStridedBatched
(
args
...));
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasHgemmStridedBatched
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
reinterpret_cast
<
const
__half
*>
(
alpha
),
reinterpret_cast
<
const
__half
*>
(
A
),
lda
,
strideA
,
reinterpret_cast
<
const
__half
*>
(
B
),
ldb
,
strideB
,
reinterpret_cast
<
const
__half
*>
(
beta
),
reinterpret_cast
<
__half
*>
(
C
),
ldc
,
strideC
,
batchCount
));
#else
#else
PADDLE_THROW
(
"HgemmStridedBatched is not supported on cuda <= 7.5"
);
PADDLE_THROW
(
"HgemmStridedBatched is not supported on cuda <= 7.5"
);
#endif
#endif
...
...
paddle/fluid/operators/math/blas_impl.h
浏览文件 @
2924c92a
...
@@ -172,9 +172,9 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
...
@@ -172,9 +172,9 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
c_array
.
data
(),
&
ldc
,
1
/* group_count */
,
&
batchCount
);
c_array
.
data
(),
&
ldc
,
1
/* group_count */
,
&
batchCount
);
#else
#else
for
(
int
k
=
0
;
k
<
batchCount
;
++
k
)
{
for
(
int
k
=
0
;
k
<
batchCount
;
++
k
)
{
const
float
*
Ak
=
&
A
[
k
*
strideA
];
auto
*
Ak
=
&
A
[
k
*
strideA
];
const
float
*
Bk
=
&
B
[
k
*
strideB
];
auto
*
Bk
=
&
B
[
k
*
strideB
];
float
*
Ck
=
&
C
[
k
*
M
*
N
];
auto
*
Ck
=
&
C
[
k
*
M
*
N
];
this
->
template
GEMM
<
T
>(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
Ak
,
Bk
,
beta
,
Ck
);
this
->
template
GEMM
<
T
>(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
Ak
,
Bk
,
beta
,
Ck
);
}
}
#endif
#endif
...
...
paddle/fluid/operators/math/math_function.cu
浏览文件 @
2924c92a
...
@@ -35,7 +35,8 @@ template struct SetConstant<platform::CUDADeviceContext, bool>;
...
@@ -35,7 +35,8 @@ template struct SetConstant<platform::CUDADeviceContext, bool>;
#define DEFINE_GPU_TRANS(RANK) \
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>;
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>;
DEFINE_GPU_TRANS
(
1
);
DEFINE_GPU_TRANS
(
1
);
DEFINE_GPU_TRANS
(
2
);
DEFINE_GPU_TRANS
(
2
);
...
...
paddle/fluid/operators/matmul_op.cc
浏览文件 @
2924c92a
...
@@ -25,7 +25,7 @@ namespace operators {
...
@@ -25,7 +25,7 @@ namespace operators {
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
* original x_dim is returned.
* original x_dim is returned.
*/
*/
static
framework
::
DDim
RowMatrixFromVector
(
const
framework
::
DDim
&
x_dim
)
{
static
framework
::
DDim
RowMatrixFromVector
(
const
framework
::
DDim
&
x_dim
)
{
if
(
x_dim
.
size
()
>
1
)
{
if
(
x_dim
.
size
()
>
1
)
{
return
x_dim
;
return
x_dim
;
}
}
...
@@ -36,7 +36,7 @@ static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
...
@@ -36,7 +36,7 @@ static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
* original y_dim is returned.
* original y_dim is returned.
*/
*/
static
framework
::
DDim
ColumnMatrixFromVector
(
const
framework
::
DDim
&
y_dim
)
{
static
framework
::
DDim
ColumnMatrixFromVector
(
const
framework
::
DDim
&
y_dim
)
{
if
(
y_dim
.
size
()
>
1
)
{
if
(
y_dim
.
size
()
>
1
)
{
return
y_dim
;
return
y_dim
;
}
}
...
@@ -46,12 +46,12 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
...
@@ -46,12 +46,12 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
MatMulKernel
:
public
framework
::
OpKernel
<
T
>
{
class
MatMulKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
&
x
=
auto
&
x
=
detail
::
Ref
(
context
.
Input
<
framework
::
Tensor
>
(
"X"
),
"Cannot find X"
);
detail
::
Ref
(
context
.
Input
<
framework
::
Tensor
>
(
"X"
),
"Cannot find X"
);
auto
&
y
=
auto
&
y
=
detail
::
Ref
(
context
.
Input
<
framework
::
Tensor
>
(
"Y"
),
"Cannot find Y"
);
detail
::
Ref
(
context
.
Input
<
framework
::
Tensor
>
(
"Y"
),
"Cannot find Y"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
context
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
context
);
...
@@ -65,7 +65,7 @@ class MatMulKernel : public framework::OpKernel<T> {
...
@@ -65,7 +65,7 @@ class MatMulKernel : public framework::OpKernel<T> {
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
// Identity op if the tensor is not of rank 3.
static
framework
::
Tensor
FoldInitDims
(
const
framework
::
Tensor
&
input
)
{
static
framework
::
Tensor
FoldInitDims
(
const
framework
::
Tensor
&
input
)
{
auto
output
=
input
;
auto
output
=
input
;
auto
in_dims
=
input
.
dims
();
auto
in_dims
=
input
.
dims
();
if
(
in_dims
.
size
()
==
3
)
{
if
(
in_dims
.
size
()
==
3
)
{
...
@@ -78,8 +78,8 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) {
...
@@ -78,8 +78,8 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) {
// (Warning: This requires transposing data and writes into new memory.)
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
// Identity op if the tensor is not of rank 3.
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
static
framework
::
Tensor
FoldHeadAndLastDims
(
const
DeviceContext
&
context
,
static
framework
::
Tensor
FoldHeadAndLastDims
(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
)
{
const
framework
::
Tensor
&
input
)
{
auto
in_dims
=
input
.
dims
();
auto
in_dims
=
input
.
dims
();
if
(
in_dims
.
size
()
!=
3
)
{
if
(
in_dims
.
size
()
!=
3
)
{
return
input
;
return
input
;
...
@@ -102,7 +102,7 @@ static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
...
@@ -102,7 +102,7 @@ static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
* If transposed, `H,W` will be swapped.
* If transposed, `H,W` will be swapped.
*/
*/
static
void
ReshapeTensorIntoMatrixSequence
(
static
void
ReshapeTensorIntoMatrixSequence
(
framework
::
Tensor
*
x
,
const
math
::
MatDescriptor
&
descriptor
)
{
framework
::
Tensor
*
x
,
const
math
::
MatDescriptor
&
descriptor
)
{
int64_t
h
,
w
;
int64_t
h
,
w
;
h
=
descriptor
.
height_
;
h
=
descriptor
.
height_
;
w
=
descriptor
.
width_
;
w
=
descriptor
.
width_
;
...
@@ -130,9 +130,9 @@ static void ReshapeTensorIntoMatrixSequence(
...
@@ -130,9 +130,9 @@ static void ReshapeTensorIntoMatrixSequence(
* If any of `X` and `Y` has batch size BatchSize, the out will have the
* If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize.
* BatchSize.
*/
*/
static
void
ReshapeXYOutIntoMatrixSequence
(
framework
::
Tensor
*
x
,
static
void
ReshapeXYOutIntoMatrixSequence
(
framework
::
Tensor
*
x
,
framework
::
Tensor
*
y
,
framework
::
Tensor
*
y
,
framework
::
Tensor
*
out
,
bool
trans_x
,
framework
::
Tensor
*
out
,
bool
trans_x
,
bool
trans_y
)
{
bool
trans_y
)
{
auto
x_dim
=
RowMatrixFromVector
(
x
->
dims
());
auto
x_dim
=
RowMatrixFromVector
(
x
->
dims
());
auto
y_dim
=
ColumnMatrixFromVector
(
y
->
dims
());
auto
y_dim
=
ColumnMatrixFromVector
(
y
->
dims
());
...
@@ -177,10 +177,10 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
...
@@ -177,10 +177,10 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
MatMulGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
MatMulGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
MatMul
(
const
framework
::
ExecutionContext
&
context
,
void
MatMul
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
&
a
,
bool
trans_a
,
const
framework
::
Tensor
&
a
,
bool
trans_a
,
const
framework
::
Tensor
&
b
,
bool
trans_b
,
const
framework
::
Tensor
&
b
,
bool
trans_b
,
framework
::
Tensor
*
out
)
const
{
framework
::
Tensor
*
out
)
const
{
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
context
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
context
);
auto
mat_dim_a
=
math
::
CreateMatrixDescriptor
(
a
.
dims
(),
0
,
trans_a
);
auto
mat_dim_a
=
math
::
CreateMatrixDescriptor
(
a
.
dims
(),
0
,
trans_a
);
...
@@ -188,18 +188,18 @@ class MatMulGradKernel : public framework::OpKernel<T> {
...
@@ -188,18 +188,18 @@ class MatMulGradKernel : public framework::OpKernel<T> {
blas
.
MatMul
(
a
,
mat_dim_a
,
b
,
mat_dim_b
,
T
(
1
),
out
,
T
(
0
));
blas
.
MatMul
(
a
,
mat_dim_a
,
b
,
mat_dim_b
,
T
(
1
),
out
,
T
(
0
));
}
}
void
CalcInputGrad
(
const
framework
::
ExecutionContext
&
context
,
void
CalcInputGrad
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
&
a
,
bool
trans_a
,
const
framework
::
Tensor
&
a
,
bool
trans_a
,
bool
is_fold_init_dims_a
,
const
framework
::
Tensor
&
b
,
bool
is_fold_init_dims_a
,
const
framework
::
Tensor
&
b
,
bool
trans_b
,
bool
is_fold_init_dims_b
,
bool
trans_b
,
bool
is_fold_init_dims_b
,
framework
::
Tensor
*
out
)
const
{
framework
::
Tensor
*
out
)
const
{
if
(
out
==
nullptr
)
return
;
if
(
out
==
nullptr
)
return
;
bool
need_combine
=
(
a
.
dims
().
size
()
==
3
||
b
.
dims
().
size
()
==
3
)
&&
bool
need_combine
=
(
a
.
dims
().
size
()
==
3
||
b
.
dims
().
size
()
==
3
)
&&
out
->
dims
().
size
()
==
2
;
out
->
dims
().
size
()
==
2
;
if
(
!
need_combine
)
{
if
(
!
need_combine
)
{
MatMul
(
context
,
a
,
trans_a
,
b
,
trans_b
,
out
);
MatMul
(
context
,
a
,
trans_a
,
b
,
trans_b
,
out
);
}
else
{
}
else
{
auto
&
ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
ctx
=
context
.
template
device_context
<
DeviceContext
>();
MatMul
(
context
,
is_fold_init_dims_a
MatMul
(
context
,
is_fold_init_dims_a
?
FoldInitDims
(
a
)
?
FoldInitDims
(
a
)
:
FoldHeadAndLastDims
<
DeviceContext
,
T
>
(
ctx
,
a
),
:
FoldHeadAndLastDims
<
DeviceContext
,
T
>
(
ctx
,
a
),
...
@@ -210,13 +210,13 @@ class MatMulGradKernel : public framework::OpKernel<T> {
...
@@ -210,13 +210,13 @@ class MatMulGradKernel : public framework::OpKernel<T> {
}
}
}
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
x
=
*
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
x
=
*
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
y
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
y
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
dout
=
auto
dout
=
*
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
*
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dx
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
dy
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
bool
transpose_x
=
context
.
Attr
<
bool
>
(
"transpose_X"
);
bool
transpose_x
=
context
.
Attr
<
bool
>
(
"transpose_X"
);
bool
transpose_y
=
context
.
Attr
<
bool
>
(
"transpose_Y"
);
bool
transpose_y
=
context
.
Attr
<
bool
>
(
"transpose_Y"
);
...
@@ -269,7 +269,7 @@ class MatMulOp : public framework::OperatorWithKernel {
...
@@ -269,7 +269,7 @@ class MatMulOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
*
context
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
context
->
HasInput
(
"X"
),
"Input(X) of MatMulOp should not be null."
);
"Input(X) of MatMulOp should not be null."
);
PADDLE_ENFORCE
(
context
->
HasInput
(
"Y"
),
PADDLE_ENFORCE
(
context
->
HasInput
(
"Y"
),
...
@@ -375,7 +375,7 @@ class MatMulOpGrad : public framework::OperatorWithKernel {
...
@@ -375,7 +375,7 @@ class MatMulOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
protected:
void
InferShape
(
framework
::
InferShapeContext
*
context
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
context
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
context
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
context
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
context
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
context
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
...
@@ -401,7 +401,7 @@ class MatMulOpGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -401,7 +401,7 @@ class MatMulOpGradMaker : public framework::SingleGradOpDescMaker {
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
retv
=
new
framework
::
OpDesc
();
auto
*
retv
=
new
framework
::
OpDesc
();
retv
->
SetType
(
"matmul_grad"
);
retv
->
SetType
(
"matmul_grad"
);
retv
->
SetInput
(
"X"
,
Input
(
"X"
));
retv
->
SetInput
(
"X"
,
Input
(
"X"
));
retv
->
SetInput
(
"Y"
,
Input
(
"Y"
));
retv
->
SetInput
(
"Y"
,
Input
(
"Y"
));
...
@@ -420,15 +420,27 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
...
@@ -420,15 +420,27 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
ops
::
MatMulOpGradMaker
);
ops
::
MatMulOpGradMaker
);
REGISTER_OPERATOR
(
matmul_grad
,
ops
::
MatMulOpGrad
);
REGISTER_OPERATOR
(
matmul_grad
,
ops
::
MatMulOpGrad
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
matmul
,
ops
::
MatMulKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
matmul
,
ops
::
MatMulKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MatMulKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
MatMulKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
float16
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
matmul_grad
,
matmul_grad
,
ops
::
MatMulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
MatMulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MatMulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
MatMulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
float16
>
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
matmul
,
ops
::
MatMulKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
matmul
,
ops
::
MatMulKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MatMulKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
MatMulKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
matmul_grad
,
matmul_grad
,
ops
::
MatMulGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
MatMulGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MatMulGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
MatMulGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
);
#endif
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录