Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
292f3f77
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
292f3f77
编写于
1月 16, 2023
作者:
X
xiaoguoguo626807
提交者:
GitHub
1月 16, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【prim】vjp for reduce sum (#49736)
上级
e70af91d
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
363 addition
and
15 deletion
+363
-15
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
+31
-0
paddle/fluid/prim/api/manual/backward/composite_backward_api.h
...e/fluid/prim/api/manual/backward/composite_backward_api.h
+47
-3
paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc
paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc
+10
-0
paddle/fluid/prim/api/manual/prim_api/prim_api.h
paddle/fluid/prim/api/manual/prim_api/prim_api.h
+9
-2
paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc
paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc
+35
-0
paddle/phi/api/yaml/legacy_backward.yaml
paddle/phi/api/yaml/legacy_backward.yaml
+1
-0
python/paddle/fluid/tests/unittests/prim/CMakeLists.txt
python/paddle/fluid/tests/unittests/prim/CMakeLists.txt
+0
-1
python/paddle/fluid/tests/unittests/prim/comp/CMakeLists.txt
python/paddle/fluid/tests/unittests/prim/comp/CMakeLists.txt
+0
-9
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt
.../fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py
+104
-0
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt
...fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py
...ests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py
+124
-0
未找到文件。
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
浏览文件 @
292f3f77
...
...
@@ -17,6 +17,7 @@
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
...
...
@@ -63,6 +64,35 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
class
ReduceSumCompositeGradOpMaker
:
public
prim
::
GradCompositeOpMakerBase
{
public:
using
prim
::
GradCompositeOpMakerBase
::
GradCompositeOpMakerBase
;
void
Apply
()
override
{
// get inputs
paddle
::
experimental
::
Tensor
x
=
this
->
GetSingleForwardInput
(
"X"
);
paddle
::
experimental
::
Tensor
out_grad
=
this
->
GetSingleOutputGrad
(
"Out"
);
// get attr
std
::
vector
<
int
>
axis
=
this
->
Attr
<
std
::
vector
<
int
>>
(
"dim"
);
bool
keep_dim
=
this
->
Attr
<
bool
>
(
"keep_dim"
);
bool
reduce_all
=
this
->
Attr
<
bool
>
(
"reduce_all"
);
// get output
paddle
::
experimental
::
Tensor
x_grad_t
=
this
->
GetSingleInputGrad
(
"X"
);
// get output ptr
paddle
::
experimental
::
Tensor
*
x_grad
=
this
->
GetOutputPtr
(
&
x_grad_t
);
// get output orginal name
std
::
string
x_grad_name
=
this
->
GetOutputName
(
x_grad_t
);
// call composite backward func
prim
::
sum_grad
<
prim
::
DescTensor
>
(
x
,
out_grad
,
axis
,
keep_dim
,
reduce_all
,
x_grad
);
// recover output name
this
->
RecoverOutputName
(
x_grad_t
,
x_grad_name
);
}
};
template
<
typename
T
>
class
ReduceSumDoubleOpGradMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
...
...
@@ -114,6 +144,7 @@ REGISTER_OPERATOR(reduce_sum,
ops
::
ReduceSumVarTypeInference
,
ops
::
ReduceSumOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
ReduceSumOpGradMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
ReduceSumCompositeGradOpMaker
,
ReduceSumInferShapeFunctor
);
REGISTER_OPERATOR
(
reduce_sum_grad
,
ops
::
ReduceGradOp
,
...
...
paddle/fluid/prim/api/manual/backward/composite_backward_api.h
浏览文件 @
292f3f77
...
...
@@ -15,9 +15,15 @@
#pragma once
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/ddim.h"
namespace
paddle
{
namespace
prim
{
using
Tensor
=
paddle
::
experimental
::
Tensor
;
using
IntArray
=
paddle
::
experimental
::
IntArrayBase
<
paddle
::
experimental
::
Tensor
>
;
// using IntArray = paddle::experimental::IntArray;
// This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h
template
<
typename
T
>
...
...
@@ -94,6 +100,44 @@ void add_grad(const Tensor& x,
}
}
template
<
typename
T
>
void
sum_grad
(
const
Tensor
&
x
,
const
Tensor
&
out_grad
,
const
IntArray
&
axis
,
bool
keepdim
,
bool
reduce_all
,
Tensor
*
x_grad
)
{
if
(
!
x_grad
)
{
return
;
}
std
::
vector
<
int
>
x_dim
=
phi
::
vectorize
<
int
>
(
x
.
dims
());
int64_t
axis_size
=
axis
.
size
();
int64_t
x_dim_size
=
x_dim
.
size
();
reduce_all
=
false
;
if
(
reduce_all
||
axis_size
==
0
||
axis_size
==
x_dim_size
)
{
reduce_all
=
true
;
}
else
{
reduce_all
=
false
;
}
auto
x_grad_tmp
=
Tensor
();
if
(
!
keepdim
)
{
auto
axis_
=
std
::
vector
<
int64_t
>
();
if
(
reduce_all
)
{
for
(
int64_t
i
=
1
;
i
<
x_dim_size
;
i
++
)
{
axis_
.
push_back
(
i
);
}
}
else
{
axis_
=
axis
.
GetData
();
}
auto
out_grad_
=
unsqueeze
<
T
>
(
out_grad
,
axis_
);
x_grad_tmp
=
expand
<
T
>
(
out_grad_
,
x_dim
);
}
else
{
x_grad_tmp
=
expand
<
T
>
(
out_grad
,
x_dim
);
}
x_grad
->
set_impl
(
x_grad_tmp
.
impl
());
}
template
<
typename
T
>
void
divide_grad
(
const
Tensor
&
x
,
const
Tensor
&
y
,
...
...
paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc
浏览文件 @
292f3f77
...
...
@@ -36,6 +36,16 @@ Tensor multiply<Tensor>(const Tensor& x, const Tensor& y) {
return
::
multiply_ad_func
(
x
,
y
);
}
template
<
>
Tensor
expand
<
Tensor
>
(
const
Tensor
&
x
,
const
IntArray
&
shape
)
{
return
::
expand_ad_func
(
x
,
shape
);
}
template
<
>
Tensor
unsqueeze
<
Tensor
>
(
const
Tensor
&
x
,
const
IntArray
&
axis
)
{
return
::
unsqueeze_ad_func
(
x
,
axis
);
}
template
<
>
Tensor
divide
<
Tensor
>
(
const
Tensor
&
x
,
const
Tensor
&
y
)
{
return
::
divide_ad_func
(
x
,
y
);
...
...
paddle/fluid/prim/api/manual/prim_api/prim_api.h
浏览文件 @
292f3f77
...
...
@@ -21,18 +21,25 @@ namespace prim {
using
Tensor
=
paddle
::
experimental
::
Tensor
;
using
IntArray
=
paddle
::
experimental
::
IntArray
;
using
Scalar
=
paddle
::
experimental
::
Scalar
;
template
<
typename
T
>
Tensor
pow
(
const
Tensor
&
x
,
const
paddle
::
experimental
::
Scalar
&
y
);
Tensor
pow
(
const
Tensor
&
x
,
const
Scalar
&
y
);
template
<
typename
T
>
Tensor
scale
(
const
Tensor
&
X
,
const
paddle
::
experimental
::
Scalar
&
scale
,
const
Scalar
&
scale
,
float
bias
,
bool
bias_after_scale
);
template
<
typename
T
>
Tensor
multiply
(
const
Tensor
&
x
,
const
Tensor
&
y
);
template
<
typename
T
>
Tensor
expand
(
const
Tensor
&
x
,
const
IntArray
&
shape
);
template
<
typename
T
>
Tensor
unsqueeze
(
const
Tensor
&
x
,
const
IntArray
&
axis
);
template
<
typename
T
>
Tensor
divide
(
const
Tensor
&
x
,
const
Tensor
&
y
);
...
...
paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc
浏览文件 @
292f3f77
...
...
@@ -94,6 +94,23 @@ Tensor multiply<DescTensor>(const Tensor& x, const Tensor& y) {
return
out
;
}
template
<
>
Tensor
expand
<
DescTensor
>
(
const
Tensor
&
x
,
const
IntArray
&
shape
)
{
Tensor
out
=
empty
<
DescTensor
>
({},
phi
::
DataType
::
FLOAT32
,
paddle
::
Place
());
framework
::
BlockDesc
*
block
=
StaticCompositeContext
::
Instance
().
GetBlock
();
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"expand_v2"
);
op
->
SetInput
(
"X"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
x
.
impl
())
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
std
::
vector
<
int
>
new_shape
(
shape
.
GetData
().
begin
(),
shape
.
GetData
().
end
());
op
->
SetAttr
(
"shape"
,
new_shape
);
op
->
CheckAttrs
();
op
->
InferVarType
(
block
);
return
out
;
}
template
<
>
Tensor
divide
<
DescTensor
>
(
const
Tensor
&
x
,
const
Tensor
&
y
)
{
// Grad infershape
...
...
@@ -113,6 +130,23 @@ Tensor divide<DescTensor>(const Tensor& x, const Tensor& y) {
return
out
;
}
template
<
>
Tensor
unsqueeze
<
DescTensor
>
(
const
Tensor
&
x
,
const
IntArray
&
axis
)
{
Tensor
out
=
empty
<
DescTensor
>
({},
phi
::
DataType
::
FLOAT32
,
paddle
::
Place
());
framework
::
BlockDesc
*
block
=
StaticCompositeContext
::
Instance
().
GetBlock
();
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"unsqueeze2"
);
op
->
SetInput
(
"X"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
x
.
impl
())
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
std
::
vector
<
int
>
new_shape
(
axis
.
GetData
().
begin
(),
axis
.
GetData
().
end
());
op
->
SetAttr
(
"axes"
,
new_shape
);
op
->
CheckAttrs
();
op
->
InferVarType
(
block
);
return
out
;
}
template
<
>
Tensor
full
<
DescTensor
>
(
paddle
::
experimental
::
IntArray
shape
,
paddle
::
experimental
::
Scalar
value
,
...
...
@@ -141,6 +175,7 @@ Tensor full<DescTensor>(paddle::experimental::IntArray shape,
op
->
InferShape
(
*
block
);
return
out
;
}
template
<
>
Tensor
sum
<
DescTensor
>
(
Tensor
x
,
paddle
::
experimental
::
IntArray
axis
,
...
...
paddle/phi/api/yaml/legacy_backward.yaml
浏览文件 @
292f3f77
...
...
@@ -1356,6 +1356,7 @@
param
:
[
x
]
kernel
:
func
:
sum_grad
composite
:
sum_grad(x, out_grad, axis, keepdim, reduce_all, x_grad)
no_need_buffer
:
x
backward
:
sum_double_grad
...
...
python/paddle/fluid/tests/unittests/prim/CMakeLists.txt
浏览文件 @
292f3f77
...
...
@@ -8,5 +8,4 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules
(
${
TEST_OP
}
MODULES
${
TEST_OP
}
ENVS
${
GC_ENVS
}
)
endforeach
()
add_subdirectory
(
comp
)
add_subdirectory
(
prim
)
python/paddle/fluid/tests/unittests/prim/comp/CMakeLists.txt
已删除
100644 → 0
浏览文件 @
e70af91d
file
(
GLOB TEST_OPS
RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"test_*.py"
)
string
(
REPLACE
".py"
""
TEST_OPS
"
${
TEST_OPS
}
"
)
foreach
(
TEST_OP
${
TEST_OPS
}
)
py_test_modules
(
${
TEST_OP
}
MODULES
${
TEST_OP
}
ENVS
${
GC_ENVS
}
)
endforeach
()
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt
浏览文件 @
292f3f77
...
...
@@ -11,5 +11,6 @@ endforeach()
set_tests_properties
(
test_comp_eager_tanh_grad PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_comp_eager_div_grad PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_comp_eager_sum_grad PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_comp_eager_add_grad PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_comp_eager_sub_grad PROPERTIES TIMEOUT 60
)
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py
0 → 100644
浏览文件 @
292f3f77
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
from
paddle.fluid
import
core
def
actual
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
set_prim_enabled
(
False
)
x
=
paddle
.
to_tensor
(
primal
,
dtype
=
'float32'
,
stop_gradient
=
False
)
v
=
paddle
.
to_tensor
(
cotangent
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
sum
(
x
,
axis
=
axis
,
keepdim
=
keep_dim
)
x_cotangent
=
paddle
.
grad
(
y
,
x
,
v
,
create_graph
=
True
,
retain_graph
=
True
)
return
x_cotangent
[
0
]
def
desired
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
set_prim_enabled
(
True
)
x
=
paddle
.
to_tensor
(
primal
,
dtype
=
'float32'
,
stop_gradient
=
False
)
v
=
paddle
.
to_tensor
(
cotangent
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
sum
(
x
,
axis
=
axis
,
keepdim
=
keep_dim
)
x_cotangent
=
paddle
.
grad
(
y
,
x
,
v
,
create_graph
=
True
,
retain_graph
=
True
)
return
x_cotangent
[
0
]
class
TestSumGradComp
(
unittest
.
TestCase
):
def
test_sum_grad_comp_1
(
self
):
self
.
primal
=
np
.
random
.
rand
(
10
,
10
)
self
.
cotangent
=
np
.
random
.
rand
(
1
)
paddle
.
disable_static
()
np
.
testing
.
assert_allclose
(
actual
=
actual
(
self
.
primal
,
self
.
cotangent
,
[],
False
),
desired
=
desired
(
self
.
primal
,
self
.
cotangent
,
[],
False
),
rtol
=
1e-6
,
atol
=
0
,
)
def
test_sum_grad_comp_2
(
self
):
self
.
primal
=
np
.
random
.
rand
(
4
,
3
,
2
)
self
.
cotangent
=
np
.
random
.
rand
(
4
,
2
)
paddle
.
disable_static
()
np
.
testing
.
assert_allclose
(
actual
=
actual
(
self
.
primal
,
self
.
cotangent
,
1
,
False
),
desired
=
desired
(
self
.
primal
,
self
.
cotangent
,
1
,
False
),
rtol
=
1e-6
,
atol
=
0
,
)
def
test_sum_grad_comp_3
(
self
):
self
.
primal
=
np
.
random
.
rand
(
4
,
3
,
2
)
self
.
cotangent
=
np
.
random
.
rand
(
4
,
1
,
2
)
paddle
.
disable_static
()
np
.
testing
.
assert_allclose
(
actual
=
actual
(
self
.
primal
,
self
.
cotangent
,
1
,
True
),
desired
=
desired
(
self
.
primal
,
self
.
cotangent
,
1
,
True
),
rtol
=
1e-6
,
atol
=
0
,
)
def
test_sum_grad_comp_4
(
self
):
self
.
primal
=
np
.
random
.
rand
(
4
,
3
,
2
,
5
)
self
.
cotangent
=
np
.
random
.
rand
(
4
,
1
,
2
,
1
)
paddle
.
disable_static
()
np
.
testing
.
assert_allclose
(
actual
=
actual
(
self
.
primal
,
self
.
cotangent
,
[
1
,
3
],
True
),
desired
=
desired
(
self
.
primal
,
self
.
cotangent
,
[
1
,
3
],
True
),
rtol
=
1e-6
,
atol
=
0
,
)
def
test_sum_grad_comp_5
(
self
):
self
.
primal
=
np
.
random
.
rand
(
4
,
3
,
2
,
5
)
self
.
cotangent
=
np
.
random
.
rand
(
4
,
2
)
paddle
.
disable_static
()
np
.
testing
.
assert_allclose
(
actual
=
actual
(
self
.
primal
,
self
.
cotangent
,
[
1
,
3
],
False
),
desired
=
desired
(
self
.
primal
,
self
.
cotangent
,
[
1
,
3
],
False
),
rtol
=
1e-6
,
atol
=
0
,
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt
浏览文件 @
292f3f77
...
...
@@ -11,6 +11,7 @@ endforeach()
set_tests_properties
(
test_comp_tanh_grad PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_comp_div_grad PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_comp_sum_grad PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_comp_add_grad PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_comp_sub_grad PROPERTIES TIMEOUT 60
)
set_tests_properties
(
test_comp_add_tanh_grad PROPERTIES TIMEOUT 60
)
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py
0 → 100644
浏览文件 @
292f3f77
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
from
paddle.fluid
import
core
def
actual
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
set_prim_enabled
(
False
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal'
,
primal
.
shape
,
primal
.
dtype
)
x
.
stop_gradient
=
False
v
=
paddle
.
static
.
data
(
'cotangent'
,
cotangent
.
shape
,
cotangent
.
dtype
)
y
=
paddle
.
sum
(
x
,
axis
=
axis
,
keepdim
=
keep_dim
)
x_cotangent
=
paddle
.
static
.
gradients
(
y
,
x
,
None
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
result
=
exe
.
run
(
program
=
mp
,
feed
=
{
'primal'
:
primal
,
'cotangent'
:
cotangent
},
fetch_list
=
[
x_cotangent
],
)[
0
]
return
result
def
desired
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
set_prim_enabled
(
True
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal'
,
primal
.
shape
,
primal
.
dtype
)
x
.
stop_gradient
=
False
v
=
paddle
.
static
.
data
(
'cotangent'
,
cotangent
.
shape
,
cotangent
.
dtype
)
y
=
paddle
.
sum
(
x
,
axis
=
axis
,
keepdim
=
keep_dim
)
x_cotangent
=
paddle
.
static
.
gradients
(
y
,
x
,
None
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
result
=
exe
.
run
(
program
=
mp
,
feed
=
{
'primal'
:
primal
,
'cotangent'
:
cotangent
},
fetch_list
=
[
x_cotangent
],
)[
0
]
return
result
class
TestSumGradComp
(
unittest
.
TestCase
):
def
test_sum_grad_comp_1
(
self
):
self
.
primal
=
np
.
random
.
rand
(
10
,
10
)
self
.
cotangent
=
np
.
random
.
rand
(
1
,
1
)
paddle
.
enable_static
()
np
.
testing
.
assert_allclose
(
actual
=
actual
(
self
.
primal
,
self
.
cotangent
,
[],
True
),
desired
=
desired
(
self
.
primal
,
self
.
cotangent
,
[],
True
),
rtol
=
1e-6
,
atol
=
0
,
)
def
test_sum_grad_comp_2
(
self
):
self
.
primal
=
np
.
random
.
rand
(
4
,
3
,
2
)
self
.
cotangent
=
np
.
random
.
rand
(
4
,
2
)
paddle
.
enable_static
()
np
.
testing
.
assert_allclose
(
actual
=
actual
(
self
.
primal
,
self
.
cotangent
,
1
,
False
),
desired
=
desired
(
self
.
primal
,
self
.
cotangent
,
1
,
False
),
rtol
=
1e-6
,
atol
=
0
,
)
def
test_sum_grad_comp_3
(
self
):
self
.
primal
=
np
.
random
.
rand
(
4
,
3
,
2
)
self
.
cotangent
=
np
.
random
.
rand
(
4
,
1
,
2
)
paddle
.
enable_static
()
np
.
testing
.
assert_allclose
(
actual
=
actual
(
self
.
primal
,
self
.
cotangent
,
1
,
True
),
desired
=
desired
(
self
.
primal
,
self
.
cotangent
,
1
,
True
),
rtol
=
1e-6
,
atol
=
0
,
)
def
test_sum_grad_comp_4
(
self
):
self
.
primal
=
np
.
random
.
rand
(
4
,
3
,
2
,
5
)
self
.
cotangent
=
np
.
random
.
rand
(
4
,
1
,
2
,
1
)
paddle
.
enable_static
()
np
.
testing
.
assert_allclose
(
actual
=
actual
(
self
.
primal
,
self
.
cotangent
,
[
1
,
3
],
True
),
desired
=
desired
(
self
.
primal
,
self
.
cotangent
,
[
1
,
3
],
True
),
rtol
=
1e-6
,
atol
=
0
,
)
def
test_sum_grad_comp_5
(
self
):
self
.
primal
=
np
.
random
.
rand
(
4
,
3
,
2
,
5
)
self
.
cotangent
=
np
.
random
.
rand
(
4
,
2
)
paddle
.
enable_static
()
np
.
testing
.
assert_allclose
(
actual
=
actual
(
self
.
primal
,
self
.
cotangent
,
[
1
,
3
],
False
),
desired
=
desired
(
self
.
primal
,
self
.
cotangent
,
[
1
,
3
],
False
),
rtol
=
1e-6
,
atol
=
0
,
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录