Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
23d20e30
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
23d20e30
编写于
1月 20, 2023
作者:
J
Jiabin Yang
提交者:
GitHub
1月 20, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Prim】Refactor prim flags system (#49930)
上级
44855da3
变更
49
显示空白变更内容
内联
并排
Showing
49 changed file
with
339 addition
and
206 deletion
+339
-206
paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
...le/fluid/eager/auto_code_generator/generator/eager_gen.py
+2
-2
paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc
paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc
+11
-1
paddle/fluid/prim/api/manual/backward/composite_backward_api.h
...e/fluid/prim/api/manual/backward/composite_backward_api.h
+1
-1
paddle/fluid/prim/tests/test_eager_prim.cc
paddle/fluid/prim/tests/test_eager_prim.cc
+8
-8
paddle/fluid/prim/tests/test_static_prim.cc
paddle/fluid/prim/tests/test_static_prim.cc
+4
-4
paddle/fluid/prim/utils/static/static_global_utils.cc
paddle/fluid/prim/utils/static/static_global_utils.cc
+2
-1
paddle/fluid/prim/utils/static/static_global_utils.h
paddle/fluid/prim/utils/static/static_global_utils.h
+13
-3
paddle/fluid/prim/utils/utils.cc
paddle/fluid/prim/utils/utils.cc
+16
-4
paddle/fluid/prim/utils/utils.h
paddle/fluid/prim/utils/utils.h
+5
-2
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+12
-3
paddle/phi/api/yaml/legacy_backward.yaml
paddle/phi/api/yaml/legacy_backward.yaml
+3
-3
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+3
-2
python/paddle/fluid/core.py
python/paddle/fluid/core.py
+96
-28
python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py
...d/tests/unittests/composite_ops/test_composite_softmax.py
+3
-0
python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py
...ts/unittests/composite_ops/test_composite_softmax_grad.py
+3
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py
...ddle/fluid/tests/unittests/dygraph_to_static/test_bert.py
+2
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py
...fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py
+2
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py
...le/fluid/tests/unittests/dygraph_to_static/test_resnet.py
+14
-59
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_amp.py
...luid/tests/unittests/dygraph_to_static/test_resnet_amp.py
+2
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py
...ests/unittests/dygraph_to_static/test_resnet_pure_fp16.py
+2
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py
...fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/CMakeLists.txt
python/paddle/fluid/tests/unittests/prim/prim/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt
...ddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt
+9
-0
python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py
.../fluid/tests/unittests/prim/prim/flags/test_prim_flags.py
+52
-0
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py
...ttests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py
+3
-3
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py
...ests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py
...nittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py
...nittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py
...ests/unittests/prim/prim/vjp/static/test_comp_add_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py
...unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py
...ests/unittests/prim/prim/vjp/static/test_comp_div_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py
...ests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_expand_grad.py
...s/unittests/prim/prim/vjp/static/test_comp_expand_grad.py
+3
-3
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py
...unittests/prim/prim/vjp/static/test_comp_multiply_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py
...sts/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py
+3
-3
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py
...ests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py
...ests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py
...sts/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py
+3
-3
python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_disabled.py
...nittests/prim/test_comp_get_grad_op_desc_prim_disabled.py
+1
-1
python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py
...unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py
+2
-2
python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py
...n/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py
+2
-2
python/paddle/incubate/autograd/primapi.py
python/paddle/incubate/autograd/primapi.py
+2
-1
python/paddle/jit/dy2static/partial_program.py
python/paddle/jit/dy2static/partial_program.py
+4
-7
python/paddle/jit/dy2static/program_translator.py
python/paddle/jit/dy2static/program_translator.py
+4
-5
未找到文件。
paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
浏览文件 @
23d20e30
...
...
@@ -1841,7 +1841,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if
is_composite_grad_api
and
next_grad_node_creation_str
!=
''
:
next_grad_node_creation_str
=
f
"""
if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
if (!paddle::prim::PrimCommonUtils::Is
Bwd
PrimEnabled()) {{
{
next_grad_node_creation_str
}
}}
"""
...
...
@@ -2261,7 +2261,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif
is_composite_grad_api
:
grad_function_call_str
=
f
"""
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
if (paddle::prim::PrimCommonUtils::Is
Bwd
PrimEnabled()) {{
{
indent
}{
composite_grad_api_namespace
}{
composite_grad_api_name
}{
composite_template_name
}
(
{
composite_grad_api_args_str
}
);
VLOG(4) << "Composite api
{
composite_grad_api_name
}
is called ";
}}else{{
...
...
paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc
浏览文件 @
23d20e30
...
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string.h>
#include <memory>
#include <sstream>
#include <string>
...
...
@@ -166,7 +167,16 @@ Tensor full<DescTensor>(const IntArray& shape,
phi
::
errors
::
InvalidArgument
(
"We only support float32/float16 for full, but we got data type: %s"
,
phi
::
DataTypeToString
(
dtype
)));
if
(
dtype
==
phi
::
DataType
::
FLOAT32
)
{
op
->
SetAttr
(
"value"
,
value
.
to
<
float
>
());
}
else
if
(
dtype
==
phi
::
DataType
::
FLOAT64
)
{
op
->
SetAttr
(
"str_value"
,
std
::
to_string
(
value
.
to
<
double
>
()));
}
else
if
(
dtype
==
phi
::
DataType
::
FLOAT16
)
{
op
->
SetAttr
(
"str_value"
,
std
::
to_string
(
value
.
to
<
float
>
()));
}
else
{
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"We only support float64/float32/float16 for full"
));
}
op
->
SetAttr
(
"dtype"
,
paddle
::
framework
::
TransToProtoVarType
(
dtype
));
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
...
...
paddle/fluid/prim/api/manual/backward/composite_backward_api.h
浏览文件 @
23d20e30
...
...
@@ -192,7 +192,7 @@ void divide_grad(const Tensor& x,
}
// indicate we will compute dy
if
(
dx
)
{
// dx = (1/y) * dout
auto
one_tensor
=
full
<
T
>
(
phi
::
vectorize
(
y
.
dims
()),
1.0
);
auto
one_tensor
=
full
<
T
>
(
phi
::
vectorize
(
y
.
dims
()),
1.0
,
y
.
dtype
()
);
auto
tmp0
=
divide
<
T
>
(
one_tensor
,
y
);
auto
dx_res
=
multiply
<
T
>
(
tmp0
,
out_grad
);
if
(
y
.
dims
()
!=
x
.
dims
())
{
...
...
paddle/fluid/prim/tests/test_eager_prim.cc
浏览文件 @
23d20e30
...
...
@@ -68,16 +68,16 @@ TEST(EagerPrim, TanhBackwardTest) {
paddle
::
experimental
::
Tensor
out0
=
tanh_ad_func
(
tensor0
);
std
::
vector
<
paddle
::
experimental
::
Tensor
>
outs0
=
{
out0
};
// Disable prim
PrimCommonUtils
::
SetPrimEnabled
(
false
);
ASSERT_FALSE
(
PrimCommonUtils
::
IsPrimEnabled
());
PrimCommonUtils
::
Set
Bwd
PrimEnabled
(
false
);
ASSERT_FALSE
(
PrimCommonUtils
::
Is
Bwd
PrimEnabled
());
// 4. Run Backward
egr
::
Backward
(
outs0
,
{},
false
);
paddle
::
experimental
::
Tensor
out1
=
tanh_ad_func
(
tensor1
);
std
::
vector
<
paddle
::
experimental
::
Tensor
>
outs1
=
{
out1
};
// Disable prim
PrimCommonUtils
::
SetPrimEnabled
(
true
);
ASSERT_TRUE
(
PrimCommonUtils
::
IsPrimEnabled
());
PrimCommonUtils
::
Set
Bwd
PrimEnabled
(
true
);
ASSERT_TRUE
(
PrimCommonUtils
::
Is
Bwd
PrimEnabled
());
// 4. Run Backward
::
egr
::
Backward
(
outs1
,
{},
false
);
VLOG
(
7
)
...
...
@@ -99,10 +99,10 @@ TEST(EagerPrim, TanhBackwardTest) {
}
TEST
(
EagerPrim
,
TestFlags
)
{
PrimCommonUtils
::
SetPrimEnabled
(
true
);
ASSERT_TRUE
(
PrimCommonUtils
::
IsPrimEnabled
());
PrimCommonUtils
::
SetPrimEnabled
(
false
);
ASSERT_FALSE
(
PrimCommonUtils
::
IsPrimEnabled
());
PrimCommonUtils
::
Set
Bwd
PrimEnabled
(
true
);
ASSERT_TRUE
(
PrimCommonUtils
::
Is
Bwd
PrimEnabled
());
PrimCommonUtils
::
Set
Bwd
PrimEnabled
(
false
);
ASSERT_FALSE
(
PrimCommonUtils
::
Is
Bwd
PrimEnabled
());
}
}
// namespace prim
...
...
paddle/fluid/prim/tests/test_static_prim.cc
浏览文件 @
23d20e30
...
...
@@ -341,10 +341,10 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
}
TEST
(
StaticPrim
,
TestFlags
)
{
PrimCommonUtils
::
SetPrimEnabled
(
true
);
ASSERT_TRUE
(
PrimCommonUtils
::
IsPrimEnabled
());
PrimCommonUtils
::
SetPrimEnabled
(
false
);
ASSERT_FALSE
(
PrimCommonUtils
::
IsPrimEnabled
());
PrimCommonUtils
::
Set
Bwd
PrimEnabled
(
true
);
ASSERT_TRUE
(
PrimCommonUtils
::
Is
Bwd
PrimEnabled
());
PrimCommonUtils
::
Set
Bwd
PrimEnabled
(
false
);
ASSERT_FALSE
(
PrimCommonUtils
::
Is
Bwd
PrimEnabled
());
}
}
// namespace prim
...
...
paddle/fluid/prim/utils/static/static_global_utils.cc
浏览文件 @
23d20e30
...
...
@@ -18,6 +18,7 @@ namespace paddle {
namespace
prim
{
StaticCompositeContext
*
StaticCompositeContext
::
static_composite_context_
=
new
StaticCompositeContext
();
thread_local
bool
StaticCompositeContext
::
enable_prim_
=
false
;
thread_local
bool
StaticCompositeContext
::
enable_bwd_prim_
=
false
;
thread_local
bool
StaticCompositeContext
::
enable_fwd_prim_
=
false
;
}
// namespace prim
}
// namespace paddle
paddle/fluid/prim/utils/static/static_global_utils.h
浏览文件 @
23d20e30
...
...
@@ -56,9 +56,18 @@ class StaticCompositeContext {
return
generator_
->
Generate
(
key
);
}
void
Set
PrimEnabled
(
bool
enable_prim
)
{
enable
_prim_
=
enable_prim
;
}
void
Set
BwdPrimEnabled
(
bool
enable_prim
)
{
enable_bwd
_prim_
=
enable_prim
;
}
bool
IsPrimEnabled
()
{
return
enable_prim_
;
}
bool
IsBwdPrimEnabled
()
{
return
enable_bwd_prim_
;
}
void
SetFwdPrimEnabled
(
bool
enable_prim
)
{
enable_fwd_prim_
=
enable_prim
;
}
bool
IsFwdPrimEnabled
()
{
return
enable_fwd_prim_
;
}
void
SetAllPrimEnabled
(
bool
enable_prim
)
{
enable_fwd_prim_
=
enable_prim
;
enable_bwd_prim_
=
enable_prim
;
}
private:
StaticCompositeContext
()
...
...
@@ -66,7 +75,8 @@ class StaticCompositeContext {
framework
::
BlockDesc
*
current_block_desc_
;
std
::
unique_ptr
<
UniqueNameGenerator
>
generator_
;
static
thread_local
bool
enable_prim_
;
static
thread_local
bool
enable_bwd_prim_
;
static
thread_local
bool
enable_fwd_prim_
;
static
StaticCompositeContext
*
static_composite_context_
;
DISABLE_COPY_AND_ASSIGN
(
StaticCompositeContext
);
};
...
...
paddle/fluid/prim/utils/utils.cc
浏览文件 @
23d20e30
...
...
@@ -19,12 +19,24 @@
PADDLE_DEFINE_EXPORTED_bool
(
prim_enabled
,
false
,
"enable_prim or not"
);
namespace
paddle
{
namespace
prim
{
bool
PrimCommonUtils
::
IsPrimEnabled
()
{
return
StaticCompositeContext
::
Instance
().
IsPrimEnabled
();
bool
PrimCommonUtils
::
Is
Bwd
PrimEnabled
()
{
return
StaticCompositeContext
::
Instance
().
Is
Bwd
PrimEnabled
();
}
void
PrimCommonUtils
::
SetPrimEnabled
(
bool
enable_prim
)
{
return
StaticCompositeContext
::
Instance
().
SetPrimEnabled
(
enable_prim
);
void
PrimCommonUtils
::
SetBwdPrimEnabled
(
bool
enable_prim
)
{
return
StaticCompositeContext
::
Instance
().
SetBwdPrimEnabled
(
enable_prim
);
}
bool
PrimCommonUtils
::
IsFwdPrimEnabled
()
{
return
StaticCompositeContext
::
Instance
().
IsFwdPrimEnabled
();
}
void
PrimCommonUtils
::
SetFwdPrimEnabled
(
bool
enable_prim
)
{
return
StaticCompositeContext
::
Instance
().
SetFwdPrimEnabled
(
enable_prim
);
}
void
PrimCommonUtils
::
SetAllPrimEnabled
(
bool
enable_prim
)
{
return
StaticCompositeContext
::
Instance
().
SetAllPrimEnabled
(
enable_prim
);
}
}
// namespace prim
}
// namespace paddle
paddle/fluid/prim/utils/utils.h
浏览文件 @
23d20e30
...
...
@@ -18,8 +18,11 @@ namespace paddle {
namespace
prim
{
class
PrimCommonUtils
{
public:
static
bool
IsPrimEnabled
();
static
void
SetPrimEnabled
(
bool
enabled
);
static
bool
IsBwdPrimEnabled
();
static
void
SetBwdPrimEnabled
(
bool
enabled
);
static
bool
IsFwdPrimEnabled
();
static
void
SetFwdPrimEnabled
(
bool
enabled
);
static
void
SetAllPrimEnabled
(
bool
enabled
);
};
}
// namespace prim
}
// namespace paddle
paddle/fluid/pybind/pybind.cc
浏览文件 @
23d20e30
...
...
@@ -660,8 +660,16 @@ PYBIND11_MODULE(libpaddle, m) {
return
oss
.
str
();
});
m
.
def
(
"set_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
SetPrimEnabled
);
m
.
def
(
"is_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
IsPrimEnabled
);
m
.
def
(
"__set_bwd_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
SetBwdPrimEnabled
);
m
.
def
(
"_is_bwd_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
IsBwdPrimEnabled
);
m
.
def
(
"__set_fwd_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
SetFwdPrimEnabled
);
m
.
def
(
"_is_fwd_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
IsFwdPrimEnabled
);
m
.
def
(
"__set_all_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
SetAllPrimEnabled
);
m
.
def
(
"set_num_threads"
,
&
platform
::
SetNumThreads
);
m
.
def
(
"disable_signal_handler"
,
&
DisableSignalHandler
);
...
...
@@ -1264,8 +1272,9 @@ All parameter, weight, gradient are variables in Paddle.
// priority of GradCompOpMaker is less than GradCompMaker for better
// performance.
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
grad_op_descs
;
if
(
paddle
::
prim
::
PrimCommonUtils
::
IsPrimEnabled
())
{
if
(
paddle
::
prim
::
PrimCommonUtils
::
Is
Bwd
PrimEnabled
())
{
if
(
grad_comp_op_maker
!=
nullptr
)
{
VLOG
(
3
)
<<
"Runing composite fun for "
<<
op_desc
.
Type
();
grad_op_descs
=
grad_comp_op_maker
(
op_desc
,
no_grad_set
,
&
grad_to_var
,
...
...
paddle/phi/api/yaml/legacy_backward.yaml
浏览文件 @
23d20e30
...
...
@@ -42,7 +42,7 @@
kernel
:
func
:
add_grad
no_need_buffer
:
x, y
composite
:
add_grad(
Tensor x, Tensor y, Tensor out_grad, int
axis)
composite
:
add_grad(
x, y, out_grad,
axis)
backward
:
add_double_grad
inplace
:
(out_grad -> x_grad)
...
...
@@ -390,7 +390,7 @@
param
:
[
x
,
y
]
kernel
:
func
:
divide_grad
composite
:
divide_grad(
Tensor x, Tensor y, Tensor out, Tensor out_grad, int axis =
-1)
composite
:
divide_grad(
x, y, out, out_grad,
-1)
backward
:
divide_double_grad
-
backward_op
:
dropout_grad
...
...
@@ -1319,7 +1319,7 @@
kernel
:
func
:
subtract_grad
no_need_buffer
:
x, y
composite
:
subtract_grad(
Tensor x, Tensor y, Tensor out_grad, int
axis)
composite
:
subtract_grad(
x, y, out_grad,
axis)
backward
:
subtract_double_grad
inplace
:
(out_grad -> x_grad)
...
...
python/paddle/fluid/backward.py
浏览文件 @
23d20e30
...
...
@@ -1493,14 +1493,15 @@ def _append_backward_ops_(
# remove some backward ops
# TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem
if
not
core
.
is
_prim_enabled
():
if
not
core
.
_is_bwd
_prim_enabled
():
not_need_ops
=
_find_not_need_ops
(
grad_op_descs
,
ops
,
input_grad_names_set
)
grad_op_descs
=
[
op_desc
for
op_desc
in
grad_op_descs
if
op_desc
not
in
not_need_ops
]
else
:
logging
.
debug
(
"Runing backward composite and disable find_not_need_ops"
)
# append op_desc in grad_op_descs to target_block
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
...
...
python/paddle/fluid/core.py
浏览文件 @
23d20e30
...
...
@@ -17,6 +17,7 @@ import sys
import
os
import
warnings
import
platform
import
logging
has_paddle_dy_lib
=
False
...
...
@@ -305,8 +306,13 @@ try:
from
.libpaddle
import
_Profiler
,
_ProfilerResult
,
_RecordEvent
from
.libpaddle
import
_set_current_stream
from
.libpaddle
import
_get_phi_kernel_name
from
.libpaddle
import
set_prim_enabled
from
.libpaddle
import
is_prim_enabled
# prim controller flags
from
.libpaddle
import
__set_bwd_prim_enabled
from
.libpaddle
import
_is_bwd_prim_enabled
from
.libpaddle
import
__set_fwd_prim_enabled
from
.libpaddle
import
_is_fwd_prim_enabled
from
.libpaddle
import
__set_all_prim_enabled
if
sys
.
platform
!=
'win32'
:
from
.libpaddle
import
_set_process_pids
...
...
@@ -373,36 +379,98 @@ def set_paddle_lib_path():
set_paddle_lib_path
()
# We have 3 FLAGS to judge whether prim is enabled
# FLAGS_prim_forward: Open or close forward prim strategy
# FLAGS_prim_backward: Open or close backward prim strategy
# FLAGS_prim_all: Open or close all prim strategy
#
#
# Priorities:
# if With CINN and Dy2St:
# # # _set_prim_all_enabled > FLAGS_prim_all > check_and_set_prim_all_enabled == _set_prim_backward_enabled == _set_prim_backward_enabled > FLAGS_prim_forward == FLAGS_prim_backward
# else:
# # # _set_prim_all_enabled > FLAGS_prim_all == check_and_set_prim_all_enabled == _set_prim_backward_enabled == _set_prim_backward_enabled > FLAGS_prim_forward == FLAGS_prim_backward
def
__sync_stat_with_flag
(
flag
):
if
flag
is
"FLAGS_prim_forward"
:
flag_value
=
os
.
getenv
(
"FLAGS_prim_forward"
)
assert
flag_value
is
not
None
flag_value
=
flag_value
.
lower
()
if
flag_value
==
"false"
:
__set_fwd_prim_enabled
(
False
)
elif
flag_value
==
"true"
:
__set_fwd_prim_enabled
(
True
)
else
:
raise
TypeError
(
f
"flag
{
flag
}
should be true or false."
)
logging
.
debug
(
"forward prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()))
elif
flag
is
"FLAGS_prim_backward"
:
flag_value
=
os
.
getenv
(
"FLAGS_prim_backward"
)
assert
flag_value
is
not
None
flag_value
=
flag_value
.
lower
()
if
flag_value
==
"false"
:
__set_bwd_prim_enabled
(
False
)
elif
flag_value
==
"true"
:
__set_bwd_prim_enabled
(
True
)
else
:
raise
TypeError
(
f
"flag
{
flag
}
should be true or false."
)
logging
.
debug
(
"backward prim enabled: "
,
bool
(
_is_bwd_prim_enabled
()))
elif
flag
is
"FLAGS_prim_all"
:
flag_value
=
os
.
getenv
(
"FLAGS_prim_all"
)
assert
flag_value
is
not
None
flag_value
=
flag_value
.
lower
()
if
flag_value
==
"false"
:
__set_all_prim_enabled
(
False
)
elif
flag_value
==
"true"
:
__set_all_prim_enabled
(
True
)
else
:
raise
TypeError
(
f
"flag
{
flag
}
should be true or false."
)
logging
.
debug
(
"all prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()
and
_is_bwd_prim_enabled
()),
)
else
:
raise
TypeError
(
f
"We only support FLAGS_prim_forward/FLAGS_prim_backward/FLAGS_prim_all but we got
{
flag
}
."
)
def
set_prim_forward
(
value
):
"""set flag FLAGS_prim_forward."""
flag
=
str
(
value
)
if
flag
.
lower
()
not
in
[
"true"
,
"false"
,
"debug"
]:
raise
TypeError
(
f
"flag
{
flag
}
should be string of bool or 'debug'."
)
os
.
environ
[
"FLAGS_prim_forward"
]
=
flag
return
def
_set_prim_backward_enabled
(
value
):
__set_bwd_prim_enabled
(
bool
(
value
))
logging
.
debug
(
"backward prim enabled: "
,
bool
(
_is_bwd_prim_enabled
()))
def
enable_prim_forward
():
flag
=
os
.
getenv
(
"FLAGS_prim_forward"
,
"true"
).
lower
()
if
flag
==
"false"
:
return
False
if
flag
==
"debug"
:
return
"debug"
return
True
def
_set_prim_forward_enabled
(
value
):
__set_fwd_prim_enabled
(
bool
(
value
))
logging
.
debug
(
"forward prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()))
def
set_prim_backward
(
value
):
"""set flag FLAGS_prim_backward,"""
flag
=
str
(
value
)
if
flag
.
lower
()
not
in
[
"true"
,
"false"
]:
raise
TypeError
(
f
"flag
{
flag
}
should be bool or string of bool."
)
os
.
environ
[
"FLAGS_prim_backward"
]
=
flag
return
def
_set_prim_all_enabled
(
value
):
__set_all_prim_enabled
(
bool
(
value
))
logging
.
debug
(
"all prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()
and
_is_bwd_prim_enabled
()),
)
def
enable_prim_backward
():
flag
=
os
.
getenv
(
"FLAGS_prim_backward"
,
"true"
)
if
flag
.
lower
()
==
"false"
:
return
False
return
True
def
__sync_prim_backward_status
():
flag_value
=
os
.
getenv
(
"FLAGS_prim_backward"
)
if
flag_value
is
None
:
logging
.
debug
(
"backward prim enabled: "
,
bool
(
_is_bwd_prim_enabled
()))
else
:
__sync_stat_with_flag
(
"FLAGS_prim_backward"
)
def
__sync_prim_forward_status
():
flag_value
=
os
.
getenv
(
"FLAGS_prim_forward"
)
if
flag_value
is
None
:
logging
.
debug
(
"forward prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()))
else
:
__sync_stat_with_flag
(
"FLAGS_prim_forward"
)
def
check_and_set_prim_all_enabled
():
flag_value
=
os
.
getenv
(
"FLAGS_prim_all"
)
if
flag_value
is
None
:
__sync_prim_backward_status
()
__sync_prim_forward_status
()
else
:
__sync_stat_with_flag
(
"FLAGS_prim_all"
)
python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py
浏览文件 @
23d20e30
...
...
@@ -19,6 +19,7 @@ from utils import TOLERANCE
import
paddle
import
paddle.nn.functional
as
F
from
paddle.fluid
import
core
def
generate_data
(
shape
,
dtype
=
"float32"
):
...
...
@@ -72,6 +73,7 @@ class TestCompositeSoftmax(unittest.TestCase):
def
cal_composite
(
self
,
inputs
):
paddle
.
enable_static
()
core
.
_set_prim_forward_enabled
(
True
)
startup_program
=
paddle
.
static
.
Program
()
main_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
...
...
@@ -95,6 +97,7 @@ class TestCompositeSoftmax(unittest.TestCase):
exe
.
run
(
startup_program
)
res
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
inputs
},
fetch_list
=
[
y
])
paddle
.
disable_static
()
core
.
_set_prim_forward_enabled
(
False
)
return
res
def
compare_forward
(
self
):
...
...
python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py
浏览文件 @
23d20e30
...
...
@@ -78,6 +78,7 @@ class TestCompositeSoftmax(unittest.TestCase):
def
cal_composite_grad
(
self
,
inputs
):
paddle
.
enable_static
()
core
.
_set_prim_all_enabled
(
True
)
startup_program
=
paddle
.
static
.
Program
()
main_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
...
...
@@ -108,6 +109,7 @@ class TestCompositeSoftmax(unittest.TestCase):
exe
.
run
(
startup_program
)
res
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
inputs
},
fetch_list
=
[
z
])
paddle
.
disable_static
()
core
.
_set_prim_all_enabled
(
False
)
return
res
def
compare_backward
(
self
):
...
...
@@ -139,7 +141,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
"test composite softmax and prim backward"
def
setUp
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
self
.
dtypes
=
[
"float32"
]
self
.
shapes
=
[[
2
,
3
,
4
],
[
2
,
3
]]
self
.
axes
=
[
-
1
,
0
,
1
]
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py
浏览文件 @
23d20e30
...
...
@@ -236,11 +236,11 @@ class TestBert(unittest.TestCase):
self
.
verify_predict
()
def
test_train_composite
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
static_loss
,
static_ppl
=
self
.
train_static
(
self
.
bert_config
,
self
.
data_reader
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
dygraph_loss
,
dygraph_ppl
=
self
.
train_dygraph
(
self
.
bert_config
,
self
.
data_reader
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py
浏览文件 @
23d20e30
...
...
@@ -47,7 +47,6 @@ class TestPrimForward(unittest.TestCase):
"""
def
setUp
(
self
):
core
.
set_prim_backward
(
False
)
paddle
.
seed
(
2022
)
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
x
.
stop_gradient
=
False
...
...
@@ -58,6 +57,7 @@ class TestPrimForward(unittest.TestCase):
sgd
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.1
,
parameters
=
net
.
parameters
()
)
core
.
_set_prim_forward_enabled
(
use_prim
)
if
use_prim
:
net
=
apply_to_static
(
net
,
use_prim
)
...
...
@@ -103,12 +103,12 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self
.
x
.
stop_gradient
=
False
def
train
(
self
,
use_prim
):
core
.
set_prim_backward
(
True
)
paddle
.
seed
(
2022
)
net
=
PrimeNet
()
sgd
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.1
,
parameters
=
net
.
parameters
()
)
core
.
_set_prim_all_enabled
(
use_prim
)
if
use_prim
:
net
=
apply_to_static
(
net
,
use_prim
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py
浏览文件 @
23d20e30
...
...
@@ -427,10 +427,10 @@ class TestResnet(unittest.TestCase):
)
self
.
verify_predict
()
def
test_resnet_composite
(
self
):
core
.
set_prim
_enabled
(
True
)
def
test_resnet_composite
_backward
(
self
):
core
.
_set_prim_backward
_enabled
(
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
dygraph_loss
=
self
.
train
(
to_static
=
True
)
np
.
testing
.
assert_allclose
(
static_loss
,
...
...
@@ -440,65 +440,13 @@ class TestResnet(unittest.TestCase):
static_loss
,
dygraph_loss
),
)
core
.
set_prim_enabled
(
False
)
def
test_in_static_mode_mkldnn
(
self
):
fluid
.
set_flags
({
'FLAGS_use_mkldnn'
:
True
})
try
:
if
paddle
.
fluid
.
core
.
is_compiled_with_mkldnn
():
self
.
resnet_helper
.
train
(
to_static
=
True
)
finally
:
fluid
.
set_flags
({
'FLAGS_use_mkldnn'
:
False
})
class
TestResnetPrim
(
unittest
.
TestCase
):
"test prim forward + prim backward + to_static"
def
setUp
(
self
):
self
.
resnet_helper
=
ResNetHelper
()
def
train
(
self
,
to_static
):
paddle
.
jit
.
enable_to_static
(
to_static
)
return
self
.
resnet_helper
.
train
(
to_static
)
def
verify_predict
(
self
):
image
=
np
.
random
.
random
([
1
,
3
,
224
,
224
]).
astype
(
'float32'
)
dy_pre
=
self
.
resnet_helper
.
predict_dygraph
(
image
)
st_pre
=
self
.
resnet_helper
.
predict_static
(
image
)
dy_jit_pre
=
self
.
resnet_helper
.
predict_dygraph_jit
(
image
)
predictor_pre
=
self
.
resnet_helper
.
predict_analysis_inference
(
image
)
np
.
testing
.
assert_allclose
(
dy_pre
,
st_pre
,
rtol
=
1e-05
,
err_msg
=
'dy_pre:
\n
{}
\n
, st_pre:
\n
{}.'
.
format
(
dy_pre
,
st_pre
),
)
np
.
testing
.
assert_allclose
(
dy_jit_pre
,
st_pre
,
rtol
=
1e-05
,
err_msg
=
'dy_jit_pre:
\n
{}
\n
, st_pre:
\n
{}.'
.
format
(
dy_jit_pre
,
st_pre
),
)
np
.
testing
.
assert_allclose
(
predictor_pre
,
st_pre
,
rtol
=
1e-05
,
err_msg
=
'predictor_pre:
\n
{}
\n
, st_pre:
\n
{}.'
.
format
(
predictor_pre
,
st_pre
),
)
def
test_resnet_composite
(
self
):
def
test_resnet_composite_forward_backward
(
self
):
plat
=
platform
.
system
()
if
plat
==
"Linux"
:
print
(
"=================== origin resnet ==================="
)
core
.
set_prim_enabled
(
False
)
core
.
_set_prim_all_enabled
(
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
print
(
"======= resnet with prim forward and backward ======="
)
core
.
set_prim_enabled
(
True
)
core
.
set_prim_forward
(
"debug"
)
core
.
_set_prim_all_enabled
(
False
)
dygraph_loss
=
self
.
train
(
to_static
=
True
)
np
.
testing
.
assert_allclose
(
static_loss
,
...
...
@@ -508,10 +456,17 @@ class TestResnetPrim(unittest.TestCase):
static_loss
,
dygraph_loss
),
)
core
.
set_prim_enabled
(
False
)
else
:
pass
def
test_in_static_mode_mkldnn
(
self
):
fluid
.
set_flags
({
'FLAGS_use_mkldnn'
:
True
})
try
:
if
paddle
.
fluid
.
core
.
is_compiled_with_mkldnn
():
self
.
resnet_helper
.
train
(
to_static
=
True
)
finally
:
fluid
.
set_flags
({
'FLAGS_use_mkldnn'
:
False
})
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_amp.py
浏览文件 @
23d20e30
...
...
@@ -130,9 +130,9 @@ class TestResnet(unittest.TestCase):
)
def
test_resnet_composite
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
dygraph_loss
=
self
.
train
(
to_static
=
False
)
np
.
testing
.
assert_allclose
(
static_loss
,
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py
浏览文件 @
23d20e30
...
...
@@ -137,9 +137,9 @@ class TestResnet(unittest.TestCase):
def
test_resnet_composite
(
self
):
if
fluid
.
is_compiled_with_cuda
():
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
dygraph_loss
=
self
.
train
(
to_static
=
False
)
# NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here.
np
.
testing
.
assert_allclose
(
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py
浏览文件 @
23d20e30
...
...
@@ -426,9 +426,9 @@ class TestResnet(unittest.TestCase):
self
.
verify_predict
()
def
test_resnet_composite
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
dygraph_loss
=
self
.
train
(
to_static
=
False
)
np
.
testing
.
assert_allclose
(
static_loss
,
...
...
python/paddle/fluid/tests/unittests/prim/prim/CMakeLists.txt
浏览文件 @
23d20e30
...
...
@@ -9,3 +9,4 @@ foreach(TEST_OP ${TEST_OPS})
endforeach
()
add_subdirectory
(
vjp
)
add_subdirectory
(
flags
)
python/paddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt
0 → 100644
浏览文件 @
23d20e30
file
(
GLOB TEST_OPS
RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"test_*.py"
)
string
(
REPLACE
".py"
""
TEST_OPS
"
${
TEST_OPS
}
"
)
foreach
(
TEST_OP
${
TEST_OPS
}
)
py_test_modules
(
${
TEST_OP
}
MODULES
${
TEST_OP
}
ENVS
${
GC_ENVS
}
)
endforeach
()
python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py
0 → 100644
浏览文件 @
23d20e30
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
unittest
from
paddle.fluid
import
core
class
TestPrimFlags
(
unittest
.
TestCase
):
def
test_prim_flags
(
self
):
self
.
assertFalse
(
core
.
_is_bwd_prim_enabled
())
self
.
assertFalse
(
core
.
_is_fwd_prim_enabled
())
os
.
environ
[
'FLAGS_prim_backward'
]
=
"True"
core
.
check_and_set_prim_all_enabled
()
self
.
assertTrue
(
core
.
_is_bwd_prim_enabled
())
os
.
environ
[
'FLAGS_prim_forward'
]
=
"True"
core
.
check_and_set_prim_all_enabled
()
self
.
assertTrue
(
core
.
_is_fwd_prim_enabled
())
os
.
environ
[
'FLAGS_prim_all'
]
=
"False"
core
.
check_and_set_prim_all_enabled
()
self
.
assertFalse
(
core
.
_is_bwd_prim_enabled
())
self
.
assertFalse
(
core
.
_is_fwd_prim_enabled
())
os
.
environ
[
'FLAGS_prim_all'
]
=
"True"
core
.
check_and_set_prim_all_enabled
()
self
.
assertTrue
(
core
.
_is_bwd_prim_enabled
())
self
.
assertTrue
(
core
.
_is_fwd_prim_enabled
())
del
os
.
environ
[
'FLAGS_prim_all'
]
os
.
environ
[
'FLAGS_prim_backward'
]
=
"False"
core
.
check_and_set_prim_all_enabled
()
self
.
assertFalse
(
core
.
_is_bwd_prim_enabled
())
os
.
environ
[
'FLAGS_prim_forward'
]
=
"False"
core
.
check_and_set_prim_all_enabled
()
self
.
assertFalse
(
core
.
_is_fwd_prim_enabled
())
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py
浏览文件 @
23d20e30
...
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
@
param
.
parameterized_class
(
...
...
@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
def
test_tanh_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
...
@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
return
res
[
0
].
numpy
(),
res
[
1
].
numpy
()
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
...
@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py
浏览文件 @
23d20e30
...
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
@
param
.
parameterized_class
(
...
...
@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
def
test_tanh_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
...
@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
return
res
[
0
].
numpy
(),
res
[
1
].
numpy
()
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
...
@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py
浏览文件 @
23d20e30
...
...
@@ -32,14 +32,14 @@ from paddle.fluid import core
class
TestExpGradComp
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
cls
.
primal
=
cls
.
primal
.
astype
(
cls
.
dtype
)
if
cls
.
cotangent
is
not
None
:
cls
.
cotangent
=
cls
.
cotangent
.
astype
(
cls
.
dtype
)
@
classmethod
def
tearDownClass
(
cls
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
def
test_exp_grad_comp
(
self
):
def
actual
(
primal
,
cotangent
):
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py
浏览文件 @
23d20e30
...
...
@@ -62,7 +62,7 @@ class TestExpandGradComp(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
def
test_comp
(
self
):
def
func
(
primal
,
cotangent
,
shape
):
...
...
@@ -74,11 +74,11 @@ class TestExpandGradComp(unittest.TestCase):
]
def
actual
(
primal
,
cotangent
,
shape
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
return
func
(
primal
,
cotangent
,
shape
)
def
desired
(
primal
,
cotangent
,
shape
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
return
func
(
primal
,
cotangent
,
shape
)
np
.
testing
.
assert_allclose
(
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py
浏览文件 @
23d20e30
...
...
@@ -81,10 +81,10 @@ class TestMultiplyGradComp(unittest.TestCase):
return
[
g
for
g
in
grads
if
g
is
not
None
]
def
test_comp
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
actual
=
self
.
vjp
()
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
desired
=
self
.
vjp
()
for
i
,
j
in
zip
(
actual
,
desired
):
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py
浏览文件 @
23d20e30
...
...
@@ -22,7 +22,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
@
param
.
parameterized_class
(
...
...
@@ -63,7 +63,7 @@ class TestSqrtGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py
浏览文件 @
23d20e30
...
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
@
param
.
parameterized_class
(
...
...
@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
def
test_tanh_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
...
@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
return
res
[
0
].
numpy
(),
res
[
1
].
numpy
()
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
...
@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py
浏览文件 @
23d20e30
...
...
@@ -21,7 +21,7 @@ from paddle.fluid import core
def
actual
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
x
=
paddle
.
to_tensor
(
primal
,
dtype
=
'float32'
,
stop_gradient
=
False
)
v
=
paddle
.
to_tensor
(
cotangent
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
sum
(
x
,
axis
=
axis
,
keepdim
=
keep_dim
)
...
...
@@ -30,7 +30,7 @@ def actual(primal, cotangent, axis, keep_dim):
def
desired
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
x
=
paddle
.
to_tensor
(
primal
,
dtype
=
'float32'
,
stop_gradient
=
False
)
v
=
paddle
.
to_tensor
(
cotangent
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
sum
(
x
,
axis
=
axis
,
keepdim
=
keep_dim
)
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py
浏览文件 @
23d20e30
...
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
@
param
.
parameterized_class
(
...
...
@@ -74,7 +74,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py
浏览文件 @
23d20e30
...
...
@@ -81,7 +81,7 @@ class TestAddGradComp(unittest.TestCase):
self
.
x
.
stop_gradient
=
False
self
.
y
.
stop_gradient
=
False
net
=
PrimeNet
()
core
.
set_prim
_enabled
(
use_prim
)
core
.
_set_prim_backward
_enabled
(
use_prim
)
net
=
apply_to_static
(
net
,
use_cinn
)
out
=
net
(
self
.
x
,
self
.
y
)
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
,
self
.
y
])
...
...
@@ -104,7 +104,7 @@ class TestAddGradComp(unittest.TestCase):
def
test_tanh_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal0'
,
primal0
.
shape
,
primal0
.
dtype
)
...
...
@@ -126,7 +126,7 @@ class TestAddGradComp(unittest.TestCase):
return
out
[
0
],
out
[
1
]
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
...
...
@@ -167,7 +167,7 @@ class TestAddGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py
浏览文件 @
23d20e30
...
...
@@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase):
self
.
x
.
stop_gradient
=
False
self
.
y
.
stop_gradient
=
False
net
=
PrimeNet
()
core
.
set_prim
_enabled
(
use_prim
)
core
.
_set_prim_backward
_enabled
(
use_prim
)
net
=
apply_to_static
(
net
,
use_cinn
)
out
=
net
(
self
.
x
,
self
.
y
)
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
,
self
.
y
])
...
...
@@ -107,7 +107,7 @@ class TestDivGradComp(unittest.TestCase):
paddle
.
enable_static
()
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal0'
,
primal0
.
shape
,
primal0
.
dtype
)
...
...
@@ -130,7 +130,7 @@ class TestDivGradComp(unittest.TestCase):
return
out
[
0
],
out
[
1
]
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
...
...
@@ -172,7 +172,7 @@ class TestDivGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
paddle
.
disable_static
()
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py
浏览文件 @
23d20e30
...
...
@@ -81,7 +81,7 @@ class TestDivGradComp(unittest.TestCase):
self
.
x
.
stop_gradient
=
False
self
.
y
.
stop_gradient
=
False
net
=
PrimeNet
()
core
.
set_prim
_enabled
(
use_prim
)
core
.
_set_prim_backward
_enabled
(
use_prim
)
net
=
apply_to_static
(
net
,
use_cinn
)
out
=
net
(
self
.
x
,
self
.
y
)
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
,
self
.
y
])
...
...
@@ -104,7 +104,7 @@ class TestDivGradComp(unittest.TestCase):
def
test_tanh_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal0'
,
primal0
.
shape
,
primal0
.
dtype
)
...
...
@@ -126,7 +126,7 @@ class TestDivGradComp(unittest.TestCase):
return
out
[
0
],
out
[
1
]
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
...
...
@@ -167,7 +167,7 @@ class TestDivGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py
浏览文件 @
23d20e30
...
...
@@ -33,14 +33,14 @@ from paddle.fluid import core
class
TestExpGradComp
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
cls
.
primal
=
cls
.
primal
.
astype
(
cls
.
dtype
)
if
cls
.
cotangent
is
not
None
:
cls
.
cotangent
=
cls
.
cotangent
.
astype
(
cls
.
dtype
)
@
classmethod
def
tearDownClass
(
cls
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
def
setUp
(
self
):
paddle
.
enable_static
()
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_expand_grad.py
浏览文件 @
23d20e30
...
...
@@ -71,7 +71,7 @@ class TestExpandGradComp(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
paddle
.
disable_static
()
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
def
test_comp
(
self
):
def
func
(
primal
,
cotangent
,
shape
):
...
...
@@ -93,11 +93,11 @@ class TestExpandGradComp(unittest.TestCase):
)[
0
]
def
actual
(
primal
,
cotangent
,
shape
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
return
func
(
primal
,
cotangent
,
shape
)
def
desired
(
primal
,
cotangent
,
shape
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
return
func
(
primal
,
cotangent
,
shape
)
np
.
testing
.
assert_allclose
(
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py
浏览文件 @
23d20e30
...
...
@@ -108,10 +108,10 @@ class TestMultiplyGradComp(unittest.TestCase):
def
test_comp
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
actual
=
self
.
vjp
()
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
desired
=
self
.
vjp
()
self
.
assertEqual
(
len
(
actual
),
len
(
desired
))
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py
浏览文件 @
23d20e30
...
...
@@ -16,7 +16,7 @@ import unittest
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
import
autograd
import
autograd.numpy
...
...
@@ -60,7 +60,7 @@ class TestSqrtGradComp(unittest.TestCase):
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
x
.
stop_gradient
=
False
net
=
PrimeNet
()
core
.
set_prim
_enabled
(
use_prim
)
core
.
_set_prim_backward
_enabled
(
use_prim
)
net
=
apply_to_static
(
net
,
use_cinn
)
out
=
net
(
self
.
x
)
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
])
...
...
@@ -109,7 +109,7 @@ class TestSqrtGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py
浏览文件 @
23d20e30
...
...
@@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase):
self
.
x
.
stop_gradient
=
False
self
.
y
.
stop_gradient
=
False
net
=
PrimeNet
()
core
.
set_prim
_enabled
(
use_prim
)
core
.
_set_prim_backward
_enabled
(
use_prim
)
net
=
apply_to_static
(
net
,
use_cinn
)
out
=
net
(
self
.
x
,
self
.
y
)
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
,
self
.
y
])
...
...
@@ -105,7 +105,7 @@ class TestDivGradComp(unittest.TestCase):
def
test_tanh_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal0'
,
primal0
.
shape
,
primal0
.
dtype
)
...
...
@@ -127,7 +127,7 @@ class TestDivGradComp(unittest.TestCase):
return
out
[
0
],
out
[
1
]
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
...
...
@@ -168,7 +168,7 @@ class TestDivGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py
浏览文件 @
23d20e30
...
...
@@ -21,7 +21,7 @@ from paddle.fluid import core
def
actual
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal'
,
primal
.
shape
,
primal
.
dtype
)
...
...
@@ -40,7 +40,7 @@ def actual(primal, cotangent, axis, keep_dim):
def
desired
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal'
,
primal
.
shape
,
primal
.
dtype
)
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py
浏览文件 @
23d20e30
...
...
@@ -16,7 +16,7 @@ import unittest
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
import
autograd
import
autograd.numpy
...
...
@@ -60,7 +60,7 @@ class TestTanhGradComp(unittest.TestCase):
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
x
.
stop_gradient
=
False
net
=
PrimeNet
()
core
.
set_prim
_enabled
(
use_prim
)
core
.
_set_prim_backward
_enabled
(
use_prim
)
net
=
apply_to_static
(
net
,
use_cinn
)
out
=
net
(
self
.
x
)
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
])
...
...
@@ -109,7 +109,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol
=
1e-6
,
atol
=
0
,
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_disabled.py
浏览文件 @
23d20e30
...
...
@@ -17,7 +17,7 @@ import unittest
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
import
parameterized
as
param
...
...
python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py
浏览文件 @
23d20e30
...
...
@@ -17,7 +17,7 @@ import unittest
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
import
parameterized
as
param
...
...
@@ -77,7 +77,7 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase):
)
print
(
actual
)
self
.
assertEquals
(
actual
,
self
.
desired_ops
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py
浏览文件 @
23d20e30
...
...
@@ -135,9 +135,9 @@ class TestResnet50Accuracy(unittest.TestCase):
loop_num
=
10
feed
=
self
.
generate_random_data
(
loop_num
)
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
loss_c
=
self
.
train
(
place
,
loop_num
,
feed
,
use_cinn
=
True
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
loss_p
=
self
.
train
(
place
,
loop_num
,
feed
,
use_cinn
=
True
)
print
(
"Losses of Composite + CINN:"
)
print
(
loss_c
)
...
...
python/paddle/incubate/autograd/primapi.py
浏览文件 @
23d20e30
...
...
@@ -218,7 +218,7 @@ def grad(outputs, inputs, grad_outputs=None):
@
framework
.
static_only
def
to_prim
(
blocks
):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops."""
if
not
core
.
enable_prim_forwar
d
():
if
not
core
.
_is_fwd_prim_enable
d
():
return
if
isinstance
(
blocks
,
paddle
.
fluid
.
framework
.
Block
):
logging
.
info
(
"Atomize composite op to primitive ops begin."
)
...
...
@@ -235,5 +235,6 @@ def to_prim(blocks):
f
"Expect block or sequence of blocks, but got
{
type
(
blocks
)
}
."
)
with
framework
.
program_guard
(
main_program
):
print
(
"Running lowering for forward..."
)
primx
.
_lower_composite
(
blocks
)
return
python/paddle/jit/dy2static/partial_program.py
浏览文件 @
23d20e30
...
...
@@ -571,12 +571,9 @@ class PartialProgramLayer:
targets
.
append
(
program
.
global_block
().
var
(
out
.
name
))
if
targets
:
enable_prim
=
self
.
_build_strategy
.
build_cinn_pass
if
enable_prim
and
core
.
enable_prim_backward
():
core
.
set_prim_enabled
(
True
)
backward
.
gradients
(
targets
=
targets
,
inputs
=
[])
core
.
set_prim_enabled
(
False
)
else
:
if
self
.
_build_strategy
.
build_cinn_pass
:
# TODO(Jiabin): Change this to True if we need this to be default option
core
.
check_and_set_prim_all_enabled
()
backward
.
gradients
(
targets
=
targets
,
inputs
=
[])
start_idx
=
len
(
main_program
.
block
(
0
).
ops
)
+
2
*
len
(
...
...
python/paddle/jit/dy2static/program_translator.py
浏览文件 @
23d20e30
...
...
@@ -1092,8 +1092,9 @@ class ProgramCache:
def
_build_once
(
self
,
cache_key
):
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
enable_prim
=
cache_key
.
kwargs
[
'build_strategy'
].
build_cinn_pass
if
enable_prim
and
core
.
enable_prim_backward
():
core
.
set_prim_enabled
(
True
)
if
enable_prim
:
# TODO(Jiabin): Change this to True if we need this to be default option
core
.
check_and_set_prim_all_enabled
()
concrete_program
=
ConcreteProgram
.
from_func_spec
(
func_spec
=
cache_key
.
function_spec
,
...
...
@@ -1103,9 +1104,7 @@ class ProgramCache:
**
cache_key
.
kwargs
)
if
enable_prim
or
core
.
enable_prim_forward
()
==
"debug"
:
concrete_program
.
_to_prim
()
core
.
set_prim_enabled
(
False
)
return
concrete_program
,
partial_program_from
(
concrete_program
)
def
__getitem__
(
self
,
item
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录