Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5b5656d0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5b5656d0
编写于
2月 16, 2022
作者:
F
Feiyu Chan
提交者:
GitHub
2月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Pten] move complex_functors.h (#39558)
* move complex_functors.h and update all references to symbols within it
上级
12ca438e
变更
35
隐藏空白更改
内联
并排
Showing
35 changed file
with
318 addition
and
309 deletion
+318
-309
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+4
-0
paddle/fluid/operators/angle_op.h
paddle/fluid/operators/angle_op.h
+6
-81
paddle/fluid/operators/cholesky_solve_op.h
paddle/fluid/operators/cholesky_solve_op.h
+5
-5
paddle/fluid/operators/complex_op.h
paddle/fluid/operators/complex_op.h
+1
-1
paddle/fluid/operators/complex_view_op.h
paddle/fluid/operators/complex_view_op.h
+1
-1
paddle/fluid/operators/cumprod_op.cu
paddle/fluid/operators/cumprod_op.cu
+3
-3
paddle/fluid/operators/cumprod_op.h
paddle/fluid/operators/cumprod_op.h
+3
-3
paddle/fluid/operators/determinant_op.h
paddle/fluid/operators/determinant_op.h
+2
-2
paddle/fluid/operators/dot_op.h
paddle/fluid/operators/dot_op.h
+1
-1
paddle/fluid/operators/eig_op.h
paddle/fluid/operators/eig_op.h
+22
-20
paddle/fluid/operators/eigh_op.h
paddle/fluid/operators/eigh_op.h
+1
-1
paddle/fluid/operators/eigvals_op.h
paddle/fluid/operators/eigvals_op.h
+12
-12
paddle/fluid/operators/imag_op.h
paddle/fluid/operators/imag_op.h
+7
-6
paddle/fluid/operators/lstsq_op.h
paddle/fluid/operators/lstsq_op.h
+3
-3
paddle/fluid/operators/lu_op.h
paddle/fluid/operators/lu_op.h
+3
-2
paddle/fluid/operators/math/eigen_values_vectors.h
paddle/fluid/operators/math/eigen_values_vectors.h
+6
-6
paddle/fluid/operators/math/inclusive_scan.h
paddle/fluid/operators/math/inclusive_scan.h
+2
-2
paddle/fluid/operators/matmul_v2_op.h
paddle/fluid/operators/matmul_v2_op.h
+1
-1
paddle/fluid/operators/matrix_rank_op.cu
paddle/fluid/operators/matrix_rank_op.cu
+3
-3
paddle/fluid/operators/qr_op.cu
paddle/fluid/operators/qr_op.cu
+10
-7
paddle/fluid/operators/qr_op.h
paddle/fluid/operators/qr_op.h
+13
-10
paddle/fluid/operators/real_op.h
paddle/fluid/operators/real_op.h
+7
-6
paddle/fluid/operators/renorm_op.h
paddle/fluid/operators/renorm_op.h
+1
-1
paddle/fluid/operators/spectral_op.cu
paddle/fluid/operators/spectral_op.cu
+9
-9
paddle/fluid/operators/svd_helper.h
paddle/fluid/operators/svd_helper.h
+10
-9
paddle/fluid/operators/svd_op.h
paddle/fluid/operators/svd_op.h
+7
-7
paddle/fluid/operators/triangular_solve_op.h
paddle/fluid/operators/triangular_solve_op.h
+3
-3
paddle/pten/kernels/cpu/abs_grad_kernel.cc
paddle/pten/kernels/cpu/abs_grad_kernel.cc
+1
-1
paddle/pten/kernels/cpu/abs_kernel.cc
paddle/pten/kernels/cpu/abs_kernel.cc
+5
-5
paddle/pten/kernels/funcs/complex_functors.h
paddle/pten/kernels/funcs/complex_functors.h
+146
-60
paddle/pten/kernels/gpu/abs_kernel.cu
paddle/pten/kernels/gpu/abs_kernel.cu
+6
-11
paddle/pten/kernels/impl/abs_grad_kernel_impl.h
paddle/pten/kernels/impl/abs_grad_kernel_impl.h
+4
-5
paddle/pten/kernels/impl/complex_kernel_impl.h
paddle/pten/kernels/impl/complex_kernel_impl.h
+2
-2
paddle/pten/kernels/impl/dot_grad_kernel_impl.h
paddle/pten/kernels/impl/dot_grad_kernel_impl.h
+7
-19
paddle/pten/kernels/impl/matmul_kernel_impl.h
paddle/pten/kernels/impl/matmul_kernel_impl.h
+1
-1
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
5b5656d0
include
(
operators
)
# solve "math constants not defined" problems caused by the order of inclusion
# of <cmath> and the definition of macro _USE_MATH_DEFINES
add_definitions
(
-D_USE_MATH_DEFINES
)
# clean cache and pybind_file content first when rebuild
unset
(
GLOB_OP_LIB CACHE
)
unset
(
OP_LIBRARY CACHE
)
...
...
paddle/fluid/operators/angle_op.h
浏览文件 @
5b5656d0
...
...
@@ -17,7 +17,7 @@
#define _USE_MATH_DEFINES
#endif
#include <cmath>
#include "paddle/
fluid/operators/math
/complex_functors.h"
#include "paddle/
pten/kernels/funcs
/complex_functors.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
...
...
@@ -26,81 +26,6 @@
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
,
typename
Enable
=
void
>
struct
AngleFunctor
;
// angel function for complex
template
<
typename
T
>
struct
AngleFunctor
<
T
,
Complex
<
T
,
Real
<
T
>>>
{
AngleFunctor
(
const
T
*
input
,
Real
<
T
>*
output
,
int64_t
numel
)
:
input_
(
input
),
output_
(
output
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
output_
[
idx
]
=
arg
(
input_
[
idx
]);
}
const
T
*
input_
;
Real
<
T
>*
output_
;
int64_t
numel_
;
};
// angel function for real
template
<
typename
T
>
struct
AngleFunctor
<
T
,
NoComplex
<
T
,
Real
<
T
>>>
{
AngleFunctor
(
const
T
*
input
,
T
*
output
,
int64_t
numel
)
:
input_
(
input
),
output_
(
output
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
output_
[
idx
]
=
input_
[
idx
]
<
static_cast
<
T
>
(
0
)
?
M_PI
:
0
;
}
const
T
*
input_
;
T
*
output_
;
int64_t
numel_
;
};
template
<
typename
T
,
typename
Enable
=
void
>
struct
AngleGradFunctor
;
// angle grad for complex
template
<
typename
T
>
struct
AngleGradFunctor
<
T
,
Complex
<
T
,
Real
<
T
>>>
{
AngleGradFunctor
(
const
math
::
Real
<
T
>*
dout
,
const
T
*
x
,
T
*
dx
,
int64_t
numel
)
:
dout_
(
dout
),
x_
(
x
),
dx_
(
dx
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
if
(
x_
[
idx
]
==
T
(
0
))
{
dx_
[
idx
]
=
T
(
0
);
}
else
{
const
math
::
Real
<
T
>
r_square
=
x_
[
idx
].
real
*
x_
[
idx
].
real
+
x_
[
idx
].
imag
*
x_
[
idx
].
imag
;
dx_
[
idx
]
=
T
(
-
dout_
[
idx
]
*
x_
[
idx
].
imag
/
r_square
,
dout_
[
idx
]
*
x_
[
idx
].
real
/
r_square
);
}
}
const
math
::
Real
<
T
>*
dout_
;
const
T
*
x_
;
T
*
dx_
;
int64_t
numel_
;
};
// angle grad for real
template
<
typename
T
>
struct
AngleGradFunctor
<
T
,
NoComplex
<
T
,
Real
<
T
>>>
{
AngleGradFunctor
(
const
math
::
Real
<
T
>*
dout
,
const
T
*
x
,
T
*
dx
,
int64_t
numel
)
:
dout_
(
dout
),
x_
(
x
),
dx_
(
dx
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
dx_
[
idx
]
=
0
;
}
const
math
::
Real
<
T
>*
dout_
;
const
T
*
x_
;
T
*
dx_
;
int64_t
numel_
;
};
}
// namespace math
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
AngleKernel
:
public
framework
::
OpKernel
<
T
>
{
...
...
@@ -111,12 +36,12 @@ class AngleKernel : public framework::OpKernel<T> {
auto
numel
=
x
->
numel
();
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
math
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
x
->
numel
()
*
sizeof
(
math
::
Real
<
T
>
)));
auto
*
out_data
=
out
->
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
x
->
numel
()
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
numel
);
math
::
AngleFunctor
<
T
>
functor
(
x_data
,
out_data
,
numel
);
pten
::
funcs
::
AngleFunctor
<
T
>
functor
(
x_data
,
out_data
,
numel
);
for_range
(
functor
);
}
};
...
...
@@ -132,14 +57,14 @@ class AngleGradKernel : public framework::OpKernel<T> {
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
numel
=
d_out
->
numel
();
auto
*
dout_data
=
d_out
->
data
<
math
::
Real
<
T
>>
();
auto
*
dout_data
=
d_out
->
data
<
pten
::
funcs
::
Real
<
T
>>
();
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
dx_data
=
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
T
)));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
numel
);
math
::
AngleGradFunctor
<
T
>
functor
(
dout_data
,
x_data
,
dx_data
,
numel
);
pten
::
funcs
::
AngleGradFunctor
<
T
>
functor
(
dout_data
,
x_data
,
dx_data
,
numel
);
for_range
(
functor
);
}
};
...
...
paddle/fluid/operators/cholesky_solve_op.h
浏览文件 @
5b5656d0
...
...
@@ -64,7 +64,7 @@ void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx,
// calculate u's conjugate for complex
framework
::
Tensor
u_conj
(
u_bst
.
type
());
platform
::
ForRange
<
DeviceContext
>
u_for_range
(
dev_ctx
,
u_bst
.
numel
());
math
::
ConjFunctor
<
T
>
u_functor
(
pten
::
funcs
::
ConjFunctor
<
T
>
u_functor
(
u_bst
.
data
<
T
>
(),
u_bst
.
numel
(),
u_conj
.
mutable_data
<
T
>
(
u_bst
.
dims
(),
dev_ctx
.
GetPlace
()));
u_for_range
(
u_functor
);
...
...
@@ -73,7 +73,7 @@ void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx,
// calculate b's conjugate for complex
framework
::
Tensor
b_conj
(
b_bst
.
type
());
platform
::
ForRange
<
DeviceContext
>
b_for_range
(
dev_ctx
,
b_bst
.
numel
());
math
::
ConjFunctor
<
T
>
b_functor
(
pten
::
funcs
::
ConjFunctor
<
T
>
b_functor
(
b_bst
.
data
<
T
>
(),
b_bst
.
numel
(),
b_conj
.
mutable_data
<
T
>
(
b_bst
.
dims
(),
dev_ctx
.
GetPlace
()));
b_for_range
(
b_functor
);
...
...
@@ -113,7 +113,7 @@ void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx,
// calculate out's conjugate for complex
platform
::
ForRange
<
DeviceContext
>
out_for_range
(
dev_ctx
,
out
->
numel
());
math
::
ConjFunctor
<
T
>
out_functor
(
pten
::
funcs
::
ConjFunctor
<
T
>
out_functor
(
out
->
data
<
T
>
(),
out
->
numel
(),
out
->
mutable_data
<
T
>
(
out
->
dims
(),
dev_ctx
.
GetPlace
()));
out_for_range
(
out_functor
);
...
...
@@ -173,7 +173,7 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> {
// calculate out's conjugate for complex
framework
::
Tensor
out_conj
(
out
->
type
());
platform
::
ForRange
<
DeviceContext
>
out_for_range
(
dev_ctx
,
out
->
numel
());
math
::
ConjFunctor
<
T
>
out_functor
(
pten
::
funcs
::
ConjFunctor
<
T
>
out_functor
(
out
->
data
<
T
>
(),
out
->
numel
(),
out_conj
.
mutable_data
<
T
>
(
out
->
dims
(),
dev_ctx
.
GetPlace
()));
out_for_range
(
out_functor
);
...
...
@@ -195,7 +195,7 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> {
framework
::
Tensor
commonterm_conj
(
commonterm
.
type
());
platform
::
ForRange
<
DeviceContext
>
commonterm_for_range
(
dev_ctx
,
commonterm
.
numel
());
math
::
ConjFunctor
<
T
>
commonterm_functor
(
pten
::
funcs
::
ConjFunctor
<
T
>
commonterm_functor
(
commonterm
.
data
<
T
>
(),
commonterm
.
numel
(),
commonterm_conj
.
mutable_data
<
T
>
(
commonterm
.
dims
(),
dev_ctx
.
GetPlace
()));
...
...
paddle/fluid/operators/complex_op.h
浏览文件 @
5b5656d0
...
...
@@ -16,8 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/complex_view_op.h
浏览文件 @
5b5656d0
...
...
@@ -17,9 +17,9 @@
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/cumprod_op.cu
浏览文件 @
5b5656d0
...
...
@@ -14,9 +14,9 @@
#include <thrust/transform.h>
#include "paddle/fluid/operators/cumprod_op.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/inclusive_scan.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -243,12 +243,12 @@ class CumprodGradOpCUDAKernel : public framework::OpKernel<T> {
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range_x
(
dev_ctx
,
numel
);
math
::
ConjFunctor
<
T
>
functor_x
(
x_data
,
numel
,
x_data_conj
);
pten
::
funcs
::
ConjFunctor
<
T
>
functor_x
(
x_data
,
numel
,
x_data_conj
);
for_range_x
(
functor_x
);
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range_y
(
dev_ctx
,
numel
);
math
::
ConjFunctor
<
T
>
functor_y
(
y_data
,
numel
,
y_data_conj
);
pten
::
funcs
::
ConjFunctor
<
T
>
functor_y
(
y_data
,
numel
,
y_data_conj
);
for_range_y
(
functor_y
);
x_data_deal
=
x_data_conj
;
y_data_deal
=
y_data_conj
;
...
...
paddle/fluid/operators/cumprod_op.h
浏览文件 @
5b5656d0
...
...
@@ -18,8 +18,8 @@
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -124,12 +124,12 @@ class CumprodGradOpCPUKernel : public framework::OpKernel<T> {
platform
::
ForRange
<
platform
::
CPUDeviceContext
>
for_range_x
(
dev_ctx
,
numel
);
math
::
ConjFunctor
<
T
>
functor_x
(
x_data
,
numel
,
x_data_conj
);
pten
::
funcs
::
ConjFunctor
<
T
>
functor_x
(
x_data
,
numel
,
x_data_conj
);
for_range_x
(
functor_x
);
platform
::
ForRange
<
platform
::
CPUDeviceContext
>
for_range_out
(
dev_ctx
,
numel
);
math
::
ConjFunctor
<
T
>
functor_out
(
out_data
,
numel
,
out_data_conj
);
pten
::
funcs
::
ConjFunctor
<
T
>
functor_out
(
out_data
,
numel
,
out_data_conj
);
for_range_out
(
functor_out
);
x_data_deal
=
x_data_conj
;
...
...
paddle/fluid/operators/determinant_op.h
浏览文件 @
5b5656d0
...
...
@@ -19,11 +19,11 @@
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/matrix_inverse.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -395,7 +395,7 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
size_t
(
numel
*
sizeof
(
T
)));
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
numel
);
math
::
ConjFunctor
<
T
>
functor
(
inverse_A
.
data
<
T
>
(),
numel
,
conj_data
);
pten
::
funcs
::
ConjFunctor
<
T
>
functor
(
inverse_A
.
data
<
T
>
(),
numel
,
conj_data
);
for_range
(
functor
);
VLOG
(
3
)
<<
"inverse(A).conj() dims: "
<<
conj_inverse_A
.
dims
();
...
...
paddle/fluid/operators/dot_op.h
浏览文件 @
5b5656d0
...
...
@@ -16,8 +16,8 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
...
...
paddle/fluid/operators/eig_op.h
浏览文件 @
5b5656d0
...
...
@@ -17,12 +17,12 @@
#include <math.h>
#include <algorithm>
#include <complex>
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/lapack_function.h"
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
#include "paddle/pten/kernels/funcs/math_function.h"
#define EPSILON 1e-6
...
...
@@ -87,18 +87,19 @@ void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info,
int
values_stride
=
values
->
dims
()[
values
->
dims
().
size
()
-
1
];
Tensor
rwork
;
math
::
Real
<
T
>*
rwork_data
=
nullptr
;
pten
::
funcs
::
Real
<
T
>*
rwork_data
=
nullptr
;
rwork
.
Resize
(
framework
::
make_ddim
({
lda
*
2
}));
rwork_data
=
rwork
.
mutable_data
<
math
::
Real
<
T
>>
(
context
.
GetPlace
());
rwork_data
=
rwork
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
context
.
GetPlace
());
// call lapackEig once to compute the size of work;
T
computed_work_size
;
math
::
lapackEig
<
T
,
math
::
Real
<
T
>>
(
math
::
lapackEig
<
T
,
pten
::
funcs
::
Real
<
T
>>
(
jobvl
,
jobvr
,
order
,
input_data
,
lda
,
values_data
,
lvector_data
,
ldvl
,
rvector_data
,
ldvr
,
&
computed_work_size
,
lwork
,
rwork_data
,
&
info
);
lwork
=
std
::
max
<
int
>
(
1
,
static_cast
<
int
>
(
math
::
Real
<
T
>
(
computed_work_size
)));
lwork
=
std
::
max
<
int
>
(
1
,
static_cast
<
int
>
(
pten
::
funcs
::
Real
<
T
>
(
computed_work_size
)));
Tensor
work
;
work
.
Resize
(
framework
::
make_ddim
({
lwork
}));
T
*
work_data
=
work
.
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
@@ -108,7 +109,7 @@ void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info,
T
*
current_values
=
&
values_data
[
i
*
values_stride
];
T
*
current_rvectors
=
&
rvector_data
[
i
*
matrix_stride
];
math
::
lapackEig
<
T
,
math
::
Real
<
T
>>
(
math
::
lapackEig
<
T
,
pten
::
funcs
::
Real
<
T
>>
(
jobvl
,
jobvr
,
order
,
current_matrix
,
lda
,
current_values
,
lvector_data
,
ldvl
,
current_rvectors
,
ldvr
,
work_data
,
lwork
,
rwork_data
,
&
info
);
PADDLE_ENFORCE_EQ
(
...
...
@@ -207,26 +208,27 @@ class EigKernel : public framework::OpKernel<T> {
origin_dim
.
push_back
(
last_item
*
2
);
framework
::
DDim
big_dim
=
framework
::
make_ddim
(
origin_dim
);
real_values
.
mutable_data
<
math
::
Real
<
T
>>
(
big_dim
,
context
.
GetPlace
());
real_vectors
.
mutable_data
<
math
::
Real
<
T
>>
(
x
->
dims
(),
context
.
GetPlace
());
real_values
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
big_dim
,
context
.
GetPlace
());
real_vectors
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
x
->
dims
(),
context
.
GetPlace
());
ApplyEigKernel
<
DeviceContext
,
math
::
Real
<
T
>>
(
*
x
,
&
real_values
,
&
real_vectors
,
context
);
auto
dito
=
math
::
DeviceIndependenceTensorOperations
<
DeviceContext
,
math
::
Real
<
T
>
,
Tout
>
(
context
);
ApplyEigKernel
<
DeviceContext
,
pten
::
funcs
::
Real
<
T
>>
(
*
x
,
&
real_values
,
&
real_vectors
,
context
);
auto
dito
=
math
::
DeviceIndependenceTensorOperations
<
DeviceContext
,
pten
::
funcs
::
Real
<
T
>
,
Tout
>
(
context
);
// 1. extract real part & imag part from real_values
Tensor
real_part
=
dito
.
Slice
(
real_values
,
{
-
1
},
{
0
},
{
order
});
Tensor
imag_part
=
dito
.
Slice
(
real_values
,
{
-
1
},
{
order
},
{
order
*
2
});
// 2. construct complex values
auto
*
real_part_data
=
real_part
.
data
<
math
::
Real
<
T
>>
();
auto
*
imag_part_data
=
imag_part
.
data
<
math
::
Real
<
T
>>
();
auto
*
real_part_data
=
real_part
.
data
<
pten
::
funcs
::
Real
<
T
>>
();
auto
*
imag_part_data
=
imag_part
.
data
<
pten
::
funcs
::
Real
<
T
>>
();
int
out_values_numel
=
out_values
->
numel
();
platform
::
ForRange
<
DeviceContext
>
for_range
(
context
.
template
device_context
<
DeviceContext
>(),
out_values_numel
);
math
::
RealImagToComplexFunctor
<
Tout
>
functor
(
pten
::
funcs
::
RealImagToComplexFunctor
<
Tout
>
functor
(
real_part_data
,
imag_part_data
,
out_values
->
mutable_data
<
Tout
>
(
context
.
GetPlace
()),
out_values_numel
);
for_range
(
functor
);
...
...
@@ -235,7 +237,7 @@ class EigKernel : public framework::OpKernel<T> {
Tensor
real_vector_trans
=
dito
.
Transpose
(
real_vectors
);
Tensor
out_vectors_trans
;
out_vectors_trans
.
mutable_data
<
Tout
>
(
x
->
dims
(),
context
.
GetPlace
());
ConstructComplexVectors
<
math
::
Real
<
T
>
,
Tout
>
(
ConstructComplexVectors
<
pten
::
funcs
::
Real
<
T
>
,
Tout
>
(
&
out_vectors_trans
,
*
out_values
,
real_vector_trans
,
context
,
batch_count
,
order
);
TransposeTwoAxis
<
DeviceContext
,
Tout
>
(
out_vectors_trans
,
out_vectors
,
...
...
@@ -271,14 +273,14 @@ void ComputeBackwardForComplexInput(
// turn diag_unsqueezed into complex
auto
numel
=
diag_unsqueezed
.
numel
();
Tensor
diag_unsqueezed_complex
;
auto
*
data_diag_un
=
diag_unsqueezed
.
data
<
math
::
Real
<
Tout
>>
();
auto
*
data_diag_un
=
diag_unsqueezed
.
data
<
pten
::
funcs
::
Real
<
Tout
>>
();
auto
*
data_diag_un_com
=
diag_unsqueezed_complex
.
mutable_data
<
Tout
>
(
diag_unsqueezed
.
dims
(),
context
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
Tout
)));
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
numel
);
math
::
RealToComplexFunctor
<
Tout
>
functor
(
data_diag_un
,
data_diag_un_com
,
numel
);
pten
::
funcs
::
RealToComplexFunctor
<
Tout
>
functor
(
data_diag_un
,
data_diag_un_com
,
numel
);
for_range
(
functor
);
// real tensor multiply complex tensor in broadcast manner
Tensor
res1
=
dito
.
RealMulComplex
(
V
,
diag_unsqueezed_complex
);
...
...
paddle/fluid/operators/eigh_op.h
浏览文件 @
5b5656d0
...
...
@@ -40,7 +40,7 @@ template <typename DeviceContext, typename T>
class
EighGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
ValueType
=
math
::
Real
<
T
>
;
using
ValueType
=
pten
::
funcs
::
Real
<
T
>
;
auto
&
x_grad
=
*
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
x_grad
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
output_w
=
*
ctx
.
Input
<
Tensor
>
(
"Eigenvalues"
);
...
...
paddle/fluid/operators/eigvals_op.h
浏览文件 @
5b5656d0
...
...
@@ -20,9 +20,9 @@
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/lapack_function.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -48,7 +48,7 @@ struct PaddleComplex<
template
<
typename
T
>
using
PaddleCType
=
typename
PaddleComplex
<
T
>::
type
;
template
<
typename
T
>
using
Real
=
typename
math
::
Real
<
T
>
;
using
Real
=
typename
pten
::
funcs
::
Real
<
T
>
;
static
void
SpiltBatchSquareMatrix
(
const
Tensor
&
input
,
std
::
vector
<
Tensor
>*
output
)
{
...
...
@@ -118,7 +118,7 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
.
template
device_context
<
DeviceContext
>(),
n_dim
);
math
::
RealImagToComplexFunctor
<
PaddleCType
<
T
>>
functor
(
pten
::
funcs
::
RealImagToComplexFunctor
<
PaddleCType
<
T
>>
functor
(
w_data
,
w_data
+
n_dim
,
output
->
template
data
<
PaddleCType
<
T
>
>
(),
n_dim
);
for_range
(
functor
);
}
...
...
@@ -143,7 +143,7 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
required_work_mem
,
work_mem
));
int64_t
rwork_mem
=
rwork
->
memory_size
();
int64_t
required_rwork_mem
=
(
n_dim
<<
1
)
*
sizeof
(
Real
<
T
>
);
int64_t
required_rwork_mem
=
(
n_dim
<<
1
)
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
);
PADDLE_ENFORCE_GE
(
rwork_mem
,
required_rwork_mem
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -153,11 +153,11 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
required_rwork_mem
,
rwork_mem
));
int
info
=
0
;
math
::
lapackEig
<
T
,
Real
<
T
>>
(
math
::
lapackEig
<
T
,
pten
::
funcs
::
Real
<
T
>>
(
'N'
,
'N'
,
static_cast
<
int
>
(
n_dim
),
a
.
template
data
<
T
>(),
static_cast
<
int
>
(
n_dim
),
output
->
template
data
<
T
>(),
NULL
,
1
,
NULL
,
1
,
work
->
template
data
<
T
>(),
static_cast
<
int
>
(
work_mem
/
sizeof
(
T
)),
rwork
->
template
data
<
Real
<
T
>
>
(),
&
info
);
rwork
->
template
data
<
pten
::
funcs
::
Real
<
T
>
>
(),
&
info
);
std
::
string
name
=
"framework::platform::dynload::cgeev_"
;
if
(
framework
::
TransToProtoVarType
(
input
.
dtype
())
==
...
...
@@ -187,10 +187,10 @@ class EigvalsKernel : public framework::OpKernel<T> {
// query workspace size
T
qwork
;
int
info
;
math
::
lapackEig
<
T
,
Real
<
T
>>
(
'N'
,
'N'
,
static_cast
<
int
>
(
n_dim
),
input_matrices
[
0
].
template
data
<
T
>(),
static_cast
<
int
>
(
n_dim
),
NULL
,
NULL
,
1
,
NULL
,
1
,
&
qwork
,
-
1
,
static_cast
<
Real
<
T
>*>
(
NULL
),
&
info
);
math
::
lapackEig
<
T
,
pten
::
funcs
::
Real
<
T
>>
(
'N'
,
'N'
,
static_cast
<
int
>
(
n_dim
),
input_matrices
[
0
].
template
data
<
T
>(),
static_cast
<
int
>
(
n_dim
),
NULL
,
NULL
,
1
,
NULL
,
1
,
&
qwork
,
-
1
,
static_cast
<
pten
::
funcs
::
Real
<
T
>*>
(
NULL
),
&
info
);
int64_t
lwork
=
static_cast
<
int64_t
>
(
qwork
);
Tensor
work
,
rwork
;
...
...
@@ -207,8 +207,8 @@ class EigvalsKernel : public framework::OpKernel<T> {
}
if
(
framework
::
IsComplexType
(
framework
::
TransToProtoVarType
(
input
->
dtype
())))
{
rwork
.
mutable_data
<
Real
<
T
>>
(
framework
::
make_ddim
({
n_dim
<<
1
}),
ctx
.
GetPlace
());
rwork
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
framework
::
make_ddim
({
n_dim
<<
1
}),
ctx
.
GetPlace
());
}
for
(
int64_t
i
=
0
;
i
<
n_batch
;
++
i
)
{
...
...
paddle/fluid/operators/imag_op.h
浏览文件 @
5b5656d0
...
...
@@ -16,8 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -31,12 +31,13 @@ class ImagKernel : public framework::OpKernel<T> {
auto
numel
=
x
->
numel
();
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
math
::
Real
<
T
>>
(
ctx
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
math
::
Real
<
T
>
)));
auto
*
out_data
=
out
->
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
ctx
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
numel
);
math
::
ImagFunctor
<
T
>
functor
(
x_data
,
out_data
,
numel
);
pten
::
funcs
::
ImagFunctor
<
T
>
functor
(
x_data
,
out_data
,
numel
);
for_range
(
functor
);
}
};
...
...
@@ -51,13 +52,13 @@ class ImagGradKernel : public framework::OpKernel<T> {
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
numel
=
d_out
->
numel
();
auto
*
dout_data
=
d_out
->
data
<
math
::
Real
<
T
>>
();
auto
*
dout_data
=
d_out
->
data
<
pten
::
funcs
::
Real
<
T
>>
();
auto
*
dx_data
=
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
T
)));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
numel
);
math
::
ImagToComplexFunctor
<
T
>
functor
(
dout_data
,
dx_data
,
numel
);
pten
::
funcs
::
ImagToComplexFunctor
<
T
>
functor
(
dout_data
,
dx_data
,
numel
);
for_range
(
functor
);
}
};
...
...
paddle/fluid/operators/lstsq_op.h
浏览文件 @
5b5656d0
...
...
@@ -18,7 +18,6 @@
#include <algorithm>
#include <complex>
#include "paddle/fluid/operators/eig_op.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/eigen_values_vectors.h"
#include "paddle/fluid/operators/math/lapack_function.h"
#include "paddle/fluid/operators/math/matrix_solve.h"
...
...
@@ -26,6 +25,7 @@
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
#include "paddle/pten/kernels/funcs/math_function.h"
#define EPSILON 1e-6
...
...
@@ -46,7 +46,7 @@ template <typename DeviceContext, typename T>
class
LstsqCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
using
ValueType
=
math
::
Real
<
T
>
;
using
ValueType
=
pten
::
funcs
::
Real
<
T
>
;
const
Tensor
&
x
=
*
context
.
Input
<
Tensor
>
(
"X"
);
auto
y
=
context
.
Input
<
Tensor
>
(
"Y"
);
...
...
@@ -169,7 +169,7 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
&
rwkopt
,
&
info
);
}
lwork
=
std
::
max
<
int
>
(
1
,
static_cast
<
int
>
(
math
::
Real
<
T
>
(
wkopt
)));
lwork
=
std
::
max
<
int
>
(
1
,
static_cast
<
int
>
(
pten
::
funcs
::
Real
<
T
>
(
wkopt
)));
Tensor
work
;
work
.
Resize
(
framework
::
make_ddim
({
lwork
}));
T
*
work_data
=
work
.
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/fluid/operators/lu_op.h
浏览文件 @
5b5656d0
...
...
@@ -211,8 +211,9 @@ void Tensor_Conj(const DeviceContext& dev_ctx, const framework::Tensor& tensor,
framework
::
Tensor
*
out
)
{
out
->
Resize
(
tensor
.
dims
());
platform
::
ForRange
<
DeviceContext
>
out_for_range
(
dev_ctx
,
tensor
.
numel
());
math
::
ConjFunctor
<
T
>
out_functor
(
tensor
.
data
<
T
>
(),
tensor
.
numel
(),
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
()));
pten
::
funcs
::
ConjFunctor
<
T
>
out_functor
(
tensor
.
data
<
T
>
(),
tensor
.
numel
(),
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
()));
out_for_range
(
out_functor
);
}
...
...
paddle/fluid/operators/math/eigen_values_vectors.h
浏览文件 @
5b5656d0
...
...
@@ -63,7 +63,7 @@ struct MatrixEighFunctor<platform::CPUDeviceContext, T> {
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
input
,
Tensor
*
eigen_values
,
Tensor
*
eigen_vectors
,
bool
is_lower
,
bool
has_vectors
)
{
using
ValueType
=
math
::
Real
<
T
>
;
using
ValueType
=
pten
::
funcs
::
Real
<
T
>
;
auto
*
out_value
=
eigen_values
->
mutable_data
<
ValueType
>
(
ctx
.
GetPlace
());
auto
dito
=
...
...
@@ -123,9 +123,9 @@ struct MatrixEighFunctor<platform::CPUDeviceContext, T> {
for
(
auto
i
=
0
;
i
<
batch_size
;
i
++
)
{
auto
*
value_data
=
out_value
+
i
*
values_stride
;
auto
*
input_data
=
input_vector
+
i
*
vector_stride
;
math
::
lapackEigh
<
T
,
Real
<
T
>>
(
jobz
,
uplo
,
n
,
input_data
,
lda
,
value_data
,
work_data
,
lwork
,
rwork_data
,
lr
work
,
iwork_data
,
liwork
,
&
info
);
math
::
lapackEigh
<
T
,
pten
::
funcs
::
Real
<
T
>>
(
jobz
,
uplo
,
n
,
input_data
,
lda
,
value_data
,
work_data
,
l
work
,
rwork_data
,
lrwork
,
iwork_data
,
liwork
,
&
info
);
CheckEighResult
(
i
,
info
);
}
if
(
has_vectors
)
{
...
...
@@ -151,7 +151,7 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, T> {
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
input
,
Tensor
*
eigen_values
,
Tensor
*
eigen_vectors
,
bool
is_lower
,
bool
has_vectors
)
{
using
ValueType
=
math
::
Real
<
T
>
;
using
ValueType
=
pten
::
funcs
::
Real
<
T
>
;
auto
*
out_value
=
eigen_values
->
mutable_data
<
ValueType
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
...
...
@@ -233,7 +233,7 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, T> {
}
}
using
ValueType
=
math
::
Real
<
T
>
;
using
ValueType
=
pten
::
funcs
::
Real
<
T
>
;
inline
void
EvdBuffer
(
cusolverDnHandle_t
handle
,
cusolverEigMode_t
jobz
,
cublasFillMode_t
uplo
,
int
n
,
const
T
*
A
,
int
lda
,
const
ValueType
*
W
,
int
*
lwork
)
const
;
...
...
paddle/fluid/operators/math/inclusive_scan.h
浏览文件 @
5b5656d0
...
...
@@ -26,9 +26,9 @@ namespace cub = hipcub;
#include <thrust/iterator/reverse_iterator.h>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -115,7 +115,7 @@ static __global__ void InclusiveScanInnerDimCUDAKernel(const T *x, T *y,
size_t
num_rows
,
size_t
row_size
,
T
init
,
BinaryOp
op
)
{
using
RealT
=
math
::
Real
<
T
>
;
using
RealT
=
pten
::
funcs
::
Real
<
T
>
;
constexpr
auto
kSharedBufferSize
=
framework
::
IsComplex
<
T
>::
value
?
4
*
kThreadNumX
:
2
*
kThreadNumX
;
__shared__
RealT
sbuf
[
kThreadNumY
][
kSharedBufferSize
];
...
...
paddle/fluid/operators/matmul_v2_op.h
浏览文件 @
5b5656d0
...
...
@@ -22,8 +22,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/dot_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
...
...
paddle/fluid/operators/matrix_rank_op.cu
浏览文件 @
5b5656d0
...
...
@@ -18,11 +18,11 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/matrix_rank_op.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
#include "paddle/pten/kernels/funcs/math_function.h"
namespace
paddle
{
...
...
@@ -93,8 +93,8 @@ class MatrixRankGPUKernel : public framework::OpKernel<T> {
info_ptr
);
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
dev_ctx
,
eigenvalue_tensor
.
numel
());
math
::
AbsFunctor
<
T
>
functor
(
eigenvalue_data
,
eigenvalue_data
,
eigenvalue_tensor
.
numel
());
pten
::
funcs
::
AbsFunctor
<
T
>
functor
(
eigenvalue_data
,
eigenvalue_data
,
eigenvalue_tensor
.
numel
());
for_range
(
functor
);
}
else
{
Tensor
U
,
VH
;
...
...
paddle/fluid/operators/qr_op.cu
浏览文件 @
5b5656d0
...
...
@@ -56,12 +56,13 @@ class QrGPUKernel : public framework::OpKernel<T> {
int
tau_stride
=
min_mn
;
if
(
compute_q
)
{
q
.
mutable_data
<
math
::
Real
<
T
>>
(
q
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batch_size
*
m
*
k
*
sizeof
(
math
::
Real
<
T
>
)));
size_t
(
batch_size
*
m
*
k
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
}
r
.
mutable_data
<
math
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batch_size
*
k
*
n
*
sizeof
(
math
::
Real
<
T
>
)));
r
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batch_size
*
k
*
n
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
auto
dito
=
math
::
DeviceIndependenceTensorOperations
<
platform
::
CUDADeviceContext
,
...
...
@@ -70,8 +71,9 @@ class QrGPUKernel : public framework::OpKernel<T> {
// Note: allocate temporary tensors because of lacking in-place operatios.
// Prepare qr
Tensor
qr
;
qr
.
mutable_data
<
math
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batch_size
*
m
*
n
*
sizeof
(
math
::
Real
<
T
>
)));
qr
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batch_size
*
m
*
n
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
// BatchedGeqrf performs computation in-place and 'qr' must be a copy of
// input
paddle
::
framework
::
TensorCopy
(
x
,
context
.
GetPlace
(),
&
qr
);
...
...
@@ -124,7 +126,8 @@ class QrGPUKernel : public framework::OpKernel<T> {
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
memory
::
Copy
(
dev_ctx
.
GetPlace
(),
(
new_qr_data
+
i
*
new_qr_stride
),
dev_ctx
.
GetPlace
(),
(
qr_data
+
i
*
qr_stride
),
qr_stride
*
sizeof
(
math
::
Real
<
T
>
),
dev_ctx
.
stream
());
qr_stride
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
),
dev_ctx
.
stream
());
}
BatchedOrgqr
<
platform
::
CUDADeviceContext
,
T
>
(
dev_ctx
,
batch_size
,
m
,
m
,
min_mn
,
new_qr_data
,
m
,
tau_data
,
...
...
paddle/fluid/operators/qr_op.h
浏览文件 @
5b5656d0
...
...
@@ -18,9 +18,9 @@
#include <cstdarg>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -74,17 +74,20 @@ class QrCPUKernel : public framework::OpKernel<T> {
int
q_stride
=
m
*
k
;
int
r_stride
=
k
*
n
;
auto
*
x_data
=
x
.
data
<
math
::
Real
<
T
>>
();
auto
*
x_data
=
x
.
data
<
pten
::
funcs
::
Real
<
T
>>
();
T
*
q_data
=
nullptr
;
if
(
compute_q
)
{
q_data
=
q
.
mutable_data
<
math
::
Real
<
T
>>
(
q_data
=
q
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batch_size
*
m
*
k
*
sizeof
(
math
::
Real
<
T
>
)));
memset
(
q_data
,
0
,
size_t
(
batch_size
*
m
*
k
*
sizeof
(
math
::
Real
<
T
>
)));
size_t
(
batch_size
*
m
*
k
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
memset
(
q_data
,
0
,
size_t
(
batch_size
*
m
*
k
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
}
auto
*
r_data
=
r
.
mutable_data
<
math
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batch_size
*
k
*
n
*
sizeof
(
math
::
Real
<
T
>
)));
memset
(
r_data
,
0
,
size_t
(
batch_size
*
k
*
n
*
sizeof
(
math
::
Real
<
T
>
)));
auto
*
r_data
=
r
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batch_size
*
k
*
n
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
memset
(
r_data
,
0
,
size_t
(
batch_size
*
k
*
n
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
// Implement QR by calling Eigen
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
...
...
@@ -140,7 +143,7 @@ class QrGradKernel : public framework::OpKernel<T> {
// Use a different name dA instead of dX
framework
::
Tensor
&
dA
=
*
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
dA
.
mutable_data
<
math
::
Real
<
T
>>
(
ctx
.
GetPlace
());
dA
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
pten
::
funcs
::
SetConstant
<
DeviceContext
,
T
>
()(
dev_ctx
,
&
dA
,
T
(
0
));
...
...
@@ -222,7 +225,7 @@ class QrGradKernel : public framework::OpKernel<T> {
}
else
{
// If m < n for input matrices A, we partition A = [X|Y] and R = [U|V]
// Calculate dX and dY individually and concatenate them to get dA
dA
.
mutable_data
<
math
::
Real
<
T
>>
(
ctx
.
GetPlace
());
dA
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
ctx
.
GetPlace
());
auto
Y
=
dito
.
Slice
(
A
,
{
-
1
},
{
m
},
{
n
});
auto
U
=
dito
.
Slice
(
R
,
{
-
1
},
{
0
},
{
m
});
...
...
paddle/fluid/operators/real_op.h
浏览文件 @
5b5656d0
...
...
@@ -16,8 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -31,12 +31,13 @@ class RealKernel : public framework::OpKernel<T> {
auto
numel
=
x
->
numel
();
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
math
::
Real
<
T
>>
(
ctx
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
math
::
Real
<
T
>
)));
auto
*
out_data
=
out
->
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
ctx
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
numel
);
math
::
RealFunctor
<
T
>
functor
(
x_data
,
out_data
,
numel
);
pten
::
funcs
::
RealFunctor
<
T
>
functor
(
x_data
,
out_data
,
numel
);
for_range
(
functor
);
}
};
...
...
@@ -51,13 +52,13 @@ class RealGradKernel : public framework::OpKernel<T> {
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
numel
=
d_out
->
numel
();
auto
*
dout_data
=
d_out
->
data
<
math
::
Real
<
T
>>
();
auto
*
dout_data
=
d_out
->
data
<
pten
::
funcs
::
Real
<
T
>>
();
auto
*
dx_data
=
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
T
)));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
numel
);
math
::
RealToComplexFunctor
<
T
>
functor
(
dout_data
,
dx_data
,
numel
);
pten
::
funcs
::
RealToComplexFunctor
<
T
>
functor
(
dout_data
,
dx_data
,
numel
);
for_range
(
functor
);
}
};
...
...
paddle/fluid/operators/renorm_op.h
浏览文件 @
5b5656d0
...
...
@@ -17,8 +17,8 @@
#include "math.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
...
...
paddle/fluid/operators/spectral_op.cu
浏览文件 @
5b5656d0
...
...
@@ -20,11 +20,11 @@
#include <vector>
#include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/spectral_helper.h"
#include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -115,8 +115,8 @@ void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config,
framework
::
Tensor
input_conj
(
input
->
type
());
input_conj
.
mutable_data
<
Ti
>
(
input
->
dims
(),
ctx
.
GetPlace
());
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
,
input
->
numel
());
math
::
ConjFunctor
<
Ti
>
functor
(
input
->
data
<
Ti
>
(),
input
->
numel
(),
input_conj
.
data
<
Ti
>
());
pten
::
funcs
::
ConjFunctor
<
Ti
>
functor
(
input
->
data
<
Ti
>
(),
input
->
numel
(),
input_conj
.
data
<
Ti
>
());
for_range
(
functor
);
exec_cufft_plan_raw
(
config
,
input_conj
.
data
(),
output
->
data
(),
forward
);
}
else
if
(
fft_type
==
FFTTransformType
::
R2C
&&
!
forward
)
{
...
...
@@ -126,8 +126,8 @@ void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config,
exec_cufft_plan_raw
(
config
,
input
->
data
(),
out_conj
.
data
(),
forward
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
,
output
->
numel
());
math
::
ConjFunctor
<
To
>
functor
(
out_conj
.
data
<
To
>
(),
output
->
numel
(),
output
->
data
<
To
>
());
pten
::
funcs
::
ConjFunctor
<
To
>
functor
(
out_conj
.
data
<
To
>
(),
output
->
numel
(),
output
->
data
<
To
>
());
for_range
(
functor
);
}
else
{
exec_cufft_plan_raw
(
config
,
input
->
data
(),
output
->
data
(),
forward
);
...
...
@@ -227,8 +227,8 @@ void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config,
framework
::
Tensor
input_conj
(
input
->
type
());
input_conj
.
mutable_data
<
Ti
>
(
input
->
dims
(),
ctx
.
GetPlace
());
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
,
input
->
numel
());
math
::
ConjFunctor
<
Ti
>
functor
(
input
->
data
<
Ti
>
(),
input
->
numel
(),
input_conj
.
data
<
Ti
>
());
pten
::
funcs
::
ConjFunctor
<
Ti
>
functor
(
input
->
data
<
Ti
>
(),
input
->
numel
(),
input_conj
.
data
<
Ti
>
());
for_range
(
functor
);
exec_hipfft_plan_raw
(
config
,
input_conj
.
data
(),
output
->
data
(),
forward
);
}
else
if
(
fft_type
==
FFTTransformType
::
R2C
&&
!
forward
)
{
...
...
@@ -238,8 +238,8 @@ void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config,
exec_hipfft_plan_raw
(
config
,
input
->
data
(),
out_conj
.
data
(),
forward
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
,
output
->
numel
());
math
::
ConjFunctor
<
To
>
functor
(
out_conj
.
data
<
To
>
(),
output
->
numel
(),
output
->
data
<
To
>
());
pten
::
funcs
::
ConjFunctor
<
To
>
functor
(
out_conj
.
data
<
To
>
(),
output
->
numel
(),
output
->
data
<
To
>
());
for_range
(
functor
);
}
else
{
exec_hipfft_plan_raw
(
config
,
input
->
data
(),
output
->
data
(),
forward
);
...
...
paddle/fluid/operators/svd_helper.h
浏览文件 @
5b5656d0
...
...
@@ -25,9 +25,9 @@
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
#include "paddle/pten/kernels/funcs/math_function.h"
namespace
paddle
{
...
...
@@ -105,7 +105,8 @@ struct RealMulComplexFunctor {
"The image part of y must to be 0"
"but got [%d]"
,
y
.
imag
));
return
platform
::
complex
<
Real
<
T
>>
(
x
.
real
*
y
.
real
,
x
.
imag
*
y
.
real
);
return
platform
::
complex
<
pten
::
funcs
::
Real
<
T
>>
(
x
.
real
*
y
.
real
,
x
.
imag
*
y
.
real
);
}
};
...
...
@@ -390,11 +391,11 @@ struct DeviceIndependenceTensorOperations {
// batch_diag for CPU only
Tensor
BatchDiag
(
const
Tensor
&
x
,
int
batch
)
{
Tensor
out
;
auto
*
x_data
=
x
.
data
<
math
::
Real
<
T
>>
();
auto
*
x_data
=
x
.
data
<
pten
::
funcs
::
Real
<
T
>>
();
auto
numel
=
x
.
numel
();
auto
*
out_data
=
out
.
mutable_data
<
math
::
Real
<
T
>>
(
auto
*
out_data
=
out
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
x
.
dims
(),
context
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
math
::
Real
<
T
>
)));
static_cast
<
size_t
>
(
numel
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
auto
x_dims
=
x
.
dims
();
int
num_dims
=
x_dims
.
size
();
...
...
@@ -654,7 +655,7 @@ struct DeviceIndependenceTensorOperations {
auto
*
out_data
=
out
.
mutable_data
<
T
>
(
x
.
dims
(),
context
.
GetPlace
());
auto
*
x_data
=
x
.
data
<
T
>
();
auto
for_range
=
GetForRange
(
x
.
numel
());
math
::
ConjFunctor
<
T
>
functor
(
x_data
,
x
.
numel
(),
out_data
);
pten
::
funcs
::
ConjFunctor
<
T
>
functor
(
x_data
,
x
.
numel
(),
out_data
);
for_range
(
functor
);
return
out
;
}
...
...
@@ -662,12 +663,12 @@ struct DeviceIndependenceTensorOperations {
Tensor
Real
(
const
Tensor
&
x
)
{
Tensor
out
;
auto
numel
=
x
.
numel
();
auto
*
out_data
=
out
.
mutable_data
<
math
::
Real
<
T
>>
(
auto
*
out_data
=
out
.
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
x
.
dims
(),
context
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
math
::
Real
<
T
>
)));
static_cast
<
size_t
>
(
numel
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
auto
*
x_data
=
x
.
data
<
T
>
();
auto
for_range
=
GetForRange
(
numel
);
math
::
RealFunctor
<
T
>
functor
(
x_data
,
out_data
,
numel
);
pten
::
funcs
::
RealFunctor
<
T
>
functor
(
x_data
,
out_data
,
numel
);
for_range
(
functor
);
return
out
;
}
...
...
paddle/fluid/operators/svd_op.h
浏览文件 @
5b5656d0
...
...
@@ -17,9 +17,9 @@
#include <cstdarg>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -46,14 +46,14 @@ class SvdCPUKernel : public framework::OpKernel<T> {
int
col_u
=
full
?
rows
:
k
;
int
col_v
=
full
?
cols
:
k
;
int
batches
=
numel
/
(
rows
*
cols
);
auto
*
U_out
=
U
->
mutable_data
<
math
::
Real
<
T
>>
(
auto
*
U_out
=
U
->
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batches
*
rows
*
col_u
*
sizeof
(
math
::
Real
<
T
>
)));
auto
*
VH_out
=
VH
->
mutable_data
<
math
::
Real
<
T
>>
(
size_t
(
batches
*
rows
*
col_u
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
auto
*
VH_out
=
VH
->
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batches
*
col_v
*
cols
*
sizeof
(
math
::
Real
<
T
>
)));
auto
*
S_out
=
S
->
mutable_data
<
math
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batches
*
k
*
sizeof
(
math
::
Real
<
T
>
)));
size_t
(
batches
*
col_v
*
cols
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
auto
*
S_out
=
S
->
mutable_data
<
pten
::
funcs
::
Real
<
T
>>
(
context
.
GetPlace
(),
size_t
(
batches
*
k
*
sizeof
(
pten
::
funcs
::
Real
<
T
>
)));
/*SVD Use the Eigen Library*/
math
::
BatchSvd
<
T
>
(
x_data
,
U_out
,
VH_out
,
S_out
,
rows
,
cols
,
batches
,
full
);
}
...
...
paddle/fluid/operators/triangular_solve_op.h
浏览文件 @
5b5656d0
...
...
@@ -19,10 +19,10 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/operators/tril_triu_op.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -152,7 +152,7 @@ class TriangularSolveGradKernel : public framework::OpKernel<T> {
// calculate x's conjugate for complex
Tensor
x_conj
(
x
->
type
());
platform
::
ForRange
<
DeviceContext
>
x_for_range
(
dev_ctx
,
x
->
numel
());
math
::
ConjFunctor
<
T
>
x_functor
(
pten
::
funcs
::
ConjFunctor
<
T
>
x_functor
(
x
->
data
<
T
>
(),
x
->
numel
(),
x_conj
.
mutable_data
<
T
>
(
x
->
dims
(),
dev_ctx
.
GetPlace
()));
x_for_range
(
x_functor
);
...
...
@@ -179,7 +179,7 @@ class TriangularSolveGradKernel : public framework::OpKernel<T> {
// calculate out's conjugate for complex
Tensor
out_conj
(
out
->
type
());
platform
::
ForRange
<
DeviceContext
>
out_for_range
(
dev_ctx
,
out
->
numel
());
math
::
ConjFunctor
<
T
>
out_functor
(
pten
::
funcs
::
ConjFunctor
<
T
>
out_functor
(
out
->
data
<
T
>
(),
out
->
numel
(),
out_conj
.
mutable_data
<
T
>
(
out
->
dims
(),
dev_ctx
.
GetPlace
()));
out_for_range
(
out_functor
);
...
...
paddle/pten/kernels/cpu/abs_grad_kernel.cc
浏览文件 @
5b5656d0
...
...
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/pten/common/complex.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
#include "paddle/pten/kernels/impl/abs_grad_kernel_impl.h"
using
pten
::
dtype
::
complex
;
...
...
paddle/pten/kernels/cpu/abs_kernel.cc
浏览文件 @
5b5656d0
...
...
@@ -13,11 +13,11 @@
// limitations under the License.
#include "paddle/pten/kernels/abs_kernel.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/common/complex.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
pten
{
...
...
@@ -25,12 +25,12 @@ template <typename T, typename Context>
void
AbsKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
)
{
auto
numel
=
x
.
numel
();
auto
*
x_data
=
x
.
data
<
T
>
();
ctx
.
template
Alloc
<
p
addle
::
operators
::
math
::
Real
<
T
>
>
(
out
,
size_t
(
x
.
numel
()
*
sizeof
(
p
addle
::
operators
::
math
::
Real
<
T
>
)));
auto
*
out_data
=
out
->
data
<
p
addle
::
operators
::
math
::
Real
<
T
>>
();
ctx
.
template
Alloc
<
p
ten
::
funcs
::
Real
<
T
>
>
(
out
,
size_t
(
x
.
numel
()
*
sizeof
(
p
ten
::
funcs
::
Real
<
T
>
)));
auto
*
out_data
=
out
->
data
<
p
ten
::
funcs
::
Real
<
T
>>
();
paddle
::
platform
::
ForRange
<
Context
>
for_range
(
ctx
,
numel
);
p
addle
::
operators
::
math
::
AbsFunctor
<
T
>
functor
(
x_data
,
out_data
,
numel
);
p
ten
::
funcs
::
AbsFunctor
<
T
>
functor
(
x_data
,
out_data
,
numel
);
for_range
(
functor
);
}
...
...
paddle/
fluid/operators/math
/complex_functors.h
→
paddle/
pten/kernels/funcs
/complex_functors.h
浏览文件 @
5b5656d0
...
...
@@ -13,15 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <cmath>
#include <type_traits>
#include "paddle/
fluid/platform
/complex.h"
#include "paddle/
pten/common
/complex.h"
#include "paddle/pten/core/hostdevice.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
pten
{
namespace
funcs
{
template
<
bool
B
,
typename
T
>
struct
cond
{
...
...
@@ -64,8 +66,8 @@ using select_t = typename select<Head, Tail...>::type;
template
<
typename
T
>
using
Real
=
select_t
<
cond
<
std
::
is_same
<
T
,
p
latform
::
complex
<
float
>>::
value
,
float
>
,
cond
<
std
::
is_same
<
T
,
p
latform
::
complex
<
double
>>::
value
,
double
>
,
select_t
<
cond
<
std
::
is_same
<
T
,
p
ten
::
dtype
::
complex
<
float
>>::
value
,
float
>
,
cond
<
std
::
is_same
<
T
,
p
ten
::
dtype
::
complex
<
double
>>::
value
,
double
>
,
T
>
;
template
<
typename
T
,
typename
RealT
>
...
...
@@ -77,13 +79,13 @@ using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type;
template
<
typename
T
>
using
EnableComplex
=
typename
std
::
enable_if
<
std
::
is_same
<
T
,
p
latform
::
complex
<
float
>>::
value
||
std
::
is_same
<
T
,
p
latform
::
complex
<
double
>>::
value
>::
type
;
std
::
is_same
<
T
,
p
ten
::
dtype
::
complex
<
float
>>::
value
||
std
::
is_same
<
T
,
p
ten
::
dtype
::
complex
<
double
>>::
value
>::
type
;
template
<
typename
T
>
using
DisableComplex
=
typename
std
::
enable_if
<
!
std
::
is_same
<
T
,
p
latform
::
complex
<
float
>>::
value
&&
!
std
::
is_same
<
T
,
p
latform
::
complex
<
double
>>::
value
>::
type
;
!
std
::
is_same
<
T
,
p
ten
::
dtype
::
complex
<
float
>>::
value
&&
!
std
::
is_same
<
T
,
p
ten
::
dtype
::
complex
<
double
>>::
value
>::
type
;
template
<
typename
T
,
typename
Enable
=
void
>
struct
RealFunctor
;
...
...
@@ -154,8 +156,7 @@ struct AbsFunctor<T, NoComplex<T, Real<T>>> {
template
<
typename
T
>
struct
AbsGradFunctor
{
AbsGradFunctor
(
const
math
::
Real
<
T
>*
dout
,
const
T
*
x
,
T
*
output
,
int64_t
numel
)
AbsGradFunctor
(
const
Real
<
T
>*
dout
,
const
T
*
x
,
T
*
output
,
int64_t
numel
)
:
dout_
(
dout
),
x_
(
x
),
output_
(
output
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
...
...
@@ -166,52 +167,55 @@ struct AbsGradFunctor {
}
}
const
math
::
Real
<
T
>*
dout_
;
const
Real
<
T
>*
dout_
;
const
T
*
x_
;
T
*
output_
;
int64_t
numel_
;
};
template
<
>
struct
AbsGradFunctor
<
paddle
::
platform
::
complex
<
float
>>
{
AbsGradFunctor
(
const
float
*
dout
,
const
paddle
::
platform
::
complex
<
float
>*
x
,
paddle
::
platform
::
complex
<
float
>*
output
,
int64_t
numel
)
struct
AbsGradFunctor
<
pten
::
dtype
::
complex
<
float
>>
{
AbsGradFunctor
(
const
float
*
dout
,
const
pten
::
dtype
::
complex
<
float
>*
x
,
pten
::
dtype
::
complex
<
float
>*
output
,
int64_t
numel
)
:
dout_
(
dout
),
x_
(
x
),
output_
(
output
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
if
(
x_
[
idx
]
==
p
addle
::
platform
::
complex
<
float
>
(
0
))
{
output_
[
idx
]
=
p
addle
::
platform
::
complex
<
float
>
(
0
);
if
(
x_
[
idx
]
==
p
ten
::
dtype
::
complex
<
float
>
(
0
))
{
output_
[
idx
]
=
p
ten
::
dtype
::
complex
<
float
>
(
0
);
}
else
{
output_
[
idx
]
=
p
addle
::
platform
::
complex
<
float
>
(
dout_
[
idx
])
*
(
x_
[
idx
]
/
p
addle
::
platform
::
complex
<
float
>
(
abs
(
x_
[
idx
])));
output_
[
idx
]
=
p
ten
::
dtype
::
complex
<
float
>
(
dout_
[
idx
])
*
(
x_
[
idx
]
/
p
ten
::
dtype
::
complex
<
float
>
(
abs
(
x_
[
idx
])));
}
}
const
float
*
dout_
;
const
p
addle
::
platform
::
complex
<
float
>*
x_
;
p
addle
::
platform
::
complex
<
float
>*
output_
;
const
p
ten
::
dtype
::
complex
<
float
>*
x_
;
p
ten
::
dtype
::
complex
<
float
>*
output_
;
int64_t
numel_
;
};
template
<
>
struct
AbsGradFunctor
<
paddle
::
platform
::
complex
<
double
>>
{
AbsGradFunctor
(
const
double
*
dout
,
const
paddle
::
platform
::
complex
<
double
>*
x
,
paddle
::
platform
::
complex
<
double
>*
output
,
int64_t
numel
)
struct
AbsGradFunctor
<
pten
::
dtype
::
complex
<
double
>>
{
AbsGradFunctor
(
const
double
*
dout
,
const
pten
::
dtype
::
complex
<
double
>*
x
,
pten
::
dtype
::
complex
<
double
>*
output
,
int64_t
numel
)
:
dout_
(
dout
),
x_
(
x
),
output_
(
output
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
if
(
x_
[
idx
]
==
p
addle
::
platform
::
complex
<
double
>
(
0
))
{
output_
[
idx
]
=
p
addle
::
platform
::
complex
<
double
>
(
0
);
if
(
x_
[
idx
]
==
p
ten
::
dtype
::
complex
<
double
>
(
0
))
{
output_
[
idx
]
=
p
ten
::
dtype
::
complex
<
double
>
(
0
);
}
else
{
output_
[
idx
]
=
paddle
::
platform
::
complex
<
double
>
(
dout_
[
idx
])
*
(
x_
[
idx
]
/
paddle
::
platform
::
complex
<
double
>
(
abs
(
x_
[
idx
])));
output_
[
idx
]
=
pten
::
dtype
::
complex
<
double
>
(
dout_
[
idx
])
*
(
x_
[
idx
]
/
pten
::
dtype
::
complex
<
double
>
(
abs
(
x_
[
idx
])));
}
}
const
double
*
dout_
;
const
p
addle
::
platform
::
complex
<
double
>*
x_
;
p
addle
::
platform
::
complex
<
double
>*
output_
;
const
p
ten
::
dtype
::
complex
<
double
>*
x_
;
p
ten
::
dtype
::
complex
<
double
>*
output_
;
int64_t
numel_
;
};
...
...
@@ -235,46 +239,48 @@ struct AbsGradGradFunctor {
};
template
<
>
struct
AbsGradGradFunctor
<
paddle
::
platform
::
complex
<
double
>>
{
AbsGradGradFunctor
(
const
paddle
::
platform
::
complex
<
double
>*
ddx
,
const
paddle
::
platform
::
complex
<
double
>*
x
,
paddle
::
platform
::
complex
<
double
>*
output
,
int64_t
numel
)
struct
AbsGradGradFunctor
<
pten
::
dtype
::
complex
<
double
>>
{
AbsGradGradFunctor
(
const
pten
::
dtype
::
complex
<
double
>*
ddx
,
const
pten
::
dtype
::
complex
<
double
>*
x
,
pten
::
dtype
::
complex
<
double
>*
output
,
int64_t
numel
)
:
ddx_
(
ddx
),
x_
(
x
),
output_
(
output
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
if
(
x_
[
idx
]
==
p
addle
::
platform
::
complex
<
double
>
(
0
))
{
output_
[
idx
]
=
p
addle
::
platform
::
complex
<
double
>
(
0
);
if
(
x_
[
idx
]
==
p
ten
::
dtype
::
complex
<
double
>
(
0
))
{
output_
[
idx
]
=
p
ten
::
dtype
::
complex
<
double
>
(
0
);
}
else
{
output_
[
idx
]
=
p
addle
::
platform
::
complex
<
double
>
(
ddx_
[
idx
])
*
x_
[
idx
]
/
p
addle
::
platform
::
complex
<
double
>
(
abs
(
x_
[
idx
]));
output_
[
idx
]
=
p
ten
::
dtype
::
complex
<
double
>
(
ddx_
[
idx
])
*
x_
[
idx
]
/
p
ten
::
dtype
::
complex
<
double
>
(
abs
(
x_
[
idx
]));
}
}
const
p
addle
::
platform
::
complex
<
double
>*
ddx_
;
const
p
addle
::
platform
::
complex
<
double
>*
x_
;
p
addle
::
platform
::
complex
<
double
>*
output_
;
const
p
ten
::
dtype
::
complex
<
double
>*
ddx_
;
const
p
ten
::
dtype
::
complex
<
double
>*
x_
;
p
ten
::
dtype
::
complex
<
double
>*
output_
;
int64_t
numel_
;
};
template
<
>
struct
AbsGradGradFunctor
<
paddle
::
platform
::
complex
<
float
>>
{
AbsGradGradFunctor
(
const
paddle
::
platform
::
complex
<
float
>*
ddx
,
const
paddle
::
platform
::
complex
<
float
>*
x
,
paddle
::
platform
::
complex
<
float
>*
output
,
int64_t
numel
)
struct
AbsGradGradFunctor
<
pten
::
dtype
::
complex
<
float
>>
{
AbsGradGradFunctor
(
const
pten
::
dtype
::
complex
<
float
>*
ddx
,
const
pten
::
dtype
::
complex
<
float
>*
x
,
pten
::
dtype
::
complex
<
float
>*
output
,
int64_t
numel
)
:
ddx_
(
ddx
),
x_
(
x
),
output_
(
output
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
if
(
x_
[
idx
]
==
p
addle
::
platform
::
complex
<
float
>
(
0
))
{
output_
[
idx
]
=
p
addle
::
platform
::
complex
<
float
>
(
0
);
if
(
x_
[
idx
]
==
p
ten
::
dtype
::
complex
<
float
>
(
0
))
{
output_
[
idx
]
=
p
ten
::
dtype
::
complex
<
float
>
(
0
);
}
else
{
output_
[
idx
]
=
p
addle
::
platform
::
complex
<
float
>
(
ddx_
[
idx
])
*
x_
[
idx
]
/
p
addle
::
platform
::
complex
<
float
>
(
abs
(
x_
[
idx
]));
output_
[
idx
]
=
p
ten
::
dtype
::
complex
<
float
>
(
ddx_
[
idx
])
*
x_
[
idx
]
/
p
ten
::
dtype
::
complex
<
float
>
(
abs
(
x_
[
idx
]));
}
}
const
p
addle
::
platform
::
complex
<
float
>*
ddx_
;
const
p
addle
::
platform
::
complex
<
float
>*
x_
;
p
addle
::
platform
::
complex
<
float
>*
output_
;
const
p
ten
::
dtype
::
complex
<
float
>*
ddx_
;
const
p
ten
::
dtype
::
complex
<
float
>*
x_
;
p
ten
::
dtype
::
complex
<
float
>*
output_
;
int64_t
numel_
;
};
template
<
typename
T
,
typename
Enable
=
void
>
...
...
@@ -318,8 +324,10 @@ struct RealImagToComplexFunctor;
template
<
typename
T
>
struct
RealImagToComplexFunctor
<
T
,
Complex
<
T
,
Real
<
T
>>>
{
RealImagToComplexFunctor
(
const
Real
<
T
>*
input_real
,
const
Real
<
T
>*
input_imag
,
T
*
output
,
int64_t
numel
)
RealImagToComplexFunctor
(
const
Real
<
T
>*
input_real
,
const
Real
<
T
>*
input_imag
,
T
*
output
,
int64_t
numel
)
:
input_real_
(
input_real
),
input_imag_
(
input_imag
),
output_
(
output
),
...
...
@@ -363,6 +371,84 @@ struct ConjFunctor<T, DisableComplex<T>> {
T
*
output_
;
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
template
<
typename
T
,
typename
Enable
=
void
>
struct
AngleFunctor
;
// angel function for complex
template
<
typename
T
>
struct
AngleFunctor
<
T
,
pten
::
funcs
::
Complex
<
T
,
pten
::
funcs
::
Real
<
T
>>>
{
AngleFunctor
(
const
T
*
input
,
pten
::
funcs
::
Real
<
T
>*
output
,
int64_t
numel
)
:
input_
(
input
),
output_
(
output
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
output_
[
idx
]
=
arg
(
input_
[
idx
]);
}
const
T
*
input_
;
pten
::
funcs
::
Real
<
T
>*
output_
;
int64_t
numel_
;
};
// angel function for real
template
<
typename
T
>
struct
AngleFunctor
<
T
,
pten
::
funcs
::
NoComplex
<
T
,
pten
::
funcs
::
Real
<
T
>>>
{
AngleFunctor
(
const
T
*
input
,
T
*
output
,
int64_t
numel
)
:
input_
(
input
),
output_
(
output
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
output_
[
idx
]
=
input_
[
idx
]
<
static_cast
<
T
>
(
0
)
?
M_PI
:
0
;
}
const
T
*
input_
;
T
*
output_
;
int64_t
numel_
;
};
template
<
typename
T
,
typename
Enable
=
void
>
struct
AngleGradFunctor
;
// angle grad for complex
template
<
typename
T
>
struct
AngleGradFunctor
<
T
,
pten
::
funcs
::
Complex
<
T
,
pten
::
funcs
::
Real
<
T
>>>
{
AngleGradFunctor
(
const
pten
::
funcs
::
Real
<
T
>*
dout
,
const
T
*
x
,
T
*
dx
,
int64_t
numel
)
:
dout_
(
dout
),
x_
(
x
),
dx_
(
dx
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
if
(
x_
[
idx
]
==
T
(
0
))
{
dx_
[
idx
]
=
T
(
0
);
}
else
{
const
pten
::
funcs
::
Real
<
T
>
r_square
=
x_
[
idx
].
real
*
x_
[
idx
].
real
+
x_
[
idx
].
imag
*
x_
[
idx
].
imag
;
dx_
[
idx
]
=
T
(
-
dout_
[
idx
]
*
x_
[
idx
].
imag
/
r_square
,
dout_
[
idx
]
*
x_
[
idx
].
real
/
r_square
);
}
}
const
pten
::
funcs
::
Real
<
T
>*
dout_
;
const
T
*
x_
;
T
*
dx_
;
int64_t
numel_
;
};
// angle grad for real
template
<
typename
T
>
struct
AngleGradFunctor
<
T
,
pten
::
funcs
::
NoComplex
<
T
,
pten
::
funcs
::
Real
<
T
>>>
{
AngleGradFunctor
(
const
pten
::
funcs
::
Real
<
T
>*
dout
,
const
T
*
x
,
T
*
dx
,
int64_t
numel
)
:
dout_
(
dout
),
x_
(
x
),
dx_
(
dx
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
dx_
[
idx
]
=
0
;
}
const
pten
::
funcs
::
Real
<
T
>*
dout_
;
const
T
*
x_
;
T
*
dx_
;
int64_t
numel_
;
};
}
// namespace funcs
}
// namespace pten
paddle/pten/kernels/gpu/abs_kernel.cu
浏览文件 @
5b5656d0
...
...
@@ -14,11 +14,11 @@
#include <algorithm>
#include <vector>
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/abs_kernel.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
namespace
pten
{
...
...
@@ -27,19 +27,14 @@ template <typename T, typename Enable = void>
struct
CudaAbsFunctor
;
template
<
typename
T
>
struct
CudaAbsFunctor
<
T
,
paddle
::
operators
::
math
::
Complex
<
T
,
paddle
::
operators
::
math
::
Real
<
T
>>>
{
__device__
__forceinline__
paddle
::
operators
::
math
::
Real
<
T
>
operator
()(
const
T
x
)
const
{
struct
CudaAbsFunctor
<
T
,
pten
::
funcs
::
Complex
<
T
,
pten
::
funcs
::
Real
<
T
>>>
{
__device__
__forceinline__
pten
::
funcs
::
Real
<
T
>
operator
()(
const
T
x
)
const
{
return
abs
(
x
);
}
};
template
<
typename
T
>
struct
CudaAbsFunctor
<
T
,
paddle
::
operators
::
math
::
NoComplex
<
T
,
paddle
::
operators
::
math
::
Real
<
T
>>>
{
struct
CudaAbsFunctor
<
T
,
pten
::
funcs
::
NoComplex
<
T
,
pten
::
funcs
::
Real
<
T
>>>
{
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
return
std
::
abs
(
x
);
}
...
...
@@ -47,12 +42,12 @@ struct CudaAbsFunctor<
template
<
typename
T
,
typename
Context
>
void
AbsKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
)
{
ctx
.
template
Alloc
<
p
addle
::
operators
::
math
::
Real
<
T
>
>
(
out
);
ctx
.
template
Alloc
<
p
ten
::
funcs
::
Real
<
T
>
>
(
out
);
std
::
vector
<
const
DenseTensor
*>
ins
=
{
&
x
};
std
::
vector
<
DenseTensor
*>
outs
=
{
out
};
auto
functor
=
CudaAbsFunctor
<
T
>
();
funcs
::
LaunchSameDimsElementwiseCudaKernel
<
p
addle
::
operators
::
math
::
Real
<
T
>>
(
funcs
::
LaunchSameDimsElementwiseCudaKernel
<
p
ten
::
funcs
::
Real
<
T
>>
(
ctx
,
ins
,
&
outs
,
functor
);
}
...
...
paddle/pten/kernels/impl/abs_grad_kernel_impl.h
浏览文件 @
5b5656d0
...
...
@@ -14,9 +14,9 @@
#pragma once
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/abs_grad_kernel.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
pten
{
...
...
@@ -26,15 +26,14 @@ void AbsGradKernel(const Context& ctx,
const
DenseTensor
&
dout
,
DenseTensor
*
dx
)
{
auto
numel
=
dout
.
numel
();
auto
*
dout_data
=
dout
.
data
<
p
addle
::
operators
::
math
::
Real
<
T
>>
();
auto
*
dout_data
=
dout
.
data
<
p
ten
::
funcs
::
Real
<
T
>>
();
auto
*
x_data
=
x
.
data
<
T
>
();
ctx
.
template
Alloc
<
T
>(
dx
,
static_cast
<
size_t
>
(
numel
*
sizeof
(
T
)));
auto
*
dx_data
=
dx
->
data
<
T
>
();
paddle
::
platform
::
ForRange
<
Context
>
for_range
(
ctx
,
numel
);
paddle
::
operators
::
math
::
AbsGradFunctor
<
T
>
functor
(
dout_data
,
x_data
,
dx_data
,
numel
);
pten
::
funcs
::
AbsGradFunctor
<
T
>
functor
(
dout_data
,
x_data
,
dx_data
,
numel
);
for_range
(
functor
);
}
...
...
@@ -50,7 +49,7 @@ void AbsDoubleGradKernel(const Context& ctx,
auto
*
ddout_data
=
ddout
->
data
<
T
>
();
paddle
::
platform
::
ForRange
<
Context
>
for_range
(
ctx
,
numel
);
p
addle
::
operators
::
math
::
AbsGradGradFunctor
<
T
>
functor
(
p
ten
::
funcs
::
AbsGradGradFunctor
<
T
>
functor
(
ddx_data
,
x_data
,
ddout_data
,
numel
);
for_range
(
functor
);
}
...
...
paddle/pten/kernels/impl/complex_kernel_impl.h
浏览文件 @
5b5656d0
...
...
@@ -15,8 +15,8 @@
#pragma once
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/complex_functors.h"
namespace
pten
{
...
...
@@ -29,7 +29,7 @@ void ConjKernel(const Context& dev_ctx,
auto
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
paddle
::
platform
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
numel
);
p
addle
::
operators
::
math
::
ConjFunctor
<
T
>
functor
(
x_data
,
numel
,
out_data
);
p
ten
::
funcs
::
ConjFunctor
<
T
>
functor
(
x_data
,
numel
,
out_data
);
for_range
(
functor
);
}
...
...
paddle/pten/kernels/impl/dot_grad_kernel_impl.h
浏览文件 @
5b5656d0
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/pten/kernels/complex_kernel.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/
fluid/operators/math
/complex_functors.h"
#include "paddle/
pten/kernels/funcs
/complex_functors.h"
namespace
pten
{
...
...
@@ -35,9 +35,7 @@ struct DotGradFunction {
};
template
<
typename
DeviceContext
,
typename
T
>
struct
DotGradFunction
<
DeviceContext
,
T
,
paddle
::
operators
::
math
::
EnableComplex
<
T
>>
{
struct
DotGradFunction
<
DeviceContext
,
T
,
pten
::
funcs
::
EnableComplex
<
T
>>
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
DenseTensor
*
tensor_x
,
const
DenseTensor
*
tensor_y
,
...
...
@@ -133,9 +131,7 @@ struct DotGradFunction<DeviceContext,
};
template
<
typename
DeviceContext
,
typename
T
>
struct
DotGradFunction
<
DeviceContext
,
T
,
paddle
::
operators
::
math
::
DisableComplex
<
T
>>
{
struct
DotGradFunction
<
DeviceContext
,
T
,
pten
::
funcs
::
DisableComplex
<
T
>>
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
DenseTensor
*
tensor_x
,
const
DenseTensor
*
tensor_y
,
...
...
@@ -221,9 +217,7 @@ struct DotDoubleGradFunction {
};
template
<
typename
DeviceContext
,
typename
T
>
struct
DotDoubleGradFunction
<
DeviceContext
,
T
,
paddle
::
operators
::
math
::
EnableComplex
<
T
>>
{
struct
DotDoubleGradFunction
<
DeviceContext
,
T
,
pten
::
funcs
::
EnableComplex
<
T
>>
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
DenseTensor
*
tensor_x
,
const
DenseTensor
*
tensor_y
,
...
...
@@ -334,9 +328,7 @@ struct DotDoubleGradFunction<DeviceContext,
};
template
<
typename
DeviceContext
,
typename
T
>
struct
DotDoubleGradFunction
<
DeviceContext
,
T
,
paddle
::
operators
::
math
::
DisableComplex
<
T
>>
{
struct
DotDoubleGradFunction
<
DeviceContext
,
T
,
pten
::
funcs
::
DisableComplex
<
T
>>
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
DenseTensor
*
tensor_x
,
const
DenseTensor
*
tensor_y
,
...
...
@@ -461,9 +453,7 @@ struct DotTripleGradFunction {
// TODO(wuweilong): enable this function when the unittests framewark for multi
// grad is ok (dtype: complex64 or complex128).
template
<
typename
DeviceContext
,
typename
T
>
struct
DotTripleGradFunction
<
DeviceContext
,
T
,
paddle
::
operators
::
math
::
EnableComplex
<
T
>>
{
struct
DotTripleGradFunction
<
DeviceContext
,
T
,
pten
::
funcs
::
EnableComplex
<
T
>>
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
DenseTensor
*
in_tensor_x
,
const
DenseTensor
*
in_tensor_y
,
...
...
@@ -656,9 +646,7 @@ struct DotTripleGradFunction<DeviceContext,
};
template
<
typename
DeviceContext
,
typename
T
>
struct
DotTripleGradFunction
<
DeviceContext
,
T
,
paddle
::
operators
::
math
::
DisableComplex
<
T
>>
{
struct
DotTripleGradFunction
<
DeviceContext
,
T
,
pten
::
funcs
::
DisableComplex
<
T
>>
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
DenseTensor
*
in_tensor_x
,
const
DenseTensor
*
in_tensor_y
,
...
...
paddle/pten/kernels/impl/matmul_kernel_impl.h
浏览文件 @
5b5656d0
...
...
@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/
fluid/operators/math
/complex_functors.h"
#include "paddle/
pten/kernels/funcs
/complex_functors.h"
#include "paddle/pten/core/dense_tensor.h"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录