Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
bbca66f2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bbca66f2
编写于
3月 02, 2023
作者:
J
Jiabin Yang
提交者:
GitHub
3月 02, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Prim】Fix slice error and eager comp (#51086)
* fix attrs copy error * fix bert by fix slice error * fix op test
上级
41e5667b
变更
28
隐藏空白更改
内联
并排
Showing
28 changed file
with
138 addition
and
106 deletion
+138
-106
paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
...le/fluid/eager/auto_code_generator/generator/eager_gen.py
+2
-2
paddle/fluid/operators/slice_op.cc
paddle/fluid/operators/slice_op.cc
+15
-8
paddle/fluid/prim/api/composite_backward/composite_backward_api.h
...luid/prim/api/composite_backward/composite_backward_api.h
+11
-5
paddle/fluid/prim/utils/static/static_global_utils.cc
paddle/fluid/prim/utils/static/static_global_utils.cc
+1
-0
paddle/fluid/prim/utils/static/static_global_utils.h
paddle/fluid/prim/utils/static/static_global_utils.h
+7
-0
paddle/fluid/prim/utils/utils.cc
paddle/fluid/prim/utils/utils.cc
+8
-0
paddle/fluid/prim/utils/utils.h
paddle/fluid/prim/utils/utils.h
+2
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+4
-0
python/paddle/fluid/core.py
python/paddle/fluid/core.py
+21
-8
python/paddle/fluid/tests/unittests/prim/model/bert.py
python/paddle/fluid/tests/unittests/prim/model/bert.py
+5
-3
python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py
...e/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py
+1
-4
python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py
.../fluid/tests/unittests/prim/prim/flags/test_prim_flags.py
+3
-0
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_cast_grad.py
...nittests/prim/prim/vjp/eager/test_comp_eager_cast_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py
...ttests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py
+3
-3
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_gather_grad.py
...ttests/prim/prim/vjp/eager/test_comp_eager_gather_grad.py
+3
-3
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_matmul_double_grad.py
...prim/prim/vjp/eager/test_comp_eager_matmul_double_grad.py
+6
-6
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py
...ests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_reshape_grad.py
...tests/prim/prim/vjp/eager/test_comp_eager_reshape_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py
...nittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py
...nittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_transpose_grad.py
...sts/prim/prim/vjp/eager/test_comp_eager_transpose_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim_op_test.py
python/paddle/fluid/tests/unittests/prim_op_test.py
+2
-1
python/paddle/fluid/tests/unittests/test_slice_op.py
python/paddle/fluid/tests/unittests/test_slice_op.py
+12
-31
未找到文件。
paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
浏览文件 @
bbca66f2
...
...
@@ -1840,7 +1840,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if
is_composite_grad_api
and
next_grad_node_creation_str
!=
''
:
next_grad_node_creation_str
=
f
"""
if (!paddle::prim::PrimCommonUtils::Is
Bwd
PrimEnabled()) {{
if (!paddle::prim::PrimCommonUtils::Is
Eager
PrimEnabled()) {{
{
next_grad_node_creation_str
}
}}
"""
...
...
@@ -2260,7 +2260,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif
is_composite_grad_api
:
grad_function_call_str
=
f
"""
if (paddle::prim::PrimCommonUtils::Is
Bwd
PrimEnabled()) {{
if (paddle::prim::PrimCommonUtils::Is
Eager
PrimEnabled()) {{
{
indent
}{
composite_grad_api_namespace
}{
composite_grad_api_name
}{
composite_template_name
}
(
{
composite_grad_api_args_str
}
);
VLOG(4) << "Composite api
{
composite_grad_api_name
}
is called ";
}}else{{
...
...
paddle/fluid/operators/slice_op.cc
浏览文件 @
bbca66f2
...
...
@@ -423,19 +423,25 @@ class SliceCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
auto
dx_ptr
=
this
->
GetOutputPtr
(
&
input_grad
);
std
::
string
dx_name
=
this
->
GetOutputName
(
input_grad
);
auto
axes
=
this
->
Attr
<
std
::
vector
<
int
64_t
>>
(
"axes"
);
auto
starts
=
this
->
Attr
<
std
::
vector
<
int
64_t
>>
(
"starts"
);
auto
ends
=
this
->
Attr
<
std
::
vector
<
int
64_t
>>
(
"ends"
);
auto
infer_flags
=
this
->
Attr
<
std
::
vector
<
int
64_t
>>
(
"infer_flags"
);
auto
decrease_axis
=
this
->
Attr
<
std
::
vector
<
int
64_t
>>
(
"decrease_axis"
);
auto
axes
=
this
->
Attr
<
std
::
vector
<
int
>>
(
"axes"
);
auto
starts
=
this
->
Attr
<
std
::
vector
<
int
>>
(
"starts"
);
auto
ends
=
this
->
Attr
<
std
::
vector
<
int
>>
(
"ends"
);
auto
infer_flags
=
this
->
Attr
<
std
::
vector
<
int
>>
(
"infer_flags"
);
auto
decrease_axis
=
this
->
Attr
<
std
::
vector
<
int
>>
(
"decrease_axis"
);
VLOG
(
6
)
<<
"Runing slice_grad composite func"
;
std
::
vector
<
int64_t
>
new_axes
=
std
::
vector
<
int64_t
>
(
axes
.
begin
(),
axes
.
end
());
std
::
vector
<
int64_t
>
new_infer_flags
=
std
::
vector
<
int64_t
>
(
infer_flags
.
begin
(),
infer_flags
.
end
());
std
::
vector
<
int64_t
>
new_decrease_axis
=
std
::
vector
<
int64_t
>
(
decrease_axis
.
begin
(),
decrease_axis
.
end
());
prim
::
slice_grad
<
prim
::
DescTensor
>
(
input
,
out_grad
,
axes
,
new_
axes
,
paddle
::
experimental
::
IntArray
(
starts
),
paddle
::
experimental
::
IntArray
(
ends
),
infer_flags
,
decrease_axis
,
new_
infer_flags
,
new_
decrease_axis
,
dx_ptr
);
this
->
RecoverOutputName
(
input_grad
,
dx_name
);
}
...
...
@@ -478,6 +484,7 @@ REGISTER_OPERATOR(slice,
ops
::
SliceOpMaker
,
ops
::
SliceOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
SliceOpGradMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
SliceCompositeGradOpMaker
,
ops
::
SliceOpVarTypeInference
);
REGISTER_OPERATOR
(
slice_grad
,
ops
::
SliceOpGrad
,
...
...
paddle/fluid/prim/api/composite_backward/composite_backward_api.h
浏览文件 @
bbca66f2
...
...
@@ -704,6 +704,7 @@ void slice_grad(const Tensor& input,
if
(
input_grad
)
{
size_t
rank
=
input
.
dims
().
size
();
auto
out_dims
=
out_grad
.
dims
();
std
::
vector
<
int64_t
>
origin_out_shape
;
auto
in_dims
=
input
.
dims
();
auto
decrease_size
=
decrease_axis
.
size
();
...
...
@@ -712,7 +713,7 @@ void slice_grad(const Tensor& input,
// all dims decrease
out_dims
=
phi
::
make_ddim
(
std
::
vector
<
int
>
(
decrease_size
,
1
));
}
else
{
std
::
vector
<
int
>
origin_out_shap
e
(
out_dims
.
size
()
+
decrease_size
,
-
1
);
origin_out_shape
.
resiz
e
(
out_dims
.
size
()
+
decrease_size
,
-
1
);
for
(
size_t
i
=
0
;
i
<
decrease_size
;
++
i
)
{
origin_out_shape
[
decrease_axis
[
i
]]
=
1
;
}
...
...
@@ -734,7 +735,6 @@ void slice_grad(const Tensor& input,
offsets
[
i
]
=
0
;
extents
[
i
]
=
out_dims
[
i
];
}
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
int
axis
=
axes
[
i
];
int64_t
start
=
starts
[
i
]
<
0
?
(
starts
[
i
]
+
in_dims
[
axis
])
:
starts
[
i
];
...
...
@@ -747,9 +747,15 @@ void slice_grad(const Tensor& input,
paddings
.
push_back
(
offsets
[
i
]);
paddings
.
push_back
((
in_dims
[
i
]
-
out_dims
[
i
])
-
offsets
[
i
]);
}
auto
out_tmp
=
pad
<
T
>
(
out_grad
,
paddings
,
0.0
);
set_output
<
T
>
(
out_tmp
,
input_grad
);
if
(
decrease_size
>
0
&&
(
decrease_size
!=
static_cast
<
size_t
>
(
in_dims
.
size
())))
{
auto
out_tmp
=
pad
<
T
>
(
reshape
<
T
>
(
out_grad
,
origin_out_shape
),
paddings
,
0.0
);
set_output
<
T
>
(
out_tmp
,
input_grad
);
}
else
{
auto
out_tmp
=
pad
<
T
>
(
out_grad
,
paddings
,
0.0
);
set_output
<
T
>
(
out_tmp
,
input_grad
);
}
}
}
...
...
paddle/fluid/prim/utils/static/static_global_utils.cc
浏览文件 @
bbca66f2
...
...
@@ -20,5 +20,6 @@ StaticCompositeContext* StaticCompositeContext::static_composite_context_ =
new
StaticCompositeContext
();
thread_local
bool
StaticCompositeContext
::
enable_bwd_prim_
=
false
;
thread_local
bool
StaticCompositeContext
::
enable_fwd_prim_
=
false
;
thread_local
bool
StaticCompositeContext
::
enable_eager_prim_
=
false
;
}
// namespace prim
}
// namespace paddle
paddle/fluid/prim/utils/static/static_global_utils.h
浏览文件 @
bbca66f2
...
...
@@ -65,6 +65,12 @@ class StaticCompositeContext {
bool
IsFwdPrimEnabled
()
{
return
enable_fwd_prim_
;
}
void
SetEagerPrimEnabled
(
bool
enable_prim
)
{
enable_eager_prim_
=
enable_prim
;
}
bool
IsEagerPrimEnabled
()
{
return
enable_eager_prim_
;
}
void
SetAllPrimEnabled
(
bool
enable_prim
)
{
enable_fwd_prim_
=
enable_prim
;
enable_bwd_prim_
=
enable_prim
;
...
...
@@ -102,6 +108,7 @@ class StaticCompositeContext {
std
::
map
<
std
::
string
,
std
::
string
>
target_grad_name_
;
static
thread_local
bool
enable_bwd_prim_
;
static
thread_local
bool
enable_fwd_prim_
;
static
thread_local
bool
enable_eager_prim_
;
static
StaticCompositeContext
*
static_composite_context_
;
DISABLE_COPY_AND_ASSIGN
(
StaticCompositeContext
);
};
...
...
paddle/fluid/prim/utils/utils.cc
浏览文件 @
bbca66f2
...
...
@@ -27,6 +27,14 @@ void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
StaticCompositeContext
::
Instance
().
SetBwdPrimEnabled
(
enable_prim
);
}
bool
PrimCommonUtils
::
IsEagerPrimEnabled
()
{
return
StaticCompositeContext
::
Instance
().
IsEagerPrimEnabled
();
}
void
PrimCommonUtils
::
SetEagerPrimEnabled
(
bool
enable_prim
)
{
StaticCompositeContext
::
Instance
().
SetEagerPrimEnabled
(
enable_prim
);
}
bool
PrimCommonUtils
::
IsFwdPrimEnabled
()
{
return
StaticCompositeContext
::
Instance
().
IsFwdPrimEnabled
();
}
...
...
paddle/fluid/prim/utils/utils.h
浏览文件 @
bbca66f2
...
...
@@ -23,6 +23,8 @@ class PrimCommonUtils {
public:
static
bool
IsBwdPrimEnabled
();
static
void
SetBwdPrimEnabled
(
bool
enabled
);
static
bool
IsEagerPrimEnabled
();
static
void
SetEagerPrimEnabled
(
bool
enabled
);
static
bool
IsFwdPrimEnabled
();
static
void
SetFwdPrimEnabled
(
bool
enabled
);
static
void
SetAllPrimEnabled
(
bool
enabled
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
bbca66f2
...
...
@@ -681,6 +681,10 @@ PYBIND11_MODULE(libpaddle, m) {
&
paddle
::
prim
::
PrimCommonUtils
::
IsFwdPrimEnabled
);
m
.
def
(
"__set_all_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
SetAllPrimEnabled
);
m
.
def
(
"_is_eager_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
IsEagerPrimEnabled
);
m
.
def
(
"__set_eager_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
SetEagerPrimEnabled
);
m
.
def
(
"_set_prim_target_grad_name"
,
&
paddle
::
prim
::
PrimCommonUtils
::
SetTargetGradName
);
m
.
def
(
"set_num_threads"
,
&
platform
::
SetNumThreads
);
...
...
python/paddle/fluid/core.py
浏览文件 @
bbca66f2
...
...
@@ -316,6 +316,8 @@ try:
from
.libpaddle
import
__set_fwd_prim_enabled
from
.libpaddle
import
_is_fwd_prim_enabled
from
.libpaddle
import
__set_all_prim_enabled
from
.libpaddle
import
_is_eager_prim_enabled
from
.libpaddle
import
__set_eager_prim_enabled
from
.libpaddle
import
_set_prim_target_grad_name
# custom devivce
...
...
@@ -475,26 +477,36 @@ def _set_prim_forward_blacklist(ops=None):
def
_set_prim_backward_enabled
(
value
):
__set_bwd_prim_enabled
(
bool
(
value
))
print
(
"backward prim enabled: "
,
bool
(
_is_bwd_prim_enabled
()))
if
os
.
getenv
(
"FLAGS_prim_log"
)
is
"1"
:
print
(
"backward prim enabled: "
,
bool
(
_is_bwd_prim_enabled
()))
def
_set_prim_forward_enabled
(
value
):
__set_fwd_prim_enabled
(
bool
(
value
))
print
(
"forward prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()))
if
os
.
getenv
(
"FLAGS_prim_log"
)
is
"1"
:
print
(
"forward prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()))
def
set_prim_eager_enabled
(
value
):
__set_eager_prim_enabled
(
bool
(
value
))
if
os
.
getenv
(
"FLAGS_prim_log"
)
is
"1"
:
print
(
"eager prim enabled: "
,
bool
(
_is_eager_prim_enabled
()))
def
_set_prim_all_enabled
(
value
):
__set_all_prim_enabled
(
bool
(
value
))
print
(
"all prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()
and
_is_bwd_prim_enabled
()),
)
if
os
.
getenv
(
"FLAGS_prim_log"
)
is
"1"
:
print
(
"all prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()
and
_is_bwd_prim_enabled
()),
)
def
__sync_prim_backward_status
():
flag_value
=
os
.
getenv
(
"FLAGS_prim_backward"
)
if
flag_value
is
None
:
print
(
"backward prim enabled: "
,
bool
(
_is_bwd_prim_enabled
()))
if
os
.
getenv
(
"FLAGS_prim_log"
)
is
"1"
:
print
(
"backward prim enabled: "
,
bool
(
_is_bwd_prim_enabled
()))
else
:
__sync_stat_with_flag
(
"FLAGS_prim_backward"
)
...
...
@@ -502,7 +514,8 @@ def __sync_prim_backward_status():
def
__sync_prim_forward_status
():
flag_value
=
os
.
getenv
(
"FLAGS_prim_forward"
)
if
flag_value
is
None
:
print
(
"forward prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()))
if
os
.
getenv
(
"FLAGS_prim_log"
)
is
1
:
print
(
"forward prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()))
else
:
__sync_stat_with_flag
(
"FLAGS_prim_forward"
)
...
...
python/paddle/fluid/tests/unittests/prim/model/bert.py
浏览文件 @
bbca66f2
...
...
@@ -207,7 +207,7 @@ class BertPooler(nn.Layer):
class
BertModel
(
nn
.
Layer
):
def
__init__
(
self
,
config
:
BertConfig
):
def
__init__
(
self
,
config
:
BertConfig
,
to_static
):
super
(
BertModel
,
self
).
__init__
()
self
.
config
=
config
self
.
pad_token_id
=
config
.
pad_token_id
...
...
@@ -247,6 +247,8 @@ class BertModel(nn.Layer):
self
.
encoder
=
nn
.
TransformerEncoder
(
encoder_layer
,
config
.
num_hidden_layers
)
if
to_static
:
self
.
encoder
=
paddle
.
jit
.
to_static
(
self
.
encoder
)
self
.
pooler
=
BertPooler
(
config
)
# self.apply(self.init_weights)
...
...
@@ -364,10 +366,10 @@ class BertModel(nn.Layer):
class
Bert
(
nn
.
Layer
):
def
__init__
(
self
):
def
__init__
(
self
,
to_static
):
super
(
Bert
,
self
).
__init__
()
config
=
BertConfig
()
self
.
bert
=
BertModel
(
config
)
self
.
bert
=
BertModel
(
config
,
to_static
)
self
.
cls
=
BertPretrainingHeads
(
config
,
embedding_weights
=
self
.
bert
.
embeddings
.
word_embeddings
.
weight
,
...
...
python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py
浏览文件 @
bbca66f2
...
...
@@ -58,7 +58,7 @@ def train(to_static, enable_prim, enable_cinn):
worker_init
=
None
,
)
bert
=
Bert
()
bert
=
Bert
(
to_static
)
criterion
=
BertPretrainingCriterion
()
if
to_static
:
# input_sepc = [
...
...
@@ -72,9 +72,6 @@ def train(to_static, enable_prim, enable_cinn):
build_strategy
=
paddle
.
static
.
BuildStrategy
()
if
enable_cinn
:
build_strategy
.
build_cinn_pass
=
True
bert
=
paddle
.
jit
.
to_static
(
bert
,
input_sepc
,
build_strategy
=
build_strategy
)
optimizer
=
fluid
.
optimizer
.
Adam
(
parameter_list
=
bert
.
parameters
())
...
...
python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py
浏览文件 @
bbca66f2
...
...
@@ -58,6 +58,9 @@ class TestPrimFlags(unittest.TestCase):
core
.
check_and_set_prim_all_enabled
()
self
.
assertFalse
(
core
.
_is_fwd_prim_enabled
())
core
.
set_prim_eager_enabled
(
True
)
self
.
assertTrue
(
core
.
_is_eager_prim_enabled
())
with
self
.
assertRaises
(
TypeError
):
core
.
_test_use_sync
(
"aaaa"
)
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py
浏览文件 @
bbca66f2
...
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
@
param
.
parameterized_class
(
...
...
@@ -61,7 +61,7 @@ class TestAddGradComp(unittest.TestCase):
def
test_add_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
...
@@ -72,7 +72,7 @@ class TestAddGradComp(unittest.TestCase):
return
res
[
0
].
numpy
(),
res
[
1
].
numpy
()
def
desired
(
primal0
,
primal1
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
...
@@ -98,7 +98,7 @@ class TestAddGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_cast_grad.py
浏览文件 @
bbca66f2
...
...
@@ -52,7 +52,7 @@ class TestCastGradComp(unittest.TestCase):
cls
.
cotangent
=
cls
.
cotangent
.
astype
(
cls
.
src_dtype
)
def
test_cast_grad_comp
(
self
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
def
actual
(
primal
,
cotangent
):
x
=
paddle
.
to_tensor
(
primal
)
...
...
@@ -78,7 +78,7 @@ class TestCastGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py
浏览文件 @
bbca66f2
...
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
@
param
.
parameterized_class
(
...
...
@@ -61,7 +61,7 @@ class TestDivGradComp(unittest.TestCase):
def
test_div_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
...
@@ -72,7 +72,7 @@ class TestDivGradComp(unittest.TestCase):
return
res
[
0
].
numpy
(),
res
[
1
].
numpy
()
def
desired
(
primal0
,
primal1
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
...
@@ -98,7 +98,7 @@ class TestDivGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py
浏览文件 @
bbca66f2
...
...
@@ -32,14 +32,14 @@ from paddle.fluid import core
class
TestExpGradComp
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
cls
.
primal
=
cls
.
primal
.
astype
(
cls
.
dtype
)
if
cls
.
cotangent
is
not
None
:
cls
.
cotangent
=
cls
.
cotangent
.
astype
(
cls
.
dtype
)
@
classmethod
def
tearDownClass
(
cls
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
def
test_exp_grad_comp
(
self
):
def
actual
(
primal
,
cotangent
):
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py
浏览文件 @
bbca66f2
...
...
@@ -62,7 +62,7 @@ class TestExpandGradComp(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
def
test_comp
(
self
):
def
func
(
primal
,
cotangent
,
shape
):
...
...
@@ -74,11 +74,11 @@ class TestExpandGradComp(unittest.TestCase):
]
def
actual
(
primal
,
cotangent
,
shape
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
return
func
(
primal
,
cotangent
,
shape
)
def
desired
(
primal
,
cotangent
,
shape
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
return
func
(
primal
,
cotangent
,
shape
)
np
.
testing
.
assert_allclose
(
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_gather_grad.py
浏览文件 @
bbca66f2
...
...
@@ -75,11 +75,11 @@ class TestGatherGradComp(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
def
test_exp_grad_comp
(
self
):
def
actual
(
primal0
,
index
,
axis
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
primal0
.
dtype
,
stop_gradient
=
False
...
...
@@ -92,7 +92,7 @@ class TestGatherGradComp(unittest.TestCase):
return
res
[
0
].
numpy
()
def
desired
(
primal0
,
index
,
axis
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
primal0
.
dtype
,
stop_gradient
=
False
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_matmul_double_grad.py
浏览文件 @
bbca66f2
...
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
# vector * vector out.shape = (1)
# matrix * vector out.shape = (2)
...
...
@@ -267,7 +267,7 @@ class TestMatmulDoubleGradComp(unittest.TestCase):
def
test_matmul_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
,
trans_0
,
trans_1
,
dtype_
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
dtype_
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
dtype_
,
stop_gradient
=
False
)
...
...
@@ -287,7 +287,7 @@ class TestMatmulDoubleGradComp(unittest.TestCase):
)
def
desired
(
primal0
,
primal1
,
trans_0
,
trans_1
,
dtype_
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
dtype_
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
dtype_
,
stop_gradient
=
False
)
...
...
@@ -428,7 +428,7 @@ class TestMatmulTribleGradComp(unittest.TestCase):
def
test_matmul_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
,
trans_0
,
trans_1
,
dtype_
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
dtype_
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
dtype_
,
stop_gradient
=
False
)
...
...
@@ -465,7 +465,7 @@ class TestMatmulTribleGradComp(unittest.TestCase):
)
def
desired
(
primal0
,
primal1
,
trans_0
,
trans_1
,
dtype_
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
dtype_
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
dtype_
,
stop_gradient
=
False
)
...
...
@@ -549,7 +549,7 @@ class TestMatmulTribleGradComp(unittest.TestCase):
atol
=
TOLERANCE
[
d_type
][
'atol'
],
)
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py
浏览文件 @
bbca66f2
...
...
@@ -81,10 +81,10 @@ class TestMultiplyGradComp(unittest.TestCase):
return
[
g
for
g
in
grads
if
g
is
not
None
]
def
test_comp
(
self
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
actual
=
self
.
vjp
()
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
desired
=
self
.
vjp
()
for
i
,
j
in
zip
(
actual
,
desired
):
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_reshape_grad.py
浏览文件 @
bbca66f2
...
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
@
param
.
parameterized_class
(
...
...
@@ -42,7 +42,7 @@ class TestReshapeGradComp(unittest.TestCase):
def
test_reshape_grad_comp
(
self
):
def
actual
(
primal0
,
shape
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
x
.
stop_gradient
=
False
...
...
@@ -51,7 +51,7 @@ class TestReshapeGradComp(unittest.TestCase):
return
res
[
0
].
numpy
()
def
desired
(
primal0
,
shape
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
x
.
stop_gradient
=
False
...
...
@@ -69,7 +69,7 @@ class TestReshapeGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py
浏览文件 @
bbca66f2
...
...
@@ -22,7 +22,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
@
param
.
parameterized_class
(
...
...
@@ -57,7 +57,7 @@ class TestSqrtGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py
浏览文件 @
bbca66f2
...
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
@
param
.
parameterized_class
(
...
...
@@ -61,7 +61,7 @@ class TestSubGradComp(unittest.TestCase):
def
test_sub_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
...
@@ -72,7 +72,7 @@ class TestSubGradComp(unittest.TestCase):
return
res
[
0
].
numpy
(),
res
[
1
].
numpy
()
def
desired
(
primal0
,
primal1
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
...
@@ -98,7 +98,7 @@ class TestSubGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py
浏览文件 @
bbca66f2
...
...
@@ -21,7 +21,7 @@ from paddle.fluid import core
def
actual
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
x
=
paddle
.
to_tensor
(
primal
,
dtype
=
'float32'
,
stop_gradient
=
False
)
v
=
paddle
.
to_tensor
(
cotangent
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
sum
(
x
,
axis
=
axis
,
keepdim
=
keep_dim
)
...
...
@@ -30,7 +30,7 @@ def actual(primal, cotangent, axis, keep_dim):
def
desired
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
x
=
paddle
.
to_tensor
(
primal
,
dtype
=
'float32'
,
stop_gradient
=
False
)
v
=
paddle
.
to_tensor
(
cotangent
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
sum
(
x
,
axis
=
axis
,
keepdim
=
keep_dim
)
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py
浏览文件 @
bbca66f2
...
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
@
param
.
parameterized_class
(
...
...
@@ -68,7 +68,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_transpose_grad.py
浏览文件 @
bbca66f2
...
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
@
param
.
parameterized_class
(
...
...
@@ -72,7 +72,7 @@ class TestTransposeGradComp(unittest.TestCase):
def
test_transpose_grad_comp
(
self
):
def
actual
(
primal0
,
shape
):
core
.
_set_prim_backward
_enabled
(
True
)
core
.
set_prim_eager
_enabled
(
True
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
x
.
stop_gradient
=
False
...
...
@@ -81,7 +81,7 @@ class TestTransposeGradComp(unittest.TestCase):
return
res
[
0
].
numpy
()
def
desired
(
primal0
,
shape
):
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
x
.
stop_gradient
=
False
...
...
@@ -99,7 +99,7 @@ class TestTransposeGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
_set_prim_backward
_enabled
(
False
)
core
.
set_prim_eager
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim_op_test.py
浏览文件 @
bbca66f2
...
...
@@ -906,7 +906,7 @@ class PrimGradChecker(PrimForwardChecker):
paddle
.
device
.
set_device
(
"gpu:0"
)
atol
=
self
.
rev_comp_atol
rtol
=
self
.
rev_comp_rtol
core
.
_set_prim_backward
_enabled
(
self
.
enable_rev_comp
)
core
.
set_prim_eager
_enabled
(
self
.
enable_rev_comp
)
actual_ret
=
self
.
get_eager_desire
()
# check static forward
if
len
(
actual_ret
)
!=
len
(
self
.
eager_desire
):
...
...
@@ -941,6 +941,7 @@ class PrimGradChecker(PrimForwardChecker):
)
)
raise
RuntimeError
(
msg
)
core
.
set_prim_eager_enabled
(
False
)
def
check_static_comp
(
self
):
paddle
.
enable_static
()
...
...
python/paddle/fluid/tests/unittests/test_slice_op.py
浏览文件 @
bbca66f2
...
...
@@ -213,9 +213,7 @@ class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim):
class
TestSliceOp_starts_ListTensor
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"slice"
self
.
prim_op_type
=
"prim"
self
.
python_api
=
paddle
.
slice
# self.enable_cinn = False
self
.
config
()
starts_tensor
=
[]
...
...
@@ -244,12 +242,10 @@ class TestSliceOp_starts_ListTensor(OpTest):
self
.
starts_infer
=
[
-
1
,
0
,
-
1
]
def
test_check_output
(
self
):
self
.
check_output
(
check_prim
=
True
)
self
.
check_output
()
def
test_check_grad_normal
(
self
):
self
.
check_grad
(
[
'Input'
],
'Out'
,
max_relative_error
=
0.006
,
check_prim
=
True
)
self
.
check_grad
([
'Input'
],
'Out'
,
max_relative_error
=
0.006
)
# Situation 2: starts(list, have tensor), ends(list, no tensor)
...
...
@@ -257,7 +253,6 @@ class TestSliceOp_starts_ListTensor(OpTest):
class
TestSliceOp_decs_dim_starts_ListTensor
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"slice"
self
.
prim_op_type
=
"prim"
self
.
python_api
=
paddle
.
slice
self
.
config
()
...
...
@@ -290,12 +285,10 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest):
self
.
starts_infer
=
[
1
,
-
1
,
2
]
def
test_check_output
(
self
):
self
.
check_output
(
check_prim
=
True
)
self
.
check_output
()
def
test_check_grad_normal
(
self
):
self
.
check_grad
(
[
'Input'
],
'Out'
,
max_relative_error
=
0.006
,
check_prim
=
True
)
self
.
check_grad
([
'Input'
],
'Out'
,
max_relative_error
=
0.006
)
class
TestSliceOp_decs_dim_5_starts_ListTensor
(
...
...
@@ -318,7 +311,6 @@ class TestSliceOp_decs_dim_5_starts_ListTensor(
class
TestSliceOp_decs_dim_starts_OneTensor
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"slice"
self
.
prim_op_type
=
"prim"
self
.
python_api
=
paddle
.
slice
self
.
config
()
self
.
inputs
=
{
...
...
@@ -344,12 +336,10 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest):
self
.
out
=
self
.
input
[
1
,
0
:
3
,
2
:
4
,
:]
def
test_check_output
(
self
):
self
.
check_output
(
check_prim
=
True
)
self
.
check_output
()
def
test_check_grad_normal
(
self
):
self
.
check_grad
(
[
'Input'
],
'Out'
,
max_relative_error
=
0.006
,
check_prim
=
True
)
self
.
check_grad
([
'Input'
],
'Out'
,
max_relative_error
=
0.006
)
# Situation 4: starts(tensor), ends(tensor)
...
...
@@ -357,7 +347,6 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest):
class
TestSliceOp_starts_OneTensor_ends_OneTensor
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"slice"
self
.
prim_op_type
=
"prim"
self
.
python_api
=
paddle
.
slice
self
.
config
()
...
...
@@ -383,12 +372,10 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
self
.
out
=
self
.
input
[
1
:
3
,
0
:
3
,
2
:
4
,
:]
def
test_check_output
(
self
):
self
.
check_output
(
check_prim
=
True
)
self
.
check_output
()
def
test_check_grad_normal
(
self
):
self
.
check_grad
(
[
'Input'
],
'Out'
,
max_relative_error
=
0.006
,
check_prim
=
True
)
self
.
check_grad
([
'Input'
],
'Out'
,
max_relative_error
=
0.006
)
# Situation 5: starts(tensor), ends(tensor)
...
...
@@ -396,7 +383,6 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
class
TestSliceOp_decs_dim_starts_and_ends_OneTensor
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"slice"
self
.
prim_op_type
=
"prim"
self
.
python_api
=
paddle
.
slice
self
.
config
()
self
.
inputs
=
{
...
...
@@ -423,12 +409,10 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest):
self
.
out
=
self
.
input
[
1
,
0
,
2
:
4
,
:]
def
test_check_output
(
self
):
self
.
check_output
(
check_prim
=
True
)
self
.
check_output
()
def
test_check_grad_normal
(
self
):
self
.
check_grad
(
[
'Input'
],
'Out'
,
max_relative_error
=
0.006
,
check_prim
=
True
)
self
.
check_grad
([
'Input'
],
'Out'
,
max_relative_error
=
0.006
)
# Situation 6: starts(tensor), ends(list, have tensor)
...
...
@@ -436,7 +420,6 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest):
class
TestSliceOp_starts_OneTensor_ends_ListTensor
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"slice"
self
.
prim_op_type
=
"prim"
self
.
python_api
=
paddle
.
slice
self
.
config
()
...
...
@@ -470,12 +453,10 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest):
self
.
ends_infer
=
[
-
1
,
3
,
4
]
def
test_check_output
(
self
):
self
.
check_output
(
check_prim
=
True
)
self
.
check_output
()
def
test_check_grad_normal
(
self
):
self
.
check_grad
(
[
'Input'
],
'Out'
,
max_relative_error
=
0.006
,
check_prim
=
True
)
self
.
check_grad
([
'Input'
],
'Out'
,
max_relative_error
=
0.006
)
# Test CUDA float16
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录