Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
666e6651
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
666e6651
编写于
1月 05, 2021
作者:
C
chentianyu03
提交者:
GitHub
1月 05, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change the kron gradient when complex types (#29995)
上级
a5e422c8
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
210 addition
and
0 deletion
+210
-0
paddle/fluid/operators/kron_op.h
paddle/fluid/operators/kron_op.h
+125
-0
python/paddle/fluid/tests/unittests/test_kron_op.py
python/paddle/fluid/tests/unittests/test_kron_op.py
+85
-0
未找到文件。
paddle/fluid/operators/kron_op.h
浏览文件 @
666e6651
...
...
@@ -26,6 +26,9 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
complex64
=
paddle
::
platform
::
complex64
;
using
complex128
=
paddle
::
platform
::
complex128
;
// Process an element in the output, used with a parallel-for
template
<
typename
T
>
struct
KronElemFunctor
{
...
...
@@ -172,6 +175,128 @@ struct KronGradElemFunctor {
const
int
ndims_
;
};
template
<
>
struct
KronGradElemFunctor
<
complex64
>
{
KronGradElemFunctor
(
const
complex64
*
dout
,
const
complex64
*
A
,
const
complex64
*
B
,
complex64
*
dout_a
,
complex64
*
dout_b
,
const
int64_t
*
stride_dout
,
const
int64_t
*
stride_a
,
const
int64_t
*
stride_b
,
const
int64_t
*
shape_b
,
const
int64_t
numel_a
,
const
int64_t
numel_b
,
const
int
ndims
)
:
dout_
(
dout
),
A_
(
A
),
B_
(
B
),
dout_a_
(
dout_a
),
dout_b_
(
dout_b
),
stride_dout_
(
stride_dout
),
stride_a_
(
stride_a
),
stride_b_
(
stride_b
),
shape_b_
(
shape_b
),
numel_a_
(
numel_a
),
numel_b_
(
numel_b
),
ndims_
(
ndims
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
{
int64_t
index
=
idx
;
int64_t
index_a
=
0
;
int64_t
index_b
=
0
;
for
(
int
i
=
0
;
i
<
ndims_
;
i
++
)
{
auto
pos_i
=
index
/
stride_dout_
[
i
];
index
=
index
%
stride_dout_
[
i
];
auto
pos_ai
=
pos_i
/
shape_b_
[
i
];
auto
pos_bi
=
pos_i
%
shape_b_
[
i
];
index_a
+=
stride_a_
[
i
]
*
pos_ai
;
index_b
+=
stride_b_
[
i
]
*
pos_bi
;
}
if
(
dout_a_
)
{
size_t
index_out_a
=
index_a
*
numel_b_
+
index_b
;
dout_a_
[
index_out_a
]
=
dout_
[
idx
]
*
complex64
(
B_
[
index_b
].
real
,
-
B_
[
index_b
].
imag
);
}
if
(
dout_b_
)
{
size_t
index_out_b
=
index_b
*
numel_a_
+
index_a
;
dout_b_
[
index_out_b
]
=
dout_
[
idx
]
*
complex64
(
A_
[
index_a
].
real
,
-
A_
[
index_a
].
imag
);
}
}
private:
const
complex64
*
dout_
;
const
complex64
*
A_
;
const
complex64
*
B_
;
complex64
*
dout_a_
;
complex64
*
dout_b_
;
const
int64_t
*
stride_dout_
;
const
int64_t
*
stride_a_
;
const
int64_t
*
stride_b_
;
const
int64_t
*
shape_b_
;
const
int64_t
numel_a_
;
const
int64_t
numel_b_
;
const
int
ndims_
;
};
template
<
>
struct
KronGradElemFunctor
<
complex128
>
{
KronGradElemFunctor
(
const
complex128
*
dout
,
const
complex128
*
A
,
const
complex128
*
B
,
complex128
*
dout_a
,
complex128
*
dout_b
,
const
int64_t
*
stride_dout
,
const
int64_t
*
stride_a
,
const
int64_t
*
stride_b
,
const
int64_t
*
shape_b
,
const
int64_t
numel_a
,
const
int64_t
numel_b
,
const
int
ndims
)
:
dout_
(
dout
),
A_
(
A
),
B_
(
B
),
dout_a_
(
dout_a
),
dout_b_
(
dout_b
),
stride_dout_
(
stride_dout
),
stride_a_
(
stride_a
),
stride_b_
(
stride_b
),
shape_b_
(
shape_b
),
numel_a_
(
numel_a
),
numel_b_
(
numel_b
),
ndims_
(
ndims
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
{
int64_t
index
=
idx
;
int64_t
index_a
=
0
;
int64_t
index_b
=
0
;
for
(
int
i
=
0
;
i
<
ndims_
;
i
++
)
{
auto
pos_i
=
index
/
stride_dout_
[
i
];
index
=
index
%
stride_dout_
[
i
];
auto
pos_ai
=
pos_i
/
shape_b_
[
i
];
auto
pos_bi
=
pos_i
%
shape_b_
[
i
];
index_a
+=
stride_a_
[
i
]
*
pos_ai
;
index_b
+=
stride_b_
[
i
]
*
pos_bi
;
}
if
(
dout_a_
)
{
size_t
index_out_a
=
index_a
*
numel_b_
+
index_b
;
dout_a_
[
index_out_a
]
=
dout_
[
idx
]
*
complex128
(
B_
[
index_b
].
real
,
-
B_
[
index_b
].
imag
);
}
if
(
dout_b_
)
{
size_t
index_out_b
=
index_b
*
numel_a_
+
index_a
;
dout_b_
[
index_out_b
]
=
dout_
[
idx
]
*
complex128
(
A_
[
index_a
].
real
,
-
A_
[
index_a
].
imag
);
}
}
private:
const
complex128
*
dout_
;
const
complex128
*
A_
;
const
complex128
*
B_
;
complex128
*
dout_a_
;
complex128
*
dout_b_
;
const
int64_t
*
stride_dout_
;
const
int64_t
*
stride_a_
;
const
int64_t
*
stride_b_
;
const
int64_t
*
shape_b_
;
const
int64_t
numel_a_
;
const
int64_t
numel_b_
;
const
int
ndims_
;
};
template
<
typename
T
>
struct
IdentityFunctor
{
HOSTDEVICE
explicit
inline
IdentityFunctor
()
{}
...
...
python/paddle/fluid/tests/unittests/test_kron_op.py
浏览文件 @
666e6651
...
...
@@ -102,5 +102,90 @@ class TestKronLayer(unittest.TestCase):
np
.
testing
.
assert_allclose
(
c
,
np
.
kron
(
a
,
b
))
class
TestComplexKronOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"kron"
self
.
x_shape
=
np
.
array
([
10
,
10
])
self
.
y_shape
=
np
.
array
([
3
,
35
])
self
.
out_shape
=
self
.
x_shape
*
self
.
y_shape
self
.
init_base_dtype
()
self
.
init_input_output
()
self
.
init_grad_input_output
()
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
self
.
x
),
'Y'
:
OpTest
.
np_dtype_to_fluid_dtype
(
self
.
y
)
}
self
.
attrs
=
{
'axis'
:
-
1
,
'use_mkldnn'
:
False
}
self
.
outputs
=
{
'Out'
:
self
.
out
}
def
init_base_dtype
(
self
):
self
.
dtype
=
np
.
float64
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
dtype
)
+
1J
*
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
random
(
self
.
y_shape
).
astype
(
self
.
dtype
)
+
1J
*
np
.
random
.
random
(
self
.
y_shape
).
astype
(
self
.
dtype
)
self
.
out
=
np
.
kron
(
self
.
x
,
self
.
y
)
def
init_grad_input_output
(
self
):
self
.
grad_out
=
np
.
ones
(
self
.
out_shape
,
self
.
dtype
)
+
1J
*
np
.
ones
(
self
.
out_shape
,
self
.
dtype
)
self
.
grad_x
=
self
.
get_grad_x_by_numpy
()
self
.
grad_y
=
self
.
get_grad_y_by_numpy
()
def
get_grad_x_by_numpy
(
self
):
grad_x
=
np
.
zeros
(
self
.
x_shape
,
np
.
complex
)
for
x_i
in
range
(
self
.
x_shape
[
0
]):
for
x_j
in
range
(
self
.
x_shape
[
1
]):
for
i
in
range
(
self
.
y_shape
[
0
]):
for
j
in
range
(
self
.
y_shape
[
1
]):
idx_i
=
x_i
*
self
.
y_shape
[
0
]
+
i
idx_j
=
x_j
*
self
.
y_shape
[
1
]
+
j
grad_x
[
x_i
][
x_j
]
+=
self
.
grad_out
[
idx_i
][
idx_j
]
*
np
.
conj
(
self
.
y
[
i
][
j
])
return
grad_x
def
get_grad_y_by_numpy
(
self
):
grad_y
=
np
.
zeros
(
self
.
y_shape
,
np
.
complex
)
for
y_i
in
range
(
self
.
y_shape
[
0
]):
for
y_j
in
range
(
self
.
y_shape
[
1
]):
for
x_i
in
range
(
self
.
x_shape
[
0
]):
for
x_j
in
range
(
self
.
x_shape
[
1
]):
idx_i
=
x_i
*
self
.
y_shape
[
0
]
+
y_i
idx_j
=
x_j
*
self
.
y_shape
[
1
]
+
y_j
grad_y
[
y_i
][
y_j
]
+=
self
.
grad_out
[
idx_i
][
idx_j
]
*
np
.
conj
(
self
.
x
[
x_i
][
x_j
])
return
grad_y
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad_normal
(
self
):
self
.
check_grad
(
[
'X'
,
'Y'
],
'Out'
,
user_defined_grads
=
[
self
.
grad_x
,
self
.
grad_y
],
user_defined_grad_outputs
=
[
self
.
grad_out
])
def
test_check_grad_ingore_x
(
self
):
self
.
check_grad
(
[
'Y'
],
'Out'
,
no_grad_set
=
set
(
"X"
),
user_defined_grads
=
[
self
.
grad_y
],
user_defined_grad_outputs
=
[
self
.
grad_out
])
def
test_check_grad_ingore_y
(
self
):
self
.
check_grad
(
[
'X'
],
'Out'
,
no_grad_set
=
set
(
'Y'
),
user_defined_grads
=
[
self
.
grad_x
],
user_defined_grad_outputs
=
[
self
.
grad_out
])
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录