Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6c07cd7e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6c07cd7e
编写于
5月 26, 2021
作者:
C
chentianyu03
提交者:
GitHub
5月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify matmul Op to complex template types (#33130)
* modify matmul Op to complex template types * remove complex64/128 head file
上级
8259d9bf
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
230 addition
and
207 deletion
+230
-207
paddle/fluid/imperative/gradient_accumulator.cc
paddle/fluid/imperative/gradient_accumulator.cc
+3
-4
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+58
-48
paddle/fluid/operators/math/blas_impl.h
paddle/fluid/operators/math/blas_impl.h
+94
-92
paddle/fluid/operators/math/blas_impl.hip.h
paddle/fluid/operators/math/blas_impl.hip.h
+52
-42
paddle/fluid/operators/math/selected_rows_functor.cc
paddle/fluid/operators/math/selected_rows_functor.cc
+5
-3
paddle/fluid/operators/matmul_v2_op.cc
paddle/fluid/operators/matmul_v2_op.cc
+4
-4
paddle/fluid/operators/matmul_v2_op.cu
paddle/fluid/operators/matmul_v2_op.cu
+4
-4
paddle/fluid/operators/matmul_v2_op.h
paddle/fluid/operators/matmul_v2_op.h
+10
-10
未找到文件。
paddle/fluid/imperative/gradient_accumulator.cc
浏览文件 @
6c07cd7e
...
...
@@ -24,8 +24,7 @@
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -200,8 +199,8 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
PADDLE_TENSOR_ADD
(
double
);
// NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future
PADDLE_TENSOR_ADD
(
platform
::
complex
64
);
PADDLE_TENSOR_ADD
(
platform
::
complex
128
);
PADDLE_TENSOR_ADD
(
platform
::
complex
<
float
>
);
PADDLE_TENSOR_ADD
(
platform
::
complex
<
double
>
);
#endif
#undef PADDLE_TENSOR_ADD
...
...
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
6c07cd7e
...
...
@@ -260,13 +260,13 @@ struct CUBlas<platform::float16> {
};
template
<
>
struct
CUBlas
<
platform
::
complex64
>
{
using
complex64
=
platform
::
complex64
;
struct
CUBlas
<
platform
::
complex
<
float
>>
{
static
void
GEMV
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
int
m
,
int
n
,
const
complex64
*
alpha
,
const
complex64
*
A
,
int
lda
,
const
complex64
*
B
,
int
ldb
,
const
complex64
*
beta
,
complex64
*
C
,
int
ldc
)
{
int
n
,
const
platform
::
complex
<
float
>
*
alpha
,
const
platform
::
complex
<
float
>
*
A
,
int
lda
,
const
platform
::
complex
<
float
>
*
B
,
int
ldb
,
const
platform
::
complex
<
float
>
*
beta
,
platform
::
complex
<
float
>
*
C
,
int
ldc
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasCgemv
(
handle
,
transa
,
m
,
n
,
reinterpret_cast
<
const
cuFloatComplex
*>
(
alpha
),
reinterpret_cast
<
const
cuFloatComplex
*>
(
A
),
lda
,
...
...
@@ -275,9 +275,10 @@ struct CUBlas<platform::complex64> {
reinterpret_cast
<
cuFloatComplex
*>
(
C
),
ldc
));
}
static
void
AXPY
(
cublasHandle_t
handle
,
int
n
,
const
complex64
*
alpha
,
const
complex64
*
X
,
const
int
incX
,
complex64
*
Y
,
const
int
incY
)
{
static
void
AXPY
(
cublasHandle_t
handle
,
int
n
,
const
platform
::
complex
<
float
>
*
alpha
,
const
platform
::
complex
<
float
>
*
X
,
const
int
incX
,
platform
::
complex
<
float
>
*
Y
,
const
int
incY
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasCaxpy
(
handle
,
n
,
reinterpret_cast
<
const
cuFloatComplex
*>
(
alpha
),
reinterpret_cast
<
const
cuFloatComplex
*>
(
X
),
incX
,
...
...
@@ -287,11 +288,13 @@ struct CUBlas<platform::complex64> {
static
void
GEMM_STRIDED_BATCH
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
complex64
*
alpha
,
const
complex64
*
A
,
int
lda
,
long
long
int
strideA
,
// NOLINT
const
complex64
*
B
,
// NOLINT
int
ldb
,
long
long
int
strideB
,
// NOLINT
const
complex64
*
beta
,
complex64
*
C
,
int
ldc
,
const
platform
::
complex
<
float
>
*
alpha
,
const
platform
::
complex
<
float
>
*
A
,
int
lda
,
long
long
int
strideA
,
// NOLINT
const
platform
::
complex
<
float
>
*
B
,
// NOLINT
int
ldb
,
long
long
int
strideB
,
// NOLINT
const
platform
::
complex
<
float
>
*
beta
,
platform
::
complex
<
float
>
*
C
,
int
ldc
,
long
long
int
strideC
,
// NOLINT
int
batchCount
)
{
#if CUDA_VERSION >= 8000
...
...
@@ -310,9 +313,11 @@ struct CUBlas<platform::complex64> {
static
void
GEMM
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
complex64
*
alpha
,
const
complex64
*
A
,
int
lda
,
const
complex64
*
B
,
int
ldb
,
const
complex64
*
beta
,
complex64
*
C
,
int
ldc
)
{
const
platform
::
complex
<
float
>
*
alpha
,
const
platform
::
complex
<
float
>
*
A
,
int
lda
,
const
platform
::
complex
<
float
>
*
B
,
int
ldb
,
const
platform
::
complex
<
float
>
*
beta
,
platform
::
complex
<
float
>
*
C
,
int
ldc
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasCgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
reinterpret_cast
<
const
cuFloatComplex
*>
(
alpha
),
...
...
@@ -356,13 +361,13 @@ struct CUBlas<platform::complex64> {
};
template
<
>
struct
CUBlas
<
platform
::
complex128
>
{
using
complex128
=
platform
::
complex128
;
struct
CUBlas
<
platform
::
complex
<
double
>>
{
static
void
GEMV
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
int
m
,
int
n
,
const
complex128
*
alpha
,
const
complex128
*
A
,
int
lda
,
const
complex128
*
B
,
int
ldb
,
const
complex128
*
beta
,
complex128
*
C
,
int
ldc
)
{
int
n
,
const
platform
::
complex
<
double
>
*
alpha
,
const
platform
::
complex
<
double
>
*
A
,
int
lda
,
const
platform
::
complex
<
double
>
*
B
,
int
ldb
,
const
platform
::
complex
<
double
>
*
beta
,
platform
::
complex
<
double
>
*
C
,
int
ldc
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasZgemv
(
handle
,
transa
,
m
,
n
,
reinterpret_cast
<
const
cuDoubleComplex
*>
(
alpha
),
reinterpret_cast
<
const
cuDoubleComplex
*>
(
A
),
lda
,
...
...
@@ -371,9 +376,10 @@ struct CUBlas<platform::complex128> {
reinterpret_cast
<
cuDoubleComplex
*>
(
C
),
ldc
));
}
static
void
AXPY
(
cublasHandle_t
handle
,
int
n
,
const
complex128
*
alpha
,
const
complex128
*
X
,
const
int
incX
,
complex128
*
Y
,
const
int
incY
)
{
static
void
AXPY
(
cublasHandle_t
handle
,
int
n
,
const
platform
::
complex
<
double
>
*
alpha
,
const
platform
::
complex
<
double
>
*
X
,
const
int
incX
,
platform
::
complex
<
double
>
*
Y
,
const
int
incY
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasZaxpy
(
handle
,
n
,
reinterpret_cast
<
const
cuDoubleComplex
*>
(
alpha
),
reinterpret_cast
<
const
cuDoubleComplex
*>
(
X
),
incX
,
...
...
@@ -383,11 +389,13 @@ struct CUBlas<platform::complex128> {
static
void
GEMM_STRIDED_BATCH
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
complex128
*
alpha
,
const
complex128
*
A
,
int
lda
,
long
long
int
strideA
,
// NOLINT
const
complex128
*
B
,
// NOLINT
int
ldb
,
long
long
int
strideB
,
// NOLINT
const
complex128
*
beta
,
complex128
*
C
,
int
ldc
,
const
platform
::
complex
<
double
>
*
alpha
,
const
platform
::
complex
<
double
>
*
A
,
int
lda
,
long
long
int
strideA
,
// NOLINT
const
platform
::
complex
<
double
>
*
B
,
// NOLINT
int
ldb
,
long
long
int
strideB
,
// NOLINT
const
platform
::
complex
<
double
>
*
beta
,
platform
::
complex
<
double
>
*
C
,
int
ldc
,
long
long
int
strideC
,
// NOLINT
int
batchCount
)
{
#if CUDA_VERSION >= 8000
...
...
@@ -406,9 +414,11 @@ struct CUBlas<platform::complex128> {
static
void
GEMM
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
complex128
*
alpha
,
const
complex128
*
A
,
int
lda
,
const
complex128
*
B
,
int
ldb
,
const
complex128
*
beta
,
complex128
*
C
,
int
ldc
)
{
const
platform
::
complex
<
double
>
*
alpha
,
const
platform
::
complex
<
double
>
*
A
,
int
lda
,
const
platform
::
complex
<
double
>
*
B
,
int
ldb
,
const
platform
::
complex
<
double
>
*
beta
,
platform
::
complex
<
double
>
*
C
,
int
ldc
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasZgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
reinterpret_cast
<
const
cuDoubleComplex
*>
(
alpha
),
...
...
@@ -535,9 +545,9 @@ template <>
template
<
>
inline
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
platform
::
complex
64
alpha
,
const
platform
::
complex64
*
A
,
const
platform
::
complex
64
*
B
,
platform
::
complex64
beta
,
platform
::
complex
64
*
C
)
const
{
platform
::
complex
<
float
>
alpha
,
const
platform
::
complex
<
float
>
*
A
,
const
platform
::
complex
<
float
>
*
B
,
platform
::
complex
<
float
>
beta
,
platform
::
complex
<
float
>
*
C
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
...
...
@@ -565,16 +575,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
auto
&
cuda_ctx
=
const_cast
<
platform
::
CUDADeviceContext
&>
(
context_
);
CUBlas
<
platform
::
complex
64
>::
GEMM_EX
(
CUBlas
<
platform
::
complex
<
float
>
>::
GEMM_EX
(
&
cuda_ctx
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
c_alpha
,
B
,
CUDA_C_32F
,
ldb
,
A
,
CUDA_C_32F
,
lda
,
&
c_beta
,
C
,
CUDA_C_32F
,
N
,
CUDA_C_32F
);
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
platform
::
complex
64
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
c_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
c_bet
a
,
h_C
,
N
);
CUBlas
<
platform
::
complex
<
float
>
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
c_alpha
,
h_B
,
ldb
,
h_A
,
ld
a
,
&
c_beta
,
h_C
,
N
);
});
#endif // CUDA_VERSION >= 8000
}
...
...
@@ -583,9 +593,9 @@ template <>
template
<
>
inline
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
platform
::
complex
128
alpha
,
const
platform
::
complex128
*
A
,
const
platform
::
complex
128
*
B
,
platform
::
complex128
beta
,
platform
::
complex
128
*
C
)
const
{
platform
::
complex
<
double
>
alpha
,
const
platform
::
complex
<
double
>
*
A
,
const
platform
::
complex
<
double
>
*
B
,
platform
::
complex
<
double
>
beta
,
platform
::
complex
<
double
>
*
C
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
...
...
@@ -614,16 +624,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
auto
&
cuda_ctx
=
const_cast
<
platform
::
CUDADeviceContext
&>
(
context_
);
CUBlas
<
platform
::
complex
128
>::
GEMM_EX
(
CUBlas
<
platform
::
complex
<
double
>
>::
GEMM_EX
(
&
cuda_ctx
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
c_alpha
,
B
,
CUDA_C_64F
,
ldb
,
A
,
CUDA_C_64F
,
lda
,
&
c_beta
,
C
,
CUDA_C_64F
,
N
,
CUDA_C_64F
);
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
platform
::
complex
128
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
c_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
c_bet
a
,
h_C
,
N
);
CUBlas
<
platform
::
complex
<
double
>
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
c_alpha
,
h_B
,
ldb
,
h_A
,
ld
a
,
&
c_beta
,
h_C
,
N
);
});
#endif // CUDA_VERSION >= 8000
}
...
...
paddle/fluid/operators/math/blas_impl.h
浏览文件 @
6c07cd7e
...
...
@@ -23,8 +23,7 @@
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -324,11 +323,11 @@ struct CBlas<double> {
};
template
<
>
struct
CBlas
<
platform
::
complex
64
>
{
struct
CBlas
<
platform
::
complex
<
float
>
>
{
template
<
typename
...
ARGS
>
static
void
AXPY
(
int
n
,
const
paddle
::
platform
::
complex
64
alpha
,
const
paddle
::
platform
::
complex
64
*
X
,
const
int
incX
,
paddle
::
platform
::
complex
64
*
Y
,
const
int
incY
)
{
static
void
AXPY
(
int
n
,
const
paddle
::
platform
::
complex
<
float
>
alpha
,
const
paddle
::
platform
::
complex
<
float
>
*
X
,
const
int
incX
,
paddle
::
platform
::
complex
<
float
>
*
Y
,
const
int
incY
)
{
platform
::
dynload
::
cblas_caxpy
(
n
,
&
alpha
,
X
,
incX
,
Y
,
incY
);
}
...
...
@@ -363,35 +362,35 @@ struct CBlas<platform::complex64> {
*/
template
<
typename
...
ARGS
>
static
void
VADD
(
int
n
,
const
paddle
::
platform
::
complex
64
*
a
,
const
paddle
::
platform
::
complex
64
*
b
,
paddle
::
platform
::
complex
64
*
y
)
{
static
void
VADD
(
int
n
,
const
paddle
::
platform
::
complex
<
float
>
*
a
,
const
paddle
::
platform
::
complex
<
float
>
*
b
,
paddle
::
platform
::
complex
<
float
>
*
y
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
i
]
+
b
[
i
];
}
}
template
<
typename
...
ARGS
>
static
void
VSUB
(
int
n
,
const
paddle
::
platform
::
complex
64
*
a
,
const
paddle
::
platform
::
complex
64
*
b
,
paddle
::
platform
::
complex
64
*
y
)
{
static
void
VSUB
(
int
n
,
const
paddle
::
platform
::
complex
<
float
>
*
a
,
const
paddle
::
platform
::
complex
<
float
>
*
b
,
paddle
::
platform
::
complex
<
float
>
*
y
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
i
]
-
b
[
i
];
}
}
template
<
typename
...
ARGS
>
static
void
VMUL
(
int
n
,
const
paddle
::
platform
::
complex
64
*
a
,
const
paddle
::
platform
::
complex
64
*
b
,
paddle
::
platform
::
complex
64
*
y
)
{
static
void
VMUL
(
int
n
,
const
paddle
::
platform
::
complex
<
float
>
*
a
,
const
paddle
::
platform
::
complex
<
float
>
*
b
,
paddle
::
platform
::
complex
<
float
>
*
y
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
i
]
*
b
[
i
];
}
}
template
<
typename
...
ARGS
>
static
void
VDIV
(
int
n
,
const
paddle
::
platform
::
complex
64
*
a
,
const
paddle
::
platform
::
complex
64
*
b
,
paddle
::
platform
::
complex
64
*
y
)
{
static
void
VDIV
(
int
n
,
const
paddle
::
platform
::
complex
<
float
>
*
a
,
const
paddle
::
platform
::
complex
<
float
>
*
b
,
paddle
::
platform
::
complex
<
float
>
*
y
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
i
]
/
b
[
i
];
}
...
...
@@ -399,11 +398,11 @@ struct CBlas<platform::complex64> {
template
<
typename
...
ARGS
>
static
void
GEMV
(
CBLAS_LAYOUT
layout
,
CBLAS_TRANSPOSE
trans
,
int
M
,
int
N
,
paddle
::
platform
::
complex
64
alpha
,
const
paddle
::
platform
::
complex
64
*
A
,
int
lda
,
const
paddle
::
platform
::
complex
64
*
X
,
int
incx
,
paddle
::
platform
::
complex
64
beta
,
paddle
::
platform
::
complex
64
*
Y
,
int
incy
)
{
paddle
::
platform
::
complex
<
float
>
alpha
,
const
paddle
::
platform
::
complex
<
float
>
*
A
,
int
lda
,
const
paddle
::
platform
::
complex
<
float
>
*
X
,
int
incx
,
paddle
::
platform
::
complex
<
float
>
beta
,
paddle
::
platform
::
complex
<
float
>
*
Y
,
int
incy
)
{
const
void
*
a_
=
(
const
void
*
)(
A
);
const
void
*
x_
=
(
const
void
*
)(
X
);
void
*
y_
=
static_cast
<
void
*>
(
Y
);
...
...
@@ -414,11 +413,11 @@ struct CBlas<platform::complex64> {
template
<
typename
...
ARGS
>
static
void
GEMM
(
CBLAS_LAYOUT
layout
,
CBLAS_TRANSPOSE
trans_a
,
CBLAS_TRANSPOSE
trans_b
,
int
M
,
int
N
,
int
K
,
paddle
::
platform
::
complex
64
alpha
,
const
paddle
::
platform
::
complex
64
*
A
,
int
lda
,
const
paddle
::
platform
::
complex
64
*
B
,
int
ldb
,
paddle
::
platform
::
complex
64
beta
,
paddle
::
platform
::
complex
64
*
C
,
int
ldc
)
{
paddle
::
platform
::
complex
<
float
>
alpha
,
const
paddle
::
platform
::
complex
<
float
>
*
A
,
int
lda
,
const
paddle
::
platform
::
complex
<
float
>
*
B
,
int
ldb
,
paddle
::
platform
::
complex
<
float
>
beta
,
paddle
::
platform
::
complex
<
float
>
*
C
,
int
ldc
)
{
const
void
*
a_
=
(
const
void
*
)(
A
);
const
void
*
b_
=
(
const
void
*
)(
B
);
void
*
c_
=
static_cast
<
void
*>
(
C
);
...
...
@@ -429,11 +428,12 @@ struct CBlas<platform::complex64> {
template
<
typename
...
ARGS
>
static
void
GEMM_BATCH
(
CBLAS_LAYOUT
layout
,
CBLAS_TRANSPOSE
*
trans_a
,
CBLAS_TRANSPOSE
*
trans_b
,
int
*
M
,
int
*
N
,
int
*
K
,
paddle
::
platform
::
complex64
*
alpha
,
const
paddle
::
platform
::
complex64
**
A
,
const
int
*
lda
,
const
paddle
::
platform
::
complex64
**
B
,
const
int
*
ldb
,
paddle
::
platform
::
complex64
*
beta
,
paddle
::
platform
::
complex64
**
C
,
const
int
*
ldc
,
paddle
::
platform
::
complex
<
float
>
*
alpha
,
const
paddle
::
platform
::
complex
<
float
>
**
A
,
const
int
*
lda
,
const
paddle
::
platform
::
complex
<
float
>
**
B
,
const
int
*
ldb
,
paddle
::
platform
::
complex
<
float
>
*
beta
,
paddle
::
platform
::
complex
<
float
>
**
C
,
const
int
*
ldc
,
int
group_count
,
int
*
group_size
)
{
const
void
**
A_void
=
(
const
void
**
)(
&
(
*
A
));
const
void
**
B_void
=
(
const
void
**
)(
&
(
*
B
));
...
...
@@ -451,11 +451,11 @@ struct CBlas<platform::complex64> {
};
template
<
>
struct
CBlas
<
platform
::
complex
128
>
{
struct
CBlas
<
platform
::
complex
<
double
>
>
{
template
<
typename
...
ARGS
>
static
void
AXPY
(
int
n
,
const
paddle
::
platform
::
complex
128
alpha
,
const
paddle
::
platform
::
complex
128
*
X
,
const
int
incX
,
paddle
::
platform
::
complex
128
*
Y
,
const
int
incY
)
{
static
void
AXPY
(
int
n
,
const
paddle
::
platform
::
complex
<
double
>
alpha
,
const
paddle
::
platform
::
complex
<
double
>
*
X
,
const
int
incX
,
paddle
::
platform
::
complex
<
double
>
*
Y
,
const
int
incY
)
{
platform
::
dynload
::
cblas_zaxpy
(
n
,
&
alpha
,
X
,
incX
,
Y
,
incY
);
}
...
...
@@ -490,35 +490,35 @@ struct CBlas<platform::complex128> {
*/
template
<
typename
...
ARGS
>
static
void
VADD
(
int
n
,
const
paddle
::
platform
::
complex
128
*
a
,
const
paddle
::
platform
::
complex
128
*
b
,
paddle
::
platform
::
complex
128
*
y
)
{
static
void
VADD
(
int
n
,
const
paddle
::
platform
::
complex
<
double
>
*
a
,
const
paddle
::
platform
::
complex
<
double
>
*
b
,
paddle
::
platform
::
complex
<
double
>
*
y
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
i
]
+
b
[
i
];
}
}
template
<
typename
...
ARGS
>
static
void
VSUB
(
int
n
,
const
paddle
::
platform
::
complex
128
*
a
,
const
paddle
::
platform
::
complex
128
*
b
,
paddle
::
platform
::
complex
128
*
y
)
{
static
void
VSUB
(
int
n
,
const
paddle
::
platform
::
complex
<
double
>
*
a
,
const
paddle
::
platform
::
complex
<
double
>
*
b
,
paddle
::
platform
::
complex
<
double
>
*
y
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
i
]
-
b
[
i
];
}
}
template
<
typename
...
ARGS
>
static
void
VMUL
(
int
n
,
const
paddle
::
platform
::
complex
128
*
a
,
const
paddle
::
platform
::
complex
128
*
b
,
paddle
::
platform
::
complex
128
*
y
)
{
static
void
VMUL
(
int
n
,
const
paddle
::
platform
::
complex
<
double
>
*
a
,
const
paddle
::
platform
::
complex
<
double
>
*
b
,
paddle
::
platform
::
complex
<
double
>
*
y
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
i
]
*
b
[
i
];
}
}
template
<
typename
...
ARGS
>
static
void
VDIV
(
int
n
,
const
paddle
::
platform
::
complex
128
*
a
,
const
paddle
::
platform
::
complex
128
*
b
,
paddle
::
platform
::
complex
128
*
y
)
{
static
void
VDIV
(
int
n
,
const
paddle
::
platform
::
complex
<
double
>
*
a
,
const
paddle
::
platform
::
complex
<
double
>
*
b
,
paddle
::
platform
::
complex
<
double
>
*
y
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
a
[
i
]
/
b
[
i
];
}
...
...
@@ -526,11 +526,11 @@ struct CBlas<platform::complex128> {
template
<
typename
...
ARGS
>
static
void
GEMV
(
CBLAS_LAYOUT
layout
,
CBLAS_TRANSPOSE
trans
,
int
M
,
int
N
,
paddle
::
platform
::
complex
128
alpha
,
const
paddle
::
platform
::
complex
128
*
A
,
int
lda
,
const
paddle
::
platform
::
complex
128
*
X
,
int
incx
,
paddle
::
platform
::
complex
128
beta
,
paddle
::
platform
::
complex
128
*
Y
,
int
incy
)
{
paddle
::
platform
::
complex
<
double
>
alpha
,
const
paddle
::
platform
::
complex
<
double
>
*
A
,
int
lda
,
const
paddle
::
platform
::
complex
<
double
>
*
X
,
int
incx
,
paddle
::
platform
::
complex
<
double
>
beta
,
paddle
::
platform
::
complex
<
double
>
*
Y
,
int
incy
)
{
const
void
*
a_
=
(
const
void
*
)(
A
);
const
void
*
x_
=
(
const
void
*
)(
X
);
void
*
y_
=
static_cast
<
void
*>
(
Y
);
...
...
@@ -541,11 +541,11 @@ struct CBlas<platform::complex128> {
template
<
typename
...
ARGS
>
static
void
GEMM
(
CBLAS_LAYOUT
layout
,
CBLAS_TRANSPOSE
trans_a
,
CBLAS_TRANSPOSE
trans_b
,
int
M
,
int
N
,
int
K
,
paddle
::
platform
::
complex
128
alpha
,
const
paddle
::
platform
::
complex
128
*
A
,
int
lda
,
const
paddle
::
platform
::
complex
128
*
B
,
int
ldb
,
paddle
::
platform
::
complex
128
beta
,
paddle
::
platform
::
complex
128
*
C
,
int
ldc
)
{
paddle
::
platform
::
complex
<
double
>
alpha
,
const
paddle
::
platform
::
complex
<
double
>
*
A
,
int
lda
,
const
paddle
::
platform
::
complex
<
double
>
*
B
,
int
ldb
,
paddle
::
platform
::
complex
<
double
>
beta
,
paddle
::
platform
::
complex
<
double
>
*
C
,
int
ldc
)
{
const
void
*
a_
=
(
const
void
*
)(
A
);
const
void
*
b_
=
(
const
void
*
)(
B
);
void
*
c_
=
static_cast
<
void
*>
(
C
);
...
...
@@ -556,11 +556,13 @@ struct CBlas<platform::complex128> {
template
<
typename
...
ARGS
>
static
void
GEMM_BATCH
(
CBLAS_LAYOUT
layout
,
CBLAS_TRANSPOSE
*
trans_a
,
CBLAS_TRANSPOSE
*
trans_b
,
int
*
M
,
int
*
N
,
int
*
K
,
paddle
::
platform
::
complex128
*
alpha
,
const
paddle
::
platform
::
complex128
**
A
,
const
int
*
lda
,
const
paddle
::
platform
::
complex128
**
B
,
const
int
*
ldb
,
paddle
::
platform
::
complex128
*
beta
,
paddle
::
platform
::
complex128
**
C
,
const
int
*
ldc
,
paddle
::
platform
::
complex
<
double
>
*
alpha
,
const
paddle
::
platform
::
complex
<
double
>
**
A
,
const
int
*
lda
,
const
paddle
::
platform
::
complex
<
double
>
**
B
,
const
int
*
ldb
,
paddle
::
platform
::
complex
<
double
>
*
beta
,
paddle
::
platform
::
complex
<
double
>
**
C
,
const
int
*
ldc
,
int
group_count
,
int
*
group_size
)
{
const
void
**
A_void
=
(
const
void
**
)(
&
(
*
A
));
const
void
**
B_void
=
(
const
void
**
)(
&
(
*
B
));
...
...
@@ -636,76 +638,76 @@ struct CBlas<double> {
};
template
<
>
struct
CBlas
<
platform
::
complex
64
>
{
struct
CBlas
<
platform
::
complex
<
float
>
>
{
template
<
typename
...
ARGS
>
static
void
VCOPY
(
ARGS
...
args
)
{
cblas_ccopy
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
AXPY
(
int
n
,
const
paddle
::
platform
::
complex
64
alpha
,
const
paddle
::
platform
::
complex
64
*
X
,
const
int
incX
,
paddle
::
platform
::
complex
64
*
Y
,
const
int
incY
)
{
static
void
AXPY
(
int
n
,
const
paddle
::
platform
::
complex
<
float
>
alpha
,
const
paddle
::
platform
::
complex
<
float
>
*
X
,
const
int
incX
,
paddle
::
platform
::
complex
<
float
>
*
Y
,
const
int
incY
)
{
cblas_caxpy
(
n
,
&
alpha
,
X
,
incX
,
Y
,
incY
);
}
template
<
typename
...
ARGS
>
static
void
GEMV
(
const
CBLAS_LAYOUT
layout
,
const
CBLAS_TRANSPOSE
TransA
,
const
int
M
,
const
int
N
,
const
paddle
::
platform
::
complex
64
alpha
,
const
paddle
::
platform
::
complex
64
*
A
,
const
int
lda
,
const
paddle
::
platform
::
complex
64
*
X
,
const
int
incX
,
const
paddle
::
platform
::
complex
64
beta
,
paddle
::
platform
::
complex
64
*
Y
,
const
int
incY
)
{
const
paddle
::
platform
::
complex
<
float
>
alpha
,
const
paddle
::
platform
::
complex
<
float
>
*
A
,
const
int
lda
,
const
paddle
::
platform
::
complex
<
float
>
*
X
,
const
int
incX
,
const
paddle
::
platform
::
complex
<
float
>
beta
,
paddle
::
platform
::
complex
<
float
>
*
Y
,
const
int
incY
)
{
cblas_cgemv
(
layout
,
TransA
,
M
,
N
,
&
alpha
,
A
,
lda
,
X
,
incX
,
&
beta
,
Y
,
incY
);
}
template
<
typename
...
ARGS
>
static
void
GEMM
(
const
CBLAS_LAYOUT
layout
,
const
CBLAS_TRANSPOSE
TransA
,
const
CBLAS_TRANSPOSE
TransB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
paddle
::
platform
::
complex
64
alpha
,
const
paddle
::
platform
::
complex
64
*
A
,
const
int
lda
,
const
paddle
::
platform
::
complex
64
*
B
,
const
int
ldb
,
const
paddle
::
platform
::
complex
64
beta
,
paddle
::
platform
::
complex
64
*
C
,
const
int
ldc
)
{
const
int
K
,
const
paddle
::
platform
::
complex
<
float
>
alpha
,
const
paddle
::
platform
::
complex
<
float
>
*
A
,
const
int
lda
,
const
paddle
::
platform
::
complex
<
float
>
*
B
,
const
int
ldb
,
const
paddle
::
platform
::
complex
<
float
>
beta
,
paddle
::
platform
::
complex
<
float
>
*
C
,
const
int
ldc
)
{
cblas_cgemm
(
layout
,
TransA
,
TransB
,
M
,
N
,
K
,
&
alpha
,
A
,
lda
,
B
,
ldb
,
&
beta
,
C
,
ldc
);
}
};
template
<
>
struct
CBlas
<
platform
::
complex
128
>
{
struct
CBlas
<
platform
::
complex
<
double
>
>
{
template
<
typename
...
ARGS
>
static
void
VCOPY
(
ARGS
...
args
)
{
cblas_zcopy
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
AXPY
(
int
n
,
const
paddle
::
platform
::
complex
128
alpha
,
const
paddle
::
platform
::
complex
128
*
X
,
const
int
incX
,
paddle
::
platform
::
complex
128
*
Y
,
const
int
incY
)
{
static
void
AXPY
(
int
n
,
const
paddle
::
platform
::
complex
<
double
>
alpha
,
const
paddle
::
platform
::
complex
<
double
>
*
X
,
const
int
incX
,
paddle
::
platform
::
complex
<
double
>
*
Y
,
const
int
incY
)
{
cblas_zaxpy
(
n
,
&
alpha
,
X
,
incX
,
Y
,
incY
);
}
template
<
typename
...
ARGS
>
static
void
GEMV
(
const
CBLAS_LAYOUT
layout
,
const
CBLAS_TRANSPOSE
TransA
,
const
int
M
,
const
int
N
,
const
paddle
::
platform
::
complex
128
alpha
,
const
paddle
::
platform
::
complex
128
*
A
,
const
int
lda
,
const
paddle
::
platform
::
complex
128
*
X
,
const
int
incX
,
const
paddle
::
platform
::
complex
128
beta
,
paddle
::
platform
::
complex
128
*
Y
,
const
int
incY
)
{
const
paddle
::
platform
::
complex
<
double
>
alpha
,
const
paddle
::
platform
::
complex
<
double
>
*
A
,
const
int
lda
,
const
paddle
::
platform
::
complex
<
double
>
*
X
,
const
int
incX
,
const
paddle
::
platform
::
complex
<
double
>
beta
,
paddle
::
platform
::
complex
<
double
>
*
Y
,
const
int
incY
)
{
cblas_zgemv
(
layout
,
TransA
,
M
,
N
,
&
alpha
,
A
,
lda
,
X
,
incX
,
&
beta
,
Y
,
incY
);
}
template
<
typename
...
ARGS
>
static
void
GEMM
(
const
CBLAS_LAYOUT
layout
,
const
CBLAS_TRANSPOSE
TransA
,
const
CBLAS_TRANSPOSE
TransB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
paddle
::
platform
::
complex
128
alpha
,
const
paddle
::
platform
::
complex
128
*
A
,
const
int
lda
,
const
paddle
::
platform
::
complex
128
*
B
,
const
int
ldb
,
const
paddle
::
platform
::
complex
128
beta
,
paddle
::
platform
::
complex
128
*
C
,
const
int
ldc
)
{
const
int
K
,
const
paddle
::
platform
::
complex
<
double
>
alpha
,
const
paddle
::
platform
::
complex
<
double
>
*
A
,
const
int
lda
,
const
paddle
::
platform
::
complex
<
double
>
*
B
,
const
int
ldb
,
const
paddle
::
platform
::
complex
<
double
>
beta
,
paddle
::
platform
::
complex
<
double
>
*
C
,
const
int
ldc
)
{
cblas_zgemm
(
layout
,
TransA
,
TransB
,
M
,
N
,
K
,
&
alpha
,
A
,
lda
,
B
,
ldb
,
&
beta
,
C
,
ldc
);
}
...
...
paddle/fluid/operators/math/blas_impl.hip.h
浏览文件 @
6c07cd7e
...
...
@@ -213,13 +213,13 @@ struct CUBlas<platform::float16> {
};
template
<
>
struct
CUBlas
<
platform
::
complex64
>
{
using
complex64
=
platform
::
complex64
;
struct
CUBlas
<
platform
::
complex
<
float
>>
{
static
void
GEMV
(
rocblas_handle
handle
,
rocblas_operation
transa
,
int
m
,
int
n
,
const
complex64
*
alpha
,
const
complex64
*
A
,
int
lda
,
const
complex64
*
B
,
int
ldb
,
const
complex64
*
beta
,
complex64
*
C
,
int
ldc
)
{
int
n
,
const
platform
::
complex
<
float
>
*
alpha
,
const
platform
::
complex
<
float
>
*
A
,
int
lda
,
const
platform
::
complex
<
float
>
*
B
,
int
ldb
,
const
platform
::
complex
<
float
>
*
beta
,
platform
::
complex
<
float
>
*
C
,
int
ldc
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
rocblas_cgemv
(
handle
,
transa
,
m
,
n
,
reinterpret_cast
<
const
rocblas_float_complex
*>
(
alpha
),
...
...
@@ -229,9 +229,10 @@ struct CUBlas<platform::complex64> {
reinterpret_cast
<
rocblas_float_complex
*>
(
C
),
ldc
));
}
static
void
AXPY
(
rocblas_handle
handle
,
int
n
,
const
complex64
*
alpha
,
const
complex64
*
X
,
const
int
incX
,
complex64
*
Y
,
const
int
incY
)
{
static
void
AXPY
(
rocblas_handle
handle
,
int
n
,
const
platform
::
complex
<
float
>
*
alpha
,
const
platform
::
complex
<
float
>
*
X
,
const
int
incX
,
platform
::
complex
<
float
>
*
Y
,
const
int
incY
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
rocblas_caxpy
(
handle
,
n
,
reinterpret_cast
<
const
rocblas_float_complex
*>
(
alpha
),
reinterpret_cast
<
const
rocblas_float_complex
*>
(
X
),
incX
,
...
...
@@ -241,11 +242,13 @@ struct CUBlas<platform::complex64> {
static
void
GEMM_STRIDED_BATCH
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
complex64
*
alpha
,
const
complex64
*
A
,
int
lda
,
long
long
int
strideA
,
// NOLINT
const
complex64
*
B
,
// NOLINT
int
ldb
,
long
long
int
strideB
,
// NOLINT
const
complex64
*
beta
,
complex64
*
C
,
int
ldc
,
const
platform
::
complex
<
float
>
*
alpha
,
const
platform
::
complex
<
float
>
*
A
,
int
lda
,
long
long
int
strideA
,
// NOLINT
const
platform
::
complex
<
float
>
*
B
,
// NOLINT
int
ldb
,
long
long
int
strideB
,
// NOLINT
const
platform
::
complex
<
float
>
*
beta
,
platform
::
complex
<
float
>
*
C
,
int
ldc
,
long
long
int
strideC
,
// NOLINT
int
batchCount
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
...
...
@@ -261,9 +264,11 @@ struct CUBlas<platform::complex64> {
static
void
GEMM
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
complex64
*
alpha
,
const
complex64
*
A
,
int
lda
,
const
complex64
*
B
,
int
ldb
,
const
complex64
*
beta
,
complex64
*
C
,
int
ldc
)
{
const
platform
::
complex
<
float
>
*
alpha
,
const
platform
::
complex
<
float
>
*
A
,
int
lda
,
const
platform
::
complex
<
float
>
*
B
,
int
ldb
,
const
platform
::
complex
<
float
>
*
beta
,
platform
::
complex
<
float
>
*
C
,
int
ldc
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
rocblas_cgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
reinterpret_cast
<
const
rocblas_float_complex
*>
(
alpha
),
...
...
@@ -293,13 +298,13 @@ struct CUBlas<platform::complex64> {
};
template
<
>
struct
CUBlas
<
platform
::
complex128
>
{
using
complex128
=
platform
::
complex128
;
struct
CUBlas
<
platform
::
complex
<
double
>>
{
static
void
GEMV
(
rocblas_handle
handle
,
rocblas_operation
transa
,
int
m
,
int
n
,
const
complex128
*
alpha
,
const
complex128
*
A
,
int
lda
,
const
complex128
*
B
,
int
ldb
,
const
complex128
*
beta
,
complex128
*
C
,
int
ldc
)
{
int
n
,
const
platform
::
complex
<
double
>
*
alpha
,
const
platform
::
complex
<
double
>
*
A
,
int
lda
,
const
platform
::
complex
<
double
>
*
B
,
int
ldb
,
const
platform
::
complex
<
double
>
*
beta
,
platform
::
complex
<
double
>
*
C
,
int
ldc
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
rocblas_zgemv
(
handle
,
transa
,
m
,
n
,
reinterpret_cast
<
const
rocblas_double_complex
*>
(
alpha
),
...
...
@@ -309,9 +314,10 @@ struct CUBlas<platform::complex128> {
reinterpret_cast
<
rocblas_double_complex
*>
(
C
),
ldc
));
}
static
void
AXPY
(
rocblas_handle
handle
,
int
n
,
const
complex128
*
alpha
,
const
complex128
*
X
,
const
int
incX
,
complex128
*
Y
,
const
int
incY
)
{
static
void
AXPY
(
rocblas_handle
handle
,
int
n
,
const
platform
::
complex
<
double
>
*
alpha
,
const
platform
::
complex
<
double
>
*
X
,
const
int
incX
,
platform
::
complex
<
double
>
*
Y
,
const
int
incY
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
rocblas_zaxpy
(
handle
,
n
,
reinterpret_cast
<
const
rocblas_double_complex
*>
(
alpha
),
reinterpret_cast
<
const
rocblas_double_complex
*>
(
X
),
incX
,
...
...
@@ -321,11 +327,13 @@ struct CUBlas<platform::complex128> {
static
void
GEMM_STRIDED_BATCH
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
complex128
*
alpha
,
const
complex128
*
A
,
int
lda
,
long
long
int
strideA
,
// NOLINT
const
complex128
*
B
,
// NOLINT
int
ldb
,
long
long
int
strideB
,
// NOLINT
const
complex128
*
beta
,
complex128
*
C
,
int
ldc
,
const
platform
::
complex
<
double
>
*
alpha
,
const
platform
::
complex
<
double
>
*
A
,
int
lda
,
long
long
int
strideA
,
// NOLINT
const
platform
::
complex
<
double
>
*
B
,
// NOLINT
int
ldb
,
long
long
int
strideB
,
// NOLINT
const
platform
::
complex
<
double
>
*
beta
,
platform
::
complex
<
double
>
*
C
,
int
ldc
,
long
long
int
strideC
,
// NOLINT
int
batchCount
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
...
...
@@ -341,9 +349,11 @@ struct CUBlas<platform::complex128> {
static
void
GEMM
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
complex128
*
alpha
,
const
complex128
*
A
,
int
lda
,
const
complex128
*
B
,
int
ldb
,
const
complex128
*
beta
,
complex128
*
C
,
int
ldc
)
{
const
platform
::
complex
<
double
>
*
alpha
,
const
platform
::
complex
<
double
>
*
A
,
int
lda
,
const
platform
::
complex
<
double
>
*
B
,
int
ldb
,
const
platform
::
complex
<
double
>
*
beta
,
platform
::
complex
<
double
>
*
C
,
int
ldc
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
rocblas_zgemm
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
reinterpret_cast
<
const
rocblas_double_complex
*>
(
alpha
),
...
...
@@ -434,9 +444,9 @@ template <>
template
<
>
inline
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
platform
::
complex
64
alpha
,
const
platform
::
complex64
*
A
,
const
platform
::
complex
64
*
B
,
platform
::
complex64
beta
,
platform
::
complex
64
*
C
)
const
{
platform
::
complex
<
float
>
alpha
,
const
platform
::
complex
<
float
>
*
A
,
const
platform
::
complex
<
float
>
*
B
,
platform
::
complex
<
float
>
beta
,
platform
::
complex
<
float
>
*
C
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
...
...
@@ -461,7 +471,7 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
thrust
::
complex
<
float
>
c_beta
=
thrust
::
complex
<
float
>
(
beta
.
real
,
beta
.
imag
);
auto
&
cuda_ctx
=
const_cast
<
platform
::
CUDADeviceContext
&>
(
context_
);
CUBlas
<
platform
::
complex
64
>::
GEMM_EX
(
CUBlas
<
platform
::
complex
<
float
>
>::
GEMM_EX
(
&
cuda_ctx
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
c_alpha
,
B
,
rocblas_datatype_f32_c
,
ldb
,
A
,
rocblas_datatype_f32_c
,
lda
,
&
c_beta
,
C
,
rocblas_datatype_f32_c
,
N
,
rocblas_datatype_f32_c
);
...
...
@@ -471,9 +481,9 @@ template <>
template
<
>
inline
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
platform
::
complex
128
alpha
,
const
platform
::
complex128
*
A
,
const
platform
::
complex
128
*
B
,
platform
::
complex128
beta
,
platform
::
complex
128
*
C
)
const
{
platform
::
complex
<
double
>
alpha
,
const
platform
::
complex
<
double
>
*
A
,
const
platform
::
complex
<
double
>
*
B
,
platform
::
complex
<
double
>
beta
,
platform
::
complex
<
double
>
*
C
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
...
...
@@ -499,7 +509,7 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
thrust
::
complex
<
double
>
(
beta
.
real
,
beta
.
imag
);
auto
&
cuda_ctx
=
const_cast
<
platform
::
CUDADeviceContext
&>
(
context_
);
CUBlas
<
platform
::
complex
128
>::
GEMM_EX
(
CUBlas
<
platform
::
complex
<
double
>
>::
GEMM_EX
(
&
cuda_ctx
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
c_alpha
,
B
,
rocblas_datatype_f64_c
,
ldb
,
A
,
rocblas_datatype_f64_c
,
lda
,
&
c_beta
,
C
,
rocblas_datatype_f64_c
,
N
,
rocblas_datatype_f64_c
);
...
...
paddle/fluid/operators/math/selected_rows_functor.cc
浏览文件 @
6c07cd7e
...
...
@@ -297,7 +297,9 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext,
namespace
scatter
{
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
>::
type
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
||
std
::
is_same
<
T
,
platform
::
complex
<
float
>>::
value
||
std
::
is_same
<
T
,
platform
::
complex
<
double
>>::
value
>::
type
elementwise_add_to
(
BlasT
<
platform
::
CPUDeviceContext
,
T
>*
blas
,
size_t
data_len
,
const
T
*
in
,
T
*
out
)
{
blas
->
AXPY
(
data_len
,
T
(
1.
f
),
in
,
out
);
...
...
@@ -542,9 +544,9 @@ template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
64
>;
paddle
::
platform
::
complex
<
float
>
>
;
template
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
128
>;
paddle
::
platform
::
complex
<
double
>
>
;
template
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>;
...
...
paddle/fluid/operators/matmul_v2_op.cc
浏览文件 @
6c07cd7e
...
...
@@ -204,15 +204,15 @@ REGISTER_OP_CPU_KERNEL(
matmul_v2
,
ops
::
MatMulV2Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MatMulV2Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
MatMulV2Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
64
>
,
paddle
::
platform
::
complex
<
float
>
>
,
ops
::
MatMulV2Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
128
>
);
paddle
::
platform
::
complex
<
double
>
>
);
REGISTER_OP_CPU_KERNEL
(
matmul_v2_grad
,
ops
::
MatMulV2GradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MatMulV2GradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
MatMulV2GradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
64
>
,
paddle
::
platform
::
complex
<
float
>
>
,
ops
::
MatMulV2GradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
128
>
);
paddle
::
platform
::
complex
<
double
>
>
);
paddle/fluid/operators/matmul_v2_op.cu
浏览文件 @
6c07cd7e
...
...
@@ -21,12 +21,12 @@ REGISTER_OP_CUDA_KERNEL(
matmul_v2
,
ops
::
MatMulV2Kernel
<
plf
::
CUDADeviceContext
,
float
>
,
ops
::
MatMulV2Kernel
<
plf
::
CUDADeviceContext
,
double
>
,
ops
::
MatMulV2Kernel
<
plf
::
CUDADeviceContext
,
plf
::
float16
>
,
ops
::
MatMulV2Kernel
<
plf
::
CUDADeviceContext
,
plf
::
complex
64
>
,
ops
::
MatMulV2Kernel
<
plf
::
CUDADeviceContext
,
plf
::
complex
128
>
);
ops
::
MatMulV2Kernel
<
plf
::
CUDADeviceContext
,
plf
::
complex
<
float
>
>
,
ops
::
MatMulV2Kernel
<
plf
::
CUDADeviceContext
,
plf
::
complex
<
double
>
>
);
REGISTER_OP_CUDA_KERNEL
(
matmul_v2_grad
,
ops
::
MatMulV2GradKernel
<
plf
::
CUDADeviceContext
,
float
>
,
ops
::
MatMulV2GradKernel
<
plf
::
CUDADeviceContext
,
double
>
,
ops
::
MatMulV2GradKernel
<
plf
::
CUDADeviceContext
,
plf
::
float16
>
,
ops
::
MatMulV2GradKernel
<
plf
::
CUDADeviceContext
,
plf
::
complex
64
>
,
ops
::
MatMulV2GradKernel
<
plf
::
CUDADeviceContext
,
plf
::
complex
128
>
);
ops
::
MatMulV2GradKernel
<
plf
::
CUDADeviceContext
,
plf
::
complex
<
float
>
>
,
ops
::
MatMulV2GradKernel
<
plf
::
CUDADeviceContext
,
plf
::
complex
<
double
>
>
);
paddle/fluid/operators/matmul_v2_op.h
浏览文件 @
6c07cd7e
...
...
@@ -483,19 +483,19 @@ struct ConjHelper {
};
template
<
typename
DeviceContext
>
struct
ConjHelper
<
DeviceContext
,
paddle
::
platform
::
complex
64
>
{
struct
ConjHelper
<
DeviceContext
,
paddle
::
platform
::
complex
<
float
>
>
{
explicit
ConjHelper
(
const
framework
::
ExecutionContext
&
ctx
)
:
ctx_
(
ctx
)
{}
HOSTDEVICE
void
operator
()(
framework
::
Tensor
&
src
,
framework
::
Tensor
&
dst
)
{
dst
.
Resize
(
src
.
dims
());
auto
*
src_data
=
src
.
data
<
paddle
::
platform
::
complex
64
>
();
auto
*
dst_data
=
dst
.
mutable_data
<
paddle
::
platform
::
complex
64
>
(
auto
*
src_data
=
src
.
data
<
paddle
::
platform
::
complex
<
float
>
>
();
auto
*
dst_data
=
dst
.
mutable_data
<
paddle
::
platform
::
complex
<
float
>
>
(
ctx_
.
GetPlace
(),
size_t
(
src
.
numel
()
*
sizeof
(
paddle
::
platform
::
complex
64
)));
size_t
(
src
.
numel
()
*
sizeof
(
paddle
::
platform
::
complex
<
float
>
)));
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx_
.
template
device_context
<
DeviceContext
>(),
src
.
numel
());
math
::
ConjFunctor
<
paddle
::
platform
::
complex
64
>
functor
(
math
::
ConjFunctor
<
paddle
::
platform
::
complex
<
float
>
>
functor
(
src_data
,
src
.
numel
(),
dst_data
);
for_range
(
functor
);
return
;
...
...
@@ -504,19 +504,19 @@ struct ConjHelper<DeviceContext, paddle::platform::complex64> {
};
template
<
typename
DeviceContext
>
struct
ConjHelper
<
DeviceContext
,
paddle
::
platform
::
complex
128
>
{
struct
ConjHelper
<
DeviceContext
,
paddle
::
platform
::
complex
<
double
>
>
{
explicit
ConjHelper
(
const
framework
::
ExecutionContext
&
ctx
)
:
ctx_
(
ctx
)
{}
HOSTDEVICE
void
operator
()(
framework
::
Tensor
&
src
,
framework
::
Tensor
&
dst
)
{
dst
.
Resize
(
src
.
dims
());
auto
*
src_data
=
src
.
data
<
paddle
::
platform
::
complex
128
>
();
auto
*
dst_data
=
dst
.
mutable_data
<
paddle
::
platform
::
complex
128
>
(
auto
*
src_data
=
src
.
data
<
paddle
::
platform
::
complex
<
double
>
>
();
auto
*
dst_data
=
dst
.
mutable_data
<
paddle
::
platform
::
complex
<
double
>
>
(
ctx_
.
GetPlace
(),
size_t
(
src
.
numel
()
*
sizeof
(
paddle
::
platform
::
complex
128
)));
size_t
(
src
.
numel
()
*
sizeof
(
paddle
::
platform
::
complex
<
double
>
)));
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx_
.
template
device_context
<
DeviceContext
>(),
src
.
numel
());
math
::
ConjFunctor
<
paddle
::
platform
::
complex
128
>
functor
(
math
::
ConjFunctor
<
paddle
::
platform
::
complex
<
double
>
>
functor
(
src_data
,
src
.
numel
(),
dst_data
);
for_range
(
functor
);
return
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录