Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
31e874b1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
31e874b1
编写于
12月 18, 2021
作者:
F
Feiyu Chan
提交者:
GitHub
12月 18, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add complex op (#37918)
* add complex op and `paddle.complex`.
上级
a3bd6fc0
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
555 addition
and
64 deletion
+555
-64
paddle/fluid/operators/complex_op.cc
paddle/fluid/operators/complex_op.cc
+144
-0
paddle/fluid/operators/complex_op.cu
paddle/fluid/operators/complex_op.cu
+27
-0
paddle/fluid/operators/complex_op.h
paddle/fluid/operators/complex_op.h
+111
-0
paddle/fluid/operators/elementwise/elementwise_op_function.h
paddle/fluid/operators/elementwise/elementwise_op_function.h
+68
-63
python/paddle/__init__.py
python/paddle/__init__.py
+2
-0
python/paddle/fluid/tests/unittests/test_complex_op.py
python/paddle/fluid/tests/unittests/test_complex_op.py
+156
-0
python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py
...luid/tests/unittests/white_list/no_grad_set_white_list.py
+2
-1
python/paddle/tensor/__init__.py
python/paddle/tensor/__init__.py
+1
-0
python/paddle/tensor/creation.py
python/paddle/tensor/creation.py
+44
-0
未找到文件。
paddle/fluid/operators/complex_op.cc
0 → 100644
浏览文件 @
31e874b1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/complex_op.h"
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/common_infer_shape_functions.cc"
namespace
paddle
{
namespace
operators
{
class
ComplexOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
protected:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), real part of complex_op"
);
AddInput
(
"Y"
,
"(Tensor), image part of complex_op"
);
AddOutput
(
"Out"
,
"(Tensor), output of complex_op"
);
AddComment
(
R"DOC(
Complex Operator.
Return a complex tensor given the real and image tensors.
)DOC"
);
}
};
template
<
typename
T
>
class
ComplexGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
"complex_grad"
);
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"Y"
,
this
->
Input
(
"Y"
));
// op->SetInput("Out", this->Output("Out"));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
this
->
InputGrad
(
"Y"
));
op
->
SetAttrMap
(
this
->
Attrs
());
}
};
class
ComplexOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"complex"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Y"
),
"Input"
,
"Y"
,
"complex"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"complex"
);
if
(
ctx
->
GetInputDim
(
"X"
)
==
ctx
->
GetInputDim
(
"Y"
))
{
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
// NOTE(chenfeiyu): lod & broadcasting is intrinsically contradictory
// so tensors with lod are not supported here
}
else
{
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
int
max_dim
=
std
::
max
(
x_dims
.
size
(),
y_dims
.
size
());
// start align axis
int
axis
=
std
::
abs
(
x_dims
.
size
()
-
y_dims
.
size
());
std
::
vector
<
int
>
x_dims_array
(
max_dim
);
std
::
vector
<
int
>
y_dims_array
(
max_dim
);
std
::
vector
<
int
>
out_dims_array
(
max_dim
);
details
::
GetBroadcastDimsArrays
(
x_dims
,
y_dims
,
x_dims_array
.
data
(),
y_dims_array
.
data
(),
out_dims_array
.
data
(),
max_dim
,
axis
);
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
out_dims_array
));
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
());
}
};
class
ComplexGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"complex_grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Y"
),
"Input"
,
"Y"
,
"kron_complex_gradgrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
framework
::
GradVarName
(
"Out"
),
"complex_grad"
);
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
ctx
->
ShareDim
(
"X"
,
/*->*/
x_grad_name
);
}
auto
y_grad_name
=
framework
::
GradVarName
(
"Y"
);
if
(
ctx
->
HasOutput
(
y_grad_name
))
{
ctx
->
ShareDim
(
"Y"
,
/*->*/
y_grad_name
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
out_grad_name
=
framework
::
GradVarName
(
"Out"
);
auto
computation_dtype
=
framework
::
ToRealType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
out_grad_name
));
return
framework
::
OpKernelType
(
computation_dtype
,
ctx
.
GetPlace
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
complex
,
ops
::
ComplexOp
,
ops
::
ComplexOpMaker
,
ops
::
ComplexGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
ComplexGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
complex_grad
,
ops
::
ComplexGradOp
);
REGISTER_OP_CPU_KERNEL
(
complex
,
ops
::
ComplexKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ComplexKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
complex_grad
,
ops
::
ComplexGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ComplexGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/complex_op.cu
0 → 100644
浏览文件 @
31e874b1
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/complex_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
complex
,
ops
::
ComplexKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ComplexKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
complex_grad
,
ops
::
ComplexGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ComplexGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/complex_op.h
0 → 100644
浏览文件 @
31e874b1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/complex.h"
namespace
paddle
{
namespace
operators
{
// functors to use with ElementwiseComputeEx
template
<
typename
T
>
struct
RealAndImagToComplexFunctor
{
inline
HOSTDEVICE
platform
::
complex
<
T
>
operator
()(
const
T
&
x
,
const
T
&
y
)
{
return
platform
::
complex
<
T
>
(
x
,
y
);
}
};
template
<
typename
T
>
struct
ImagAndRealToComplexFunctor
{
inline
HOSTDEVICE
platform
::
complex
<
T
>
operator
()(
const
T
&
y
,
const
T
&
x
)
{
return
platform
::
complex
<
T
>
(
x
,
y
);
}
};
template
<
typename
T
>
struct
ComplexGradForRealFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
x
,
const
T
y
,
const
platform
::
complex
<
T
>
out
,
const
platform
::
complex
<
T
>
dout
)
{
return
dout
.
real
;
}
};
template
<
typename
T
>
struct
ComplexGradForImagFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
x
,
const
T
y
,
const
platform
::
complex
<
T
>
out
,
const
platform
::
complex
<
T
>
dout
)
{
return
dout
.
imag
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
ComplexKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
x
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
auto
*
y
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
*
z
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
using
C
=
platform
::
complex
<
T
>
;
z
->
mutable_data
<
C
>
(
ctx
.
GetPlace
());
// NOTE(chenfeiyu): be careful of the caveats of calling elementwise-related
// facility functions
#if defined(__NVCC__) || defined(__HIPCC__)
ElementwiseComputeEx
<
RealAndImagToComplexFunctor
<
T
>
,
DeviceContext
,
T
,
C
>
(
ctx
,
x
,
y
,
/*axis*/
-
1
,
RealAndImagToComplexFunctor
<
T
>
(),
z
);
#else
auto
x_dims
=
x
->
dims
();
auto
y_dims
=
y
->
dims
();
if
(
x_dims
.
size
()
>=
y_dims
.
size
())
{
ElementwiseComputeEx
<
RealAndImagToComplexFunctor
<
T
>
,
DeviceContext
,
T
,
C
>
(
ctx
,
x
,
y
,
/*axis*/
-
1
,
RealAndImagToComplexFunctor
<
T
>
(),
z
);
}
else
{
ElementwiseComputeEx
<
ImagAndRealToComplexFunctor
<
T
>
,
DeviceContext
,
T
,
C
>
(
ctx
,
x
,
y
,
/*axis*/
-
1
,
ImagAndRealToComplexFunctor
<
T
>
(),
z
);
}
#endif
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
ComplexGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
Tensor
=
framework
::
Tensor
;
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
using
C
=
platform
::
complex
<
T
>
;
// skip out in a hacky way
auto
*
out
=
dout
;
ElemwiseGradCompute
<
DeviceContext
,
T
,
ComplexGradForRealFunctor
<
T
>
,
ComplexGradForImagFunctor
<
T
>
,
C
>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
/*axis*/
-
1
,
dx
,
dy
,
ComplexGradForRealFunctor
<
T
>
(),
ComplexGradForImagFunctor
<
T
>
());
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/elementwise/elementwise_op_function.h
浏览文件 @
31e874b1
...
...
@@ -169,7 +169,7 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
is_xsize_larger
);
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
void
CommonGradBroadcastCPU
(
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
,
const
framework
::
Tensor
&
out
,
const
framework
::
Tensor
&
dout
,
...
...
@@ -179,8 +179,8 @@ void CommonGradBroadcastCPU(
std
::
vector
<
int
>
index_array
(
max_dim
,
0
);
const
T
*
x_data
=
x
.
data
<
T
>
();
const
T
*
y_data
=
y
.
data
<
T
>
();
const
T
*
out_data
=
out
.
data
<
T
>
();
const
T
*
dout_data
=
dout
.
data
<
T
>
();
const
T
out
*
out_data
=
out
.
data
<
Tout
>
();
const
T
out
*
dout_data
=
dout
.
data
<
Tout
>
();
T
*
dx_data
=
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
dy_data
=
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
dx_data
!=
nullptr
)
{
...
...
@@ -240,9 +240,9 @@ inline void ComputeBroadcastTranspositionArray(const int *x_one_indexs,
}
#if defined(__NVCC__) || defined(__HIPCC__)
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
static
__global__
void
ElemwiseGradBroadcast1CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
const
T
*
x
,
const
T
*
y
,
const
T
out
*
out
,
const
Tout
*
dout
,
int
h
,
int
w
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
j
=
blockIdx
.
x
;
int
i
=
threadIdx
.
x
;
...
...
@@ -291,9 +291,9 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
// suppose use 2D block is fast because more parallel
// and memory coalesced
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
static
__global__
void
FastElemwiseGradBroadcast1CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
const
T
*
x
,
const
T
*
y
,
const
T
out
*
out
,
const
Tout
*
dout
,
int
h
,
int
w
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
__shared__
T
sdata
[
BLOCK_Y
][
BLOCK_X
+
1
];
...
...
@@ -369,12 +369,12 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
}
}
template
<
typename
T
,
typename
DX_OP
>
template
<
typename
T
,
typename
DX_OP
,
typename
Tout
=
T
>
__global__
void
CommonGradBroadcastCUDAKernel
(
const
int
*
x_strides_array
,
const
int
*
y_strides_array
,
const
int
*
out_dims_array
,
const
int
*
y_strides_order
,
const
int
*
y_dims_order
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
T
*
dx
,
int
out_size
,
int
max_dim
,
int
thread_num
,
const
int
*
y_dims_order
,
const
T
*
x
,
const
T
*
y
,
const
T
out
*
out
,
const
T
out
*
dout
,
T
*
dx
,
int
out_size
,
int
max_dim
,
int
thread_num
,
DX_OP
dx_op
)
{
T
val
(
0
);
int
i
=
blockIdx
.
x
;
...
...
@@ -408,9 +408,9 @@ __global__ void CommonGradBroadcastCUDAKernel(
}
}
template
<
typename
T
,
typename
DY_OP
>
template
<
typename
T
,
typename
DY_OP
,
typename
Tout
=
T
>
static
__global__
void
CommonGradBroadcast1CUDAKernelHeight
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
const
T
*
x
,
const
T
*
y
,
const
T
out
*
out
,
const
Tout
*
dout
,
int
h
,
int
w
,
DY_OP
dy_op
,
T
*
dy
,
int
x_h
,
int
x_w
,
bool
is_y
)
{
int
j
=
blockIdx
.
x
;
int
i
=
threadIdx
.
x
;
...
...
@@ -454,9 +454,9 @@ static __global__ void CommonGradBroadcast1CUDAKernelHeight(
}
}
template
<
typename
T
,
typename
DY_OP
>
template
<
typename
T
,
typename
DY_OP
,
typename
Tout
=
T
>
static
__global__
void
FastCommonGradBroadcastCUDAKernelHeight
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
const
T
*
x
,
const
T
*
y
,
const
T
out
*
out
,
const
Tout
*
dout
,
int
h
,
int
w
,
DY_OP
dy_op
,
T
*
dy
,
int
x_h
,
int
x_w
,
bool
is_y
)
{
__shared__
T
sdata
[
BLOCK_Y
][
BLOCK_X
+
1
];
...
...
@@ -528,9 +528,9 @@ static __global__ void FastCommonGradBroadcastCUDAKernelHeight(
}
}
template
<
typename
T
,
typename
DY_OP
,
typename
DX_OP
>
template
<
typename
T
,
typename
DY_OP
,
typename
DX_OP
,
typename
Tout
=
T
>
static
__global__
void
FastCommonGradBroadcastAllCUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
const
T
*
x
,
const
T
*
y
,
const
T
out
*
out
,
const
Tout
*
dout
,
int
pre
,
int
n
,
int
post
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
tid
=
threadIdx
.
x
;
int
bid
=
blockIdx
.
x
;
...
...
@@ -581,9 +581,9 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel(
}
}
template
<
typename
T
,
typename
OP
>
template
<
typename
T
,
typename
OP
,
typename
Tout
=
T
>
static
__global__
void
FastCommonGradBroadcastOneCUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
const
T
*
x
,
const
T
*
y
,
const
T
out
*
out
,
const
Tout
*
dout
,
int
pre
,
int
n
,
int
post
,
int
y_pre
,
int
y_n
,
int
y_post
,
bool
is_xsize
,
OP
op
,
T
*
dd
)
{
int
tid
=
threadIdx
.
x
;
int
bid
=
blockIdx
.
x
;
...
...
@@ -669,7 +669,7 @@ static inline bool CheckContiguousDims(const std::vector<int> &broadcast_pos) {
return
true
;
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
void
CommonGradBroadcastCUDA
(
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
,
const
framework
::
Tensor
&
out
,
const
framework
::
Tensor
&
dout
,
...
...
@@ -680,8 +680,8 @@ void CommonGradBroadcastCUDA(
auto
cplace
=
platform
::
CPUPlace
();
const
T
*
x_data
=
x
.
data
<
T
>
();
const
T
*
y_data
=
y
.
data
<
T
>
();
const
T
*
out_data
=
out
.
data
<
T
>
();
const
T
*
dout_data
=
dout
.
data
<
T
>
();
const
T
out
*
out_data
=
out
.
data
<
Tout
>
();
const
T
out
*
dout_data
=
dout
.
data
<
Tout
>
();
T
*
dx_data
=
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
dy_data
=
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
@@ -1045,7 +1045,7 @@ void CommonGradBroadcastCUDA(
memory
::
Copy
(
gplace
,
x_dims_order_gpu
,
cplace
,
x_dims_order
.
data
(),
bytes
,
ctx
.
stream
());
CommonGradBroadcastCUDAKernel
<
T
,
DX_OP
><<<
x_blocks
,
x_block_size
,
0
,
ctx
.
stream
()
>>>
(
T
,
DX_OP
,
Tout
><<<
x_blocks
,
x_block_size
,
0
,
ctx
.
stream
()
>>>
(
x_strides_array_gpu
,
y_strides_array_gpu
,
out_dims_array_gpu
,
x_strides_order_gpu
,
x_dims_order_gpu
,
x_data
,
y_data
,
out_data
,
dout_data
,
dx_data
,
out_size
,
max_dim
,
x_threads
,
dx_op
);
...
...
@@ -1062,7 +1062,7 @@ void CommonGradBroadcastCUDA(
memory
::
Copy
(
gplace
,
y_dims_order_gpu
,
cplace
,
y_dims_order
.
data
(),
bytes
,
ctx
.
stream
());
CommonGradBroadcastCUDAKernel
<
T
,
DY_OP
><<<
y_blocks
,
y_block_size
,
0
,
ctx
.
stream
()
>>>
(
T
,
DY_OP
,
Tout
><<<
y_blocks
,
y_block_size
,
0
,
ctx
.
stream
()
>>>
(
x_strides_array_gpu
,
y_strides_array_gpu
,
out_dims_array_gpu
,
y_strides_order_gpu
,
y_dims_order_gpu
,
x_data
,
y_data
,
out_data
,
dout_data
,
dy_data
,
out_size
,
max_dim
,
y_threads
,
dy_op
);
...
...
@@ -1138,12 +1138,12 @@ class TransformFunctor {
bool
is_xsize_larger_
;
};
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
struct
ElemwiseGradNoBroadcast
{
const
T
*
x_
;
const
T
*
y_
;
const
T
*
out_
;
const
T
*
dout_
;
const
T
out
*
out_
;
const
T
out
*
dout_
;
HOSTDEVICE
void
operator
()(
size_t
i
)
{
if
(
dx_
!=
nullptr
)
{
...
...
@@ -1160,9 +1160,9 @@ struct ElemwiseGradNoBroadcast {
T
*
dy_
;
};
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast1CPU
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
static
void
ElemwiseGradBroadcast1CPU
(
const
T
*
x
,
const
T
*
y
,
const
T
out
*
out
,
const
T
out
*
dout
,
int
h
,
int
w
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
if
(
is_xsize_larger
)
{
...
...
@@ -1206,11 +1206,12 @@ static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
#if defined(__NVCC__) || defined(__HIPCC__)
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
static
void
ElemwiseGradBroadcast1CUDA
(
gpuStream_t
stream
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
const
T
*
y
,
const
Tout
*
out
,
const
Tout
*
dout
,
int
h
,
int
w
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
// For small case use 1D block
constexpr
int
half_walf
=
16
;
if
(
w
<
half_walf
||
h
<
half_walf
)
{
...
...
@@ -1229,11 +1230,11 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, const T *x,
#endif
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast2CPU
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
static
void
ElemwiseGradBroadcast2CPU
(
const
T
*
x
,
const
T
*
y
,
const
T
out
*
out
,
const
T
out
*
dout
,
int
pre
,
int
n
,
int
post
,
bool
is_xsize_larger
,
D
X_OP
dx_op
,
D
Y_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
if
(
is_xsize_larger
)
{
for
(
int
i
=
0
;
i
<
pre
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
...
...
@@ -1278,9 +1279,9 @@ static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out,
}
#if defined(__NVCC__) || defined(__HIPCC__)
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
static
__global__
void
ElemwiseGradBroadcast2CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
const
T
*
x
,
const
T
*
y
,
const
T
out
*
out
,
const
Tout
*
dout
,
int
pre
,
int
n
,
int
post
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
tid
=
threadIdx
.
x
;
int
j
=
blockIdx
.
x
;
...
...
@@ -1345,12 +1346,12 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
}
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
static
void
ElemwiseGradBroadcast2CUDA
(
gpuStream_t
stream
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
d
out
,
int
pre
,
int
n
,
int
post
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
const
T
*
y
,
const
T
out
*
out
,
const
Tout
*
dout
,
int
pre
,
int
n
,
int
post
,
bool
is_xsize_larger
,
D
X_OP
dx_op
,
D
Y_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
pre
*
post
);
int
gird_size
=
n
;
ElemwiseGradBroadcast2CUDAKernel
<<<
gird_size
,
block_size
,
0
,
stream
>>>
(
...
...
@@ -1359,7 +1360,8 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, const T *x,
#endif
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
void
CommonElementwiseBroadcastBackward
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
const
framework
::
Tensor
&
x
,
...
...
@@ -1387,14 +1389,14 @@ void CommonElementwiseBroadcastBackward(
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#if defined(__NVCC__) || defined(__HIPCC__)
CommonGradBroadcastCUDA
<
T
,
DX_OP
,
DY_OP
>
(
CommonGradBroadcastCUDA
<
T
,
DX_OP
,
DY_OP
,
Tout
>
(
x
,
y
,
out
,
dout
,
dx
,
dy
,
x_dims_array
.
data
(),
y_dims_array
.
data
(),
out_dims_array
.
data
(),
max_dim
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>(),
dx_op
,
dy_op
);
#endif
}
else
{
CommonGradBroadcastCPU
<
T
,
DX_OP
,
DY_OP
>
(
CommonGradBroadcastCPU
<
T
,
DX_OP
,
DY_OP
,
Tout
>
(
x
,
y
,
out
,
dout
,
dx
,
dy
,
x_dims_array
.
data
(),
y_dims_array
.
data
(),
out_dims_array
.
data
(),
max_dim
,
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>(),
dx_op
,
...
...
@@ -1402,7 +1404,8 @@ void CommonElementwiseBroadcastBackward(
}
}
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
void
ElemwiseGradComputeNoBroadcast
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
DDim
&
x_dim
,
const
framework
::
DDim
&
y_dim
,
const
framework
::
Tensor
&
x
,
...
...
@@ -1417,13 +1420,14 @@ void ElemwiseGradComputeNoBroadcast(
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
.
device_context
<
DeviceContext
>
(),
N
);
#endif // !_WIN32
for_range
(
ElemwiseGradNoBroadcast
<
T
,
DX_OP
,
DY_OP
>
{
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
dx_op
,
dy
_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
for_range
(
ElemwiseGradNoBroadcast
<
T
,
DX_OP
,
DY_OP
,
Tout
>
{
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
out
>
(),
dout
.
data
<
Tout
>
(),
dx
_op
,
d
y_op
,
d
x
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())});
}
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
void
ElemwiseGradComputeWithBroadcast
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
const
framework
::
Tensor
&
x
,
...
...
@@ -1463,7 +1467,7 @@ void ElemwiseGradComputeWithBroadcast(
}
// special case for common backward implementation.
if
(
is_run_common_broadcast
)
{
CommonElementwiseBroadcastBackward
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
CommonElementwiseBroadcastBackward
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
,
Tout
>
(
ctx
,
x_dims
,
y_dims
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
return
;
}
...
...
@@ -1472,14 +1476,14 @@ void ElemwiseGradComputeWithBroadcast(
#if defined(__NVCC__) || defined(__HIPCC__)
ElemwiseGradBroadcast1CUDA
(
ctx
.
template
device_context
<
DeviceContext
>().
stream
(),
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
is_xsize_larger
,
dx_op
,
dy_op
,
y
.
data
<
T
>
(),
out
.
data
<
T
out
>
(),
dout
.
data
<
Tout
>
(),
pre
,
n
,
is_xsize_larger
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
#endif
}
else
{
ElemwiseGradBroadcast1CPU
(
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
out
>
(),
dout
.
data
<
Tout
>
(),
pre
,
n
,
is_xsize_larger
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
...
...
@@ -1489,15 +1493,15 @@ void ElemwiseGradComputeWithBroadcast(
#if defined(__NVCC__) || defined(__HIPCC__)
ElemwiseGradBroadcast2CUDA
(
ctx
.
template
device_context
<
DeviceContext
>().
stream
(),
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
y
.
data
<
T
>
(),
out
.
data
<
T
out
>
(),
dout
.
data
<
Tout
>
(),
pre
,
n
,
post
,
is_xsize_larger
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
#endif
}
else
{
ElemwiseGradBroadcast2CPU
(
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
is_xsize_larger
,
dx_op
,
dy_op
,
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
out
>
(),
dout
.
data
<
Tout
>
(),
pre
,
n
,
post
,
is_xsize_larger
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
...
...
@@ -1521,7 +1525,8 @@ void CommonElementwiseBroadcastForward(
axis
,
is_xsize_larger
);
}
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
Tout
=
T
>
void
ElemwiseGradCompute
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
,
const
framework
::
Tensor
&
out
,
...
...
@@ -1531,10 +1536,10 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
const
framework
::
DDim
&
x_dim
=
x
.
dims
();
const
framework
::
DDim
&
y_dim
=
y
.
dims
();
if
(
x
.
dims
()
==
y
.
dims
())
{
ElemwiseGradComputeNoBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ElemwiseGradComputeNoBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
,
Tout
>
(
ctx
,
x_dim
,
y_dim
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
}
else
{
ElemwiseGradComputeWithBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ElemwiseGradComputeWithBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
,
Tout
>
(
ctx
,
x_dim
,
y_dim
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
}
}
...
...
python/paddle/__init__.py
浏览文件 @
31e874b1
...
...
@@ -88,6 +88,7 @@ from .tensor.creation import meshgrid # noqa: F401
from
.tensor.creation
import
empty
# noqa: F401
from
.tensor.creation
import
empty_like
# noqa: F401
from
.tensor.creation
import
assign
# noqa: F401
from
.tensor.creation
import
complex
# noqa: F401
from
.tensor.linalg
import
matmul
# noqa: F401
from
.tensor.linalg
import
dot
# noqa: F401
from
.tensor.linalg
import
norm
# noqa: F401
...
...
@@ -446,6 +447,7 @@ __all__ = [ # noqa
'shape'
,
'real'
,
'imag'
,
'complex'
,
'reciprocal'
,
'rand'
,
'less_equal'
,
...
...
python/paddle/fluid/tests/unittests/test_complex_op.py
0 → 100644
浏览文件 @
31e874b1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle
from
paddle.fluid
import
dygraph
from
paddle
import
static
paddle
.
enable_static
()
def
ref_complex
(
x
,
y
):
return
x
+
1j
*
y
def
ref_complex_grad
(
x
,
y
,
dout
):
out
=
x
+
1j
*
y
out_rank
=
out
.
ndim
delta_rank_x
=
out_rank
-
x
.
ndim
delta_rank_y
=
out_rank
-
y
.
ndim
dx_reduce_axes
=
[]
dy_reduce_axes
=
[]
for
i
in
range
(
out_rank
):
if
i
<
delta_rank_x
or
dout
.
shape
[
i
]
>
x
.
shape
[
i
-
delta_rank_x
]:
dx_reduce_axes
.
append
(
i
)
if
i
<
delta_rank_y
or
dout
.
shape
[
i
]
>
y
.
shape
[
i
-
delta_rank_y
]:
dy_reduce_axes
.
append
(
i
)
dx
=
np
.
sum
(
dout
.
real
,
axis
=
tuple
(
dx_reduce_axes
)).
reshape
(
x
.
shape
)
dy
=
np
.
sum
(
dout
.
imag
,
axis
=
tuple
(
dy_reduce_axes
)).
reshape
(
y
.
shape
)
return
(
dx
,
dy
)
class
TestComplexOp
(
OpTest
):
def
init_spec
(
self
):
self
.
x_shape
=
[
10
,
10
]
self
.
y_shape
=
[
10
,
10
]
self
.
dtype
=
"float64"
def
setUp
(
self
):
self
.
op_type
=
"complex"
self
.
init_spec
()
x
=
np
.
random
.
randn
(
*
self
.
x_shape
).
astype
(
self
.
dtype
)
y
=
np
.
random
.
randn
(
*
self
.
y_shape
).
astype
(
self
.
dtype
)
out_ref
=
ref_complex
(
x
,
y
)
self
.
out_grad
=
np
.
random
.
randn
(
*
self
.
x_shape
).
astype
(
self
.
dtype
)
\
+
1j
*
np
.
random
.
randn
(
*
self
.
y_shape
).
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
x
,
'Y'
:
y
}
self
.
outputs
=
{
'Out'
:
out_ref
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
dout
=
self
.
out_grad
dx
,
dy
=
ref_complex_grad
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
],
self
.
out_grad
)
self
.
check_grad
(
[
'X'
,
'Y'
],
'Out'
,
user_defined_grads
=
[
dx
,
dy
],
user_defined_grad_outputs
=
[
dout
])
def
test_check_grad_ignore_x
(
self
):
dout
=
self
.
out_grad
dx
,
dy
=
ref_complex_grad
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
],
self
.
out_grad
)
self
.
assertTupleEqual
(
dx
.
shape
,
tuple
(
self
.
x_shape
))
self
.
assertTupleEqual
(
dy
.
shape
,
tuple
(
self
.
y_shape
))
self
.
check_grad
(
[
'Y'
],
'Out'
,
no_grad_set
=
set
(
'X'
),
user_defined_grads
=
[
dy
],
user_defined_grad_outputs
=
[
dout
])
def
test_check_grad_ignore_y
(
self
):
dout
=
self
.
out_grad
dx
,
dy
=
ref_complex_grad
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
],
self
.
out_grad
)
self
.
check_grad
(
[
'X'
],
'Out'
,
no_grad_set
=
set
(
'Y'
),
user_defined_grads
=
[
dx
],
user_defined_grad_outputs
=
[
dout
])
class
TestComplexOpBroadcast1
(
TestComplexOp
):
def
init_spec
(
self
):
self
.
x_shape
=
[
10
,
3
,
1
,
4
]
self
.
y_shape
=
[
100
,
1
]
self
.
dtype
=
"float64"
class
TestComplexOpBroadcast2
(
TestComplexOp
):
def
init_spec
(
self
):
self
.
x_shape
=
[
100
,
1
]
self
.
y_shape
=
[
10
,
3
,
1
,
4
]
self
.
dtype
=
"float32"
class
TestComplexOpBroadcast3
(
TestComplexOp
):
def
init_spec
(
self
):
self
.
x_shape
=
[
1
,
100
]
self
.
y_shape
=
[
100
]
self
.
dtype
=
"float32"
class
TestComplexAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
x
=
np
.
random
.
randn
(
10
,
10
)
self
.
y
=
np
.
random
.
randn
(
10
,
10
)
self
.
out
=
ref_complex
(
self
.
x
,
self
.
y
)
def
test_dygraph
(
self
):
with
dygraph
.
guard
():
x
=
paddle
.
to_tensor
(
self
.
x
)
y
=
paddle
.
to_tensor
(
self
.
y
)
out_np
=
paddle
.
complex
(
x
,
y
).
numpy
()
self
.
assertTrue
(
np
.
allclose
(
self
.
out
,
out_np
))
def
test_static
(
self
):
mp
,
sp
=
static
.
Program
(),
static
.
Program
()
with
static
.
program_guard
(
mp
,
sp
):
x
=
static
.
data
(
"x"
,
shape
=
[
10
,
10
],
dtype
=
"float64"
)
y
=
static
.
data
(
"y"
,
shape
=
[
10
,
10
],
dtype
=
"float64"
)
out
=
paddle
.
complex
(
x
,
y
)
exe
=
static
.
Executor
()
exe
.
run
(
sp
)
[
out_np
]
=
exe
.
run
(
mp
,
feed
=
{
"x"
:
self
.
x
,
"y"
:
self
.
y
},
fetch_list
=
[
out
])
self
.
assertTrue
(
np
.
allclose
(
self
.
out
,
out_np
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py
浏览文件 @
31e874b1
...
...
@@ -68,6 +68,7 @@ NEED_TO_FIX_OP_LIST = [
'rank_loss'
,
'sequence_conv'
,
'smooth_l1_loss'
,
'spectral_norm'
'spectral_norm'
,
'complex'
,
]
# yapf: enable
python/paddle/tensor/__init__.py
浏览文件 @
31e874b1
...
...
@@ -33,6 +33,7 @@ from .creation import tril # noqa: F401
from
.creation
import
meshgrid
# noqa: F401
from
.creation
import
empty
# noqa: F401
from
.creation
import
empty_like
# noqa: F401
from
.creation
import
complex
# noqa: F401
from
.linalg
import
matmul
# noqa: F401
from
.linalg
import
dot
# noqa: F401
from
.linalg
import
norm
# noqa: F401
...
...
python/paddle/tensor/creation.py
浏览文件 @
31e874b1
...
...
@@ -27,6 +27,7 @@ from ..fluid.layers import core
from
..fluid.layer_helper
import
LayerHelper
from
..fluid.data_feeder
import
check_variable_and_dtype
,
check_type
,
check_dtype
,
convert_dtype
from
..fluid.framework
import
convert_np_dtype_to_dtype_
,
in_dygraph_mode
,
_varbase_creator
,
device_guard
,
OpProtoHolder
from
paddle.tensor.attribute
import
_complex_to_real_dtype
,
_real_to_complex_dtype
# TODO: define functions to get create a tensor
from
..fluid.layers
import
linspace
# noqa: F401
import
paddle
...
...
@@ -1250,3 +1251,46 @@ def _memcpy(input, place=None, output=None):
outputs
=
{
'Out'
:
[
output
]},
attrs
=
attrs
)
return
output
def
complex
(
real
,
imag
,
name
=
None
):
"""Return a compelx tensor given the real and image component.
Args:
real (Tensor): The real component. The data type should be 'float32' or 'float64'.
imag (Tensor): The image component. The data type should be the same as ``real``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The output tensor. The data type is 'complex64' or 'complex128', with the same precision as ``real`` and ``imag``.
**Note**:
``paddle.complex`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Examples:
.. code-block:: python
import paddle
x = paddle.arange(2, dtype=paddle.float32).unsqueeze(-1)
y = paddle.arange(3, dtype=paddle.float32)
z = paddle.complex(x, y)
print(z.numpy())
# [[0.+0.j 0.+1.j 0.+2.j]
# [1.+0.j 1.+1.j 1.+2.j]]
"""
if
in_dygraph_mode
():
return
paddle
.
_C_ops
.
complex
(
real
,
imag
)
check_variable_and_dtype
(
real
,
'real'
,
[
'float32'
,
'float64'
],
'complex'
)
check_variable_and_dtype
(
imag
,
'imag'
,
[
'float32'
,
'float64'
],
'complex'
)
op_type
=
"complex"
helper
=
LayerHelper
(
op_type
,
**
locals
())
inputs
=
{
"X"
:
real
,
"Y"
:
imag
}
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
_real_to_complex_dtype
(
real
.
dtype
))
outputs
=
{
"Out"
:
out
}
attrs
=
{}
helper
.
append_op
(
type
=
op_type
,
inputs
=
inputs
,
attrs
=
attrs
,
outputs
=
outputs
)
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录