Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
23d20e30
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
23d20e30
编写于
1月 20, 2023
作者:
J
Jiabin Yang
提交者:
GitHub
1月 20, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Prim】Refactor prim flags system (#49930)
上级
44855da3
变更
49
隐藏空白更改
内联
并排
Showing
49 changed file
with
339 addition
and
206 deletion
+339
-206
paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
...le/fluid/eager/auto_code_generator/generator/eager_gen.py
+2
-2
paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc
paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc
+11
-1
paddle/fluid/prim/api/manual/backward/composite_backward_api.h
...e/fluid/prim/api/manual/backward/composite_backward_api.h
+1
-1
paddle/fluid/prim/tests/test_eager_prim.cc
paddle/fluid/prim/tests/test_eager_prim.cc
+8
-8
paddle/fluid/prim/tests/test_static_prim.cc
paddle/fluid/prim/tests/test_static_prim.cc
+4
-4
paddle/fluid/prim/utils/static/static_global_utils.cc
paddle/fluid/prim/utils/static/static_global_utils.cc
+2
-1
paddle/fluid/prim/utils/static/static_global_utils.h
paddle/fluid/prim/utils/static/static_global_utils.h
+13
-3
paddle/fluid/prim/utils/utils.cc
paddle/fluid/prim/utils/utils.cc
+16
-4
paddle/fluid/prim/utils/utils.h
paddle/fluid/prim/utils/utils.h
+5
-2
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+12
-3
paddle/phi/api/yaml/legacy_backward.yaml
paddle/phi/api/yaml/legacy_backward.yaml
+3
-3
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+3
-2
python/paddle/fluid/core.py
python/paddle/fluid/core.py
+96
-28
python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py
...d/tests/unittests/composite_ops/test_composite_softmax.py
+3
-0
python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py
...ts/unittests/composite_ops/test_composite_softmax_grad.py
+3
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py
...ddle/fluid/tests/unittests/dygraph_to_static/test_bert.py
+2
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py
...fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py
+2
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py
...le/fluid/tests/unittests/dygraph_to_static/test_resnet.py
+14
-59
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_amp.py
...luid/tests/unittests/dygraph_to_static/test_resnet_amp.py
+2
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py
...ests/unittests/dygraph_to_static/test_resnet_pure_fp16.py
+2
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py
...fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/CMakeLists.txt
python/paddle/fluid/tests/unittests/prim/prim/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt
...ddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt
+9
-0
python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py
.../fluid/tests/unittests/prim/prim/flags/test_prim_flags.py
+52
-0
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py
...ttests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py
+3
-3
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py
...ests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py
...nittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py
...unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py
...nittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py
...ests/unittests/prim/prim/vjp/static/test_comp_add_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py
...unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py
...ests/unittests/prim/prim/vjp/static/test_comp_div_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py
...ests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_expand_grad.py
...s/unittests/prim/prim/vjp/static/test_comp_expand_grad.py
+3
-3
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py
...unittests/prim/prim/vjp/static/test_comp_multiply_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py
...sts/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py
+3
-3
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py
...ests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py
+4
-4
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py
...ests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py
+2
-2
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py
...sts/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py
+3
-3
python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_disabled.py
...nittests/prim/test_comp_get_grad_op_desc_prim_disabled.py
+1
-1
python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py
...unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py
+2
-2
python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py
...n/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py
+2
-2
python/paddle/incubate/autograd/primapi.py
python/paddle/incubate/autograd/primapi.py
+2
-1
python/paddle/jit/dy2static/partial_program.py
python/paddle/jit/dy2static/partial_program.py
+4
-7
python/paddle/jit/dy2static/program_translator.py
python/paddle/jit/dy2static/program_translator.py
+4
-5
未找到文件。
paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
浏览文件 @
23d20e30
...
@@ -1841,7 +1841,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
...
@@ -1841,7 +1841,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if
is_composite_grad_api
and
next_grad_node_creation_str
!=
''
:
if
is_composite_grad_api
and
next_grad_node_creation_str
!=
''
:
next_grad_node_creation_str
=
f
"""
next_grad_node_creation_str
=
f
"""
if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
if (!paddle::prim::PrimCommonUtils::Is
Bwd
PrimEnabled()) {{
{
next_grad_node_creation_str
}
{
next_grad_node_creation_str
}
}}
}}
"""
"""
...
@@ -2261,7 +2261,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
...
@@ -2261,7 +2261,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# TODO(Ruting):using composite only when we don't have backward kernel in the future.
# TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif
is_composite_grad_api
:
elif
is_composite_grad_api
:
grad_function_call_str
=
f
"""
grad_function_call_str
=
f
"""
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
if (paddle::prim::PrimCommonUtils::Is
Bwd
PrimEnabled()) {{
{
indent
}{
composite_grad_api_namespace
}{
composite_grad_api_name
}{
composite_template_name
}
(
{
composite_grad_api_args_str
}
);
{
indent
}{
composite_grad_api_namespace
}{
composite_grad_api_name
}{
composite_template_name
}
(
{
composite_grad_api_args_str
}
);
VLOG(4) << "Composite api
{
composite_grad_api_name
}
is called ";
VLOG(4) << "Composite api
{
composite_grad_api_name
}
is called ";
}}else{{
}}else{{
...
...
paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc
浏览文件 @
23d20e30
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <string.h>
#include <memory>
#include <memory>
#include <sstream>
#include <sstream>
#include <string>
#include <string>
...
@@ -166,7 +167,16 @@ Tensor full<DescTensor>(const IntArray& shape,
...
@@ -166,7 +167,16 @@ Tensor full<DescTensor>(const IntArray& shape,
phi
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"We only support float32/float16 for full, but we got data type: %s"
,
"We only support float32/float16 for full, but we got data type: %s"
,
phi
::
DataTypeToString
(
dtype
)));
phi
::
DataTypeToString
(
dtype
)));
op
->
SetAttr
(
"value"
,
value
.
to
<
float
>
());
if
(
dtype
==
phi
::
DataType
::
FLOAT32
)
{
op
->
SetAttr
(
"value"
,
value
.
to
<
float
>
());
}
else
if
(
dtype
==
phi
::
DataType
::
FLOAT64
)
{
op
->
SetAttr
(
"str_value"
,
std
::
to_string
(
value
.
to
<
double
>
()));
}
else
if
(
dtype
==
phi
::
DataType
::
FLOAT16
)
{
op
->
SetAttr
(
"str_value"
,
std
::
to_string
(
value
.
to
<
float
>
()));
}
else
{
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"We only support float64/float32/float16 for full"
));
}
op
->
SetAttr
(
"dtype"
,
paddle
::
framework
::
TransToProtoVarType
(
dtype
));
op
->
SetAttr
(
"dtype"
,
paddle
::
framework
::
TransToProtoVarType
(
dtype
));
op
->
SetOutput
(
op
->
SetOutput
(
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
"Out"
,
{
std
::
static_pointer_cast
<
prim
::
DescTensor
>
(
out
.
impl
())
->
Name
()});
...
...
paddle/fluid/prim/api/manual/backward/composite_backward_api.h
浏览文件 @
23d20e30
...
@@ -192,7 +192,7 @@ void divide_grad(const Tensor& x,
...
@@ -192,7 +192,7 @@ void divide_grad(const Tensor& x,
}
// indicate we will compute dy
}
// indicate we will compute dy
if
(
dx
)
{
if
(
dx
)
{
// dx = (1/y) * dout
// dx = (1/y) * dout
auto
one_tensor
=
full
<
T
>
(
phi
::
vectorize
(
y
.
dims
()),
1.0
);
auto
one_tensor
=
full
<
T
>
(
phi
::
vectorize
(
y
.
dims
()),
1.0
,
y
.
dtype
()
);
auto
tmp0
=
divide
<
T
>
(
one_tensor
,
y
);
auto
tmp0
=
divide
<
T
>
(
one_tensor
,
y
);
auto
dx_res
=
multiply
<
T
>
(
tmp0
,
out_grad
);
auto
dx_res
=
multiply
<
T
>
(
tmp0
,
out_grad
);
if
(
y
.
dims
()
!=
x
.
dims
())
{
if
(
y
.
dims
()
!=
x
.
dims
())
{
...
...
paddle/fluid/prim/tests/test_eager_prim.cc
浏览文件 @
23d20e30
...
@@ -68,16 +68,16 @@ TEST(EagerPrim, TanhBackwardTest) {
...
@@ -68,16 +68,16 @@ TEST(EagerPrim, TanhBackwardTest) {
paddle
::
experimental
::
Tensor
out0
=
tanh_ad_func
(
tensor0
);
paddle
::
experimental
::
Tensor
out0
=
tanh_ad_func
(
tensor0
);
std
::
vector
<
paddle
::
experimental
::
Tensor
>
outs0
=
{
out0
};
std
::
vector
<
paddle
::
experimental
::
Tensor
>
outs0
=
{
out0
};
// Disable prim
// Disable prim
PrimCommonUtils
::
SetPrimEnabled
(
false
);
PrimCommonUtils
::
Set
Bwd
PrimEnabled
(
false
);
ASSERT_FALSE
(
PrimCommonUtils
::
IsPrimEnabled
());
ASSERT_FALSE
(
PrimCommonUtils
::
Is
Bwd
PrimEnabled
());
// 4. Run Backward
// 4. Run Backward
egr
::
Backward
(
outs0
,
{},
false
);
egr
::
Backward
(
outs0
,
{},
false
);
paddle
::
experimental
::
Tensor
out1
=
tanh_ad_func
(
tensor1
);
paddle
::
experimental
::
Tensor
out1
=
tanh_ad_func
(
tensor1
);
std
::
vector
<
paddle
::
experimental
::
Tensor
>
outs1
=
{
out1
};
std
::
vector
<
paddle
::
experimental
::
Tensor
>
outs1
=
{
out1
};
// Disable prim
// Disable prim
PrimCommonUtils
::
SetPrimEnabled
(
true
);
PrimCommonUtils
::
Set
Bwd
PrimEnabled
(
true
);
ASSERT_TRUE
(
PrimCommonUtils
::
IsPrimEnabled
());
ASSERT_TRUE
(
PrimCommonUtils
::
Is
Bwd
PrimEnabled
());
// 4. Run Backward
// 4. Run Backward
::
egr
::
Backward
(
outs1
,
{},
false
);
::
egr
::
Backward
(
outs1
,
{},
false
);
VLOG
(
7
)
VLOG
(
7
)
...
@@ -99,10 +99,10 @@ TEST(EagerPrim, TanhBackwardTest) {
...
@@ -99,10 +99,10 @@ TEST(EagerPrim, TanhBackwardTest) {
}
}
TEST
(
EagerPrim
,
TestFlags
)
{
TEST
(
EagerPrim
,
TestFlags
)
{
PrimCommonUtils
::
SetPrimEnabled
(
true
);
PrimCommonUtils
::
Set
Bwd
PrimEnabled
(
true
);
ASSERT_TRUE
(
PrimCommonUtils
::
IsPrimEnabled
());
ASSERT_TRUE
(
PrimCommonUtils
::
Is
Bwd
PrimEnabled
());
PrimCommonUtils
::
SetPrimEnabled
(
false
);
PrimCommonUtils
::
Set
Bwd
PrimEnabled
(
false
);
ASSERT_FALSE
(
PrimCommonUtils
::
IsPrimEnabled
());
ASSERT_FALSE
(
PrimCommonUtils
::
Is
Bwd
PrimEnabled
());
}
}
}
// namespace prim
}
// namespace prim
...
...
paddle/fluid/prim/tests/test_static_prim.cc
浏览文件 @
23d20e30
...
@@ -341,10 +341,10 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
...
@@ -341,10 +341,10 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
}
}
TEST
(
StaticPrim
,
TestFlags
)
{
TEST
(
StaticPrim
,
TestFlags
)
{
PrimCommonUtils
::
SetPrimEnabled
(
true
);
PrimCommonUtils
::
Set
Bwd
PrimEnabled
(
true
);
ASSERT_TRUE
(
PrimCommonUtils
::
IsPrimEnabled
());
ASSERT_TRUE
(
PrimCommonUtils
::
Is
Bwd
PrimEnabled
());
PrimCommonUtils
::
SetPrimEnabled
(
false
);
PrimCommonUtils
::
Set
Bwd
PrimEnabled
(
false
);
ASSERT_FALSE
(
PrimCommonUtils
::
IsPrimEnabled
());
ASSERT_FALSE
(
PrimCommonUtils
::
Is
Bwd
PrimEnabled
());
}
}
}
// namespace prim
}
// namespace prim
...
...
paddle/fluid/prim/utils/static/static_global_utils.cc
浏览文件 @
23d20e30
...
@@ -18,6 +18,7 @@ namespace paddle {
...
@@ -18,6 +18,7 @@ namespace paddle {
namespace
prim
{
namespace
prim
{
StaticCompositeContext
*
StaticCompositeContext
::
static_composite_context_
=
StaticCompositeContext
*
StaticCompositeContext
::
static_composite_context_
=
new
StaticCompositeContext
();
new
StaticCompositeContext
();
thread_local
bool
StaticCompositeContext
::
enable_prim_
=
false
;
thread_local
bool
StaticCompositeContext
::
enable_bwd_prim_
=
false
;
thread_local
bool
StaticCompositeContext
::
enable_fwd_prim_
=
false
;
}
// namespace prim
}
// namespace prim
}
// namespace paddle
}
// namespace paddle
paddle/fluid/prim/utils/static/static_global_utils.h
浏览文件 @
23d20e30
...
@@ -56,9 +56,18 @@ class StaticCompositeContext {
...
@@ -56,9 +56,18 @@ class StaticCompositeContext {
return
generator_
->
Generate
(
key
);
return
generator_
->
Generate
(
key
);
}
}
void
Set
PrimEnabled
(
bool
enable_prim
)
{
enable
_prim_
=
enable_prim
;
}
void
Set
BwdPrimEnabled
(
bool
enable_prim
)
{
enable_bwd
_prim_
=
enable_prim
;
}
bool
IsPrimEnabled
()
{
return
enable_prim_
;
}
bool
IsBwdPrimEnabled
()
{
return
enable_bwd_prim_
;
}
void
SetFwdPrimEnabled
(
bool
enable_prim
)
{
enable_fwd_prim_
=
enable_prim
;
}
bool
IsFwdPrimEnabled
()
{
return
enable_fwd_prim_
;
}
void
SetAllPrimEnabled
(
bool
enable_prim
)
{
enable_fwd_prim_
=
enable_prim
;
enable_bwd_prim_
=
enable_prim
;
}
private:
private:
StaticCompositeContext
()
StaticCompositeContext
()
...
@@ -66,7 +75,8 @@ class StaticCompositeContext {
...
@@ -66,7 +75,8 @@ class StaticCompositeContext {
framework
::
BlockDesc
*
current_block_desc_
;
framework
::
BlockDesc
*
current_block_desc_
;
std
::
unique_ptr
<
UniqueNameGenerator
>
generator_
;
std
::
unique_ptr
<
UniqueNameGenerator
>
generator_
;
static
thread_local
bool
enable_prim_
;
static
thread_local
bool
enable_bwd_prim_
;
static
thread_local
bool
enable_fwd_prim_
;
static
StaticCompositeContext
*
static_composite_context_
;
static
StaticCompositeContext
*
static_composite_context_
;
DISABLE_COPY_AND_ASSIGN
(
StaticCompositeContext
);
DISABLE_COPY_AND_ASSIGN
(
StaticCompositeContext
);
};
};
...
...
paddle/fluid/prim/utils/utils.cc
浏览文件 @
23d20e30
...
@@ -19,12 +19,24 @@
...
@@ -19,12 +19,24 @@
PADDLE_DEFINE_EXPORTED_bool
(
prim_enabled
,
false
,
"enable_prim or not"
);
PADDLE_DEFINE_EXPORTED_bool
(
prim_enabled
,
false
,
"enable_prim or not"
);
namespace
paddle
{
namespace
paddle
{
namespace
prim
{
namespace
prim
{
bool
PrimCommonUtils
::
IsPrimEnabled
()
{
bool
PrimCommonUtils
::
Is
Bwd
PrimEnabled
()
{
return
StaticCompositeContext
::
Instance
().
IsPrimEnabled
();
return
StaticCompositeContext
::
Instance
().
Is
Bwd
PrimEnabled
();
}
}
void
PrimCommonUtils
::
SetPrimEnabled
(
bool
enable_prim
)
{
void
PrimCommonUtils
::
SetBwdPrimEnabled
(
bool
enable_prim
)
{
return
StaticCompositeContext
::
Instance
().
SetPrimEnabled
(
enable_prim
);
return
StaticCompositeContext
::
Instance
().
SetBwdPrimEnabled
(
enable_prim
);
}
bool
PrimCommonUtils
::
IsFwdPrimEnabled
()
{
return
StaticCompositeContext
::
Instance
().
IsFwdPrimEnabled
();
}
void
PrimCommonUtils
::
SetFwdPrimEnabled
(
bool
enable_prim
)
{
return
StaticCompositeContext
::
Instance
().
SetFwdPrimEnabled
(
enable_prim
);
}
void
PrimCommonUtils
::
SetAllPrimEnabled
(
bool
enable_prim
)
{
return
StaticCompositeContext
::
Instance
().
SetAllPrimEnabled
(
enable_prim
);
}
}
}
// namespace prim
}
// namespace prim
}
// namespace paddle
}
// namespace paddle
paddle/fluid/prim/utils/utils.h
浏览文件 @
23d20e30
...
@@ -18,8 +18,11 @@ namespace paddle {
...
@@ -18,8 +18,11 @@ namespace paddle {
namespace
prim
{
namespace
prim
{
class
PrimCommonUtils
{
class
PrimCommonUtils
{
public:
public:
static
bool
IsPrimEnabled
();
static
bool
IsBwdPrimEnabled
();
static
void
SetPrimEnabled
(
bool
enabled
);
static
void
SetBwdPrimEnabled
(
bool
enabled
);
static
bool
IsFwdPrimEnabled
();
static
void
SetFwdPrimEnabled
(
bool
enabled
);
static
void
SetAllPrimEnabled
(
bool
enabled
);
};
};
}
// namespace prim
}
// namespace prim
}
// namespace paddle
}
// namespace paddle
paddle/fluid/pybind/pybind.cc
浏览文件 @
23d20e30
...
@@ -660,8 +660,16 @@ PYBIND11_MODULE(libpaddle, m) {
...
@@ -660,8 +660,16 @@ PYBIND11_MODULE(libpaddle, m) {
return
oss
.
str
();
return
oss
.
str
();
});
});
m
.
def
(
"set_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
SetPrimEnabled
);
m
.
def
(
"__set_bwd_prim_enabled"
,
m
.
def
(
"is_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
IsPrimEnabled
);
&
paddle
::
prim
::
PrimCommonUtils
::
SetBwdPrimEnabled
);
m
.
def
(
"_is_bwd_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
IsBwdPrimEnabled
);
m
.
def
(
"__set_fwd_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
SetFwdPrimEnabled
);
m
.
def
(
"_is_fwd_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
IsFwdPrimEnabled
);
m
.
def
(
"__set_all_prim_enabled"
,
&
paddle
::
prim
::
PrimCommonUtils
::
SetAllPrimEnabled
);
m
.
def
(
"set_num_threads"
,
&
platform
::
SetNumThreads
);
m
.
def
(
"set_num_threads"
,
&
platform
::
SetNumThreads
);
m
.
def
(
"disable_signal_handler"
,
&
DisableSignalHandler
);
m
.
def
(
"disable_signal_handler"
,
&
DisableSignalHandler
);
...
@@ -1264,8 +1272,9 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -1264,8 +1272,9 @@ All parameter, weight, gradient are variables in Paddle.
// priority of GradCompOpMaker is less than GradCompMaker for better
// priority of GradCompOpMaker is less than GradCompMaker for better
// performance.
// performance.
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
grad_op_descs
;
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
grad_op_descs
;
if
(
paddle
::
prim
::
PrimCommonUtils
::
IsPrimEnabled
())
{
if
(
paddle
::
prim
::
PrimCommonUtils
::
Is
Bwd
PrimEnabled
())
{
if
(
grad_comp_op_maker
!=
nullptr
)
{
if
(
grad_comp_op_maker
!=
nullptr
)
{
VLOG
(
3
)
<<
"Runing composite fun for "
<<
op_desc
.
Type
();
grad_op_descs
=
grad_comp_op_maker
(
op_desc
,
grad_op_descs
=
grad_comp_op_maker
(
op_desc
,
no_grad_set
,
no_grad_set
,
&
grad_to_var
,
&
grad_to_var
,
...
...
paddle/phi/api/yaml/legacy_backward.yaml
浏览文件 @
23d20e30
...
@@ -42,7 +42,7 @@
...
@@ -42,7 +42,7 @@
kernel
:
kernel
:
func
:
add_grad
func
:
add_grad
no_need_buffer
:
x, y
no_need_buffer
:
x, y
composite
:
add_grad(
Tensor x, Tensor y, Tensor out_grad, int
axis)
composite
:
add_grad(
x, y, out_grad,
axis)
backward
:
add_double_grad
backward
:
add_double_grad
inplace
:
(out_grad -> x_grad)
inplace
:
(out_grad -> x_grad)
...
@@ -390,7 +390,7 @@
...
@@ -390,7 +390,7 @@
param
:
[
x
,
y
]
param
:
[
x
,
y
]
kernel
:
kernel
:
func
:
divide_grad
func
:
divide_grad
composite
:
divide_grad(
Tensor x, Tensor y, Tensor out, Tensor out_grad, int axis =
-1)
composite
:
divide_grad(
x, y, out, out_grad,
-1)
backward
:
divide_double_grad
backward
:
divide_double_grad
-
backward_op
:
dropout_grad
-
backward_op
:
dropout_grad
...
@@ -1319,7 +1319,7 @@
...
@@ -1319,7 +1319,7 @@
kernel
:
kernel
:
func
:
subtract_grad
func
:
subtract_grad
no_need_buffer
:
x, y
no_need_buffer
:
x, y
composite
:
subtract_grad(
Tensor x, Tensor y, Tensor out_grad, int
axis)
composite
:
subtract_grad(
x, y, out_grad,
axis)
backward
:
subtract_double_grad
backward
:
subtract_double_grad
inplace
:
(out_grad -> x_grad)
inplace
:
(out_grad -> x_grad)
...
...
python/paddle/fluid/backward.py
浏览文件 @
23d20e30
...
@@ -1493,14 +1493,15 @@ def _append_backward_ops_(
...
@@ -1493,14 +1493,15 @@ def _append_backward_ops_(
# remove some backward ops
# remove some backward ops
# TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem
# TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem
if
not
core
.
is
_prim_enabled
():
if
not
core
.
_is_bwd
_prim_enabled
():
not_need_ops
=
_find_not_need_ops
(
not_need_ops
=
_find_not_need_ops
(
grad_op_descs
,
ops
,
input_grad_names_set
grad_op_descs
,
ops
,
input_grad_names_set
)
)
grad_op_descs
=
[
grad_op_descs
=
[
op_desc
for
op_desc
in
grad_op_descs
if
op_desc
not
in
not_need_ops
op_desc
for
op_desc
in
grad_op_descs
if
op_desc
not
in
not_need_ops
]
]
else
:
logging
.
debug
(
"Runing backward composite and disable find_not_need_ops"
)
# append op_desc in grad_op_descs to target_block
# append op_desc in grad_op_descs to target_block
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
...
...
python/paddle/fluid/core.py
浏览文件 @
23d20e30
...
@@ -17,6 +17,7 @@ import sys
...
@@ -17,6 +17,7 @@ import sys
import
os
import
os
import
warnings
import
warnings
import
platform
import
platform
import
logging
has_paddle_dy_lib
=
False
has_paddle_dy_lib
=
False
...
@@ -305,8 +306,13 @@ try:
...
@@ -305,8 +306,13 @@ try:
from
.libpaddle
import
_Profiler
,
_ProfilerResult
,
_RecordEvent
from
.libpaddle
import
_Profiler
,
_ProfilerResult
,
_RecordEvent
from
.libpaddle
import
_set_current_stream
from
.libpaddle
import
_set_current_stream
from
.libpaddle
import
_get_phi_kernel_name
from
.libpaddle
import
_get_phi_kernel_name
from
.libpaddle
import
set_prim_enabled
from
.libpaddle
import
is_prim_enabled
# prim controller flags
from
.libpaddle
import
__set_bwd_prim_enabled
from
.libpaddle
import
_is_bwd_prim_enabled
from
.libpaddle
import
__set_fwd_prim_enabled
from
.libpaddle
import
_is_fwd_prim_enabled
from
.libpaddle
import
__set_all_prim_enabled
if
sys
.
platform
!=
'win32'
:
if
sys
.
platform
!=
'win32'
:
from
.libpaddle
import
_set_process_pids
from
.libpaddle
import
_set_process_pids
...
@@ -373,36 +379,98 @@ def set_paddle_lib_path():
...
@@ -373,36 +379,98 @@ def set_paddle_lib_path():
set_paddle_lib_path
()
set_paddle_lib_path
()
# We have 3 FLAGS to judge whether prim is enabled
# FLAGS_prim_forward: Open or close forward prim strategy
# FLAGS_prim_backward: Open or close backward prim strategy
# FLAGS_prim_all: Open or close all prim strategy
#
#
# Priorities:
# if With CINN and Dy2St:
# # # _set_prim_all_enabled > FLAGS_prim_all > check_and_set_prim_all_enabled == _set_prim_backward_enabled == _set_prim_backward_enabled > FLAGS_prim_forward == FLAGS_prim_backward
# else:
# # # _set_prim_all_enabled > FLAGS_prim_all == check_and_set_prim_all_enabled == _set_prim_backward_enabled == _set_prim_backward_enabled > FLAGS_prim_forward == FLAGS_prim_backward
def
__sync_stat_with_flag
(
flag
):
if
flag
is
"FLAGS_prim_forward"
:
flag_value
=
os
.
getenv
(
"FLAGS_prim_forward"
)
assert
flag_value
is
not
None
flag_value
=
flag_value
.
lower
()
if
flag_value
==
"false"
:
__set_fwd_prim_enabled
(
False
)
elif
flag_value
==
"true"
:
__set_fwd_prim_enabled
(
True
)
else
:
raise
TypeError
(
f
"flag
{
flag
}
should be true or false."
)
logging
.
debug
(
"forward prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()))
elif
flag
is
"FLAGS_prim_backward"
:
flag_value
=
os
.
getenv
(
"FLAGS_prim_backward"
)
assert
flag_value
is
not
None
flag_value
=
flag_value
.
lower
()
if
flag_value
==
"false"
:
__set_bwd_prim_enabled
(
False
)
elif
flag_value
==
"true"
:
__set_bwd_prim_enabled
(
True
)
else
:
raise
TypeError
(
f
"flag
{
flag
}
should be true or false."
)
logging
.
debug
(
"backward prim enabled: "
,
bool
(
_is_bwd_prim_enabled
()))
elif
flag
is
"FLAGS_prim_all"
:
flag_value
=
os
.
getenv
(
"FLAGS_prim_all"
)
assert
flag_value
is
not
None
flag_value
=
flag_value
.
lower
()
if
flag_value
==
"false"
:
__set_all_prim_enabled
(
False
)
elif
flag_value
==
"true"
:
__set_all_prim_enabled
(
True
)
else
:
raise
TypeError
(
f
"flag
{
flag
}
should be true or false."
)
logging
.
debug
(
"all prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()
and
_is_bwd_prim_enabled
()),
)
else
:
raise
TypeError
(
f
"We only support FLAGS_prim_forward/FLAGS_prim_backward/FLAGS_prim_all but we got
{
flag
}
."
)
def
set_prim_forward
(
value
):
"""set flag FLAGS_prim_forward."""
flag
=
str
(
value
)
if
flag
.
lower
()
not
in
[
"true"
,
"false"
,
"debug"
]:
raise
TypeError
(
f
"flag
{
flag
}
should be string of bool or 'debug'."
)
os
.
environ
[
"FLAGS_prim_forward"
]
=
flag
return
def
_set_prim_backward_enabled
(
value
):
__set_bwd_prim_enabled
(
bool
(
value
))
logging
.
debug
(
"backward prim enabled: "
,
bool
(
_is_bwd_prim_enabled
()))
def
enable_prim_forward
():
flag
=
os
.
getenv
(
"FLAGS_prim_forward"
,
"true"
).
lower
()
if
flag
==
"false"
:
return
False
if
flag
==
"debug"
:
return
"debug"
return
True
def
_set_prim_forward_enabled
(
value
):
__set_fwd_prim_enabled
(
bool
(
value
))
logging
.
debug
(
"forward prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()))
def
set_prim_backward
(
value
):
"""set flag FLAGS_prim_backward,"""
flag
=
str
(
value
)
if
flag
.
lower
()
not
in
[
"true"
,
"false"
]:
raise
TypeError
(
f
"flag
{
flag
}
should be bool or string of bool."
)
os
.
environ
[
"FLAGS_prim_backward"
]
=
flag
return
def
_set_prim_all_enabled
(
value
):
__set_all_prim_enabled
(
bool
(
value
))
logging
.
debug
(
"all prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()
and
_is_bwd_prim_enabled
()),
)
def
enable_prim_backward
():
flag
=
os
.
getenv
(
"FLAGS_prim_backward"
,
"true"
)
def
__sync_prim_backward_status
():
if
flag
.
lower
()
==
"false"
:
flag_value
=
os
.
getenv
(
"FLAGS_prim_backward"
)
return
False
if
flag_value
is
None
:
return
True
logging
.
debug
(
"backward prim enabled: "
,
bool
(
_is_bwd_prim_enabled
()))
else
:
__sync_stat_with_flag
(
"FLAGS_prim_backward"
)
def
__sync_prim_forward_status
():
flag_value
=
os
.
getenv
(
"FLAGS_prim_forward"
)
if
flag_value
is
None
:
logging
.
debug
(
"forward prim enabled: "
,
bool
(
_is_fwd_prim_enabled
()))
else
:
__sync_stat_with_flag
(
"FLAGS_prim_forward"
)
def
check_and_set_prim_all_enabled
():
flag_value
=
os
.
getenv
(
"FLAGS_prim_all"
)
if
flag_value
is
None
:
__sync_prim_backward_status
()
__sync_prim_forward_status
()
else
:
__sync_stat_with_flag
(
"FLAGS_prim_all"
)
python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py
浏览文件 @
23d20e30
...
@@ -19,6 +19,7 @@ from utils import TOLERANCE
...
@@ -19,6 +19,7 @@ from utils import TOLERANCE
import
paddle
import
paddle
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
from
paddle.fluid
import
core
def
generate_data
(
shape
,
dtype
=
"float32"
):
def
generate_data
(
shape
,
dtype
=
"float32"
):
...
@@ -72,6 +73,7 @@ class TestCompositeSoftmax(unittest.TestCase):
...
@@ -72,6 +73,7 @@ class TestCompositeSoftmax(unittest.TestCase):
def
cal_composite
(
self
,
inputs
):
def
cal_composite
(
self
,
inputs
):
paddle
.
enable_static
()
paddle
.
enable_static
()
core
.
_set_prim_forward_enabled
(
True
)
startup_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
main_program
=
paddle
.
static
.
Program
()
main_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
...
@@ -95,6 +97,7 @@ class TestCompositeSoftmax(unittest.TestCase):
...
@@ -95,6 +97,7 @@ class TestCompositeSoftmax(unittest.TestCase):
exe
.
run
(
startup_program
)
exe
.
run
(
startup_program
)
res
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
inputs
},
fetch_list
=
[
y
])
res
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
inputs
},
fetch_list
=
[
y
])
paddle
.
disable_static
()
paddle
.
disable_static
()
core
.
_set_prim_forward_enabled
(
False
)
return
res
return
res
def
compare_forward
(
self
):
def
compare_forward
(
self
):
...
...
python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py
浏览文件 @
23d20e30
...
@@ -78,6 +78,7 @@ class TestCompositeSoftmax(unittest.TestCase):
...
@@ -78,6 +78,7 @@ class TestCompositeSoftmax(unittest.TestCase):
def
cal_composite_grad
(
self
,
inputs
):
def
cal_composite_grad
(
self
,
inputs
):
paddle
.
enable_static
()
paddle
.
enable_static
()
core
.
_set_prim_all_enabled
(
True
)
startup_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
main_program
=
paddle
.
static
.
Program
()
main_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
...
@@ -108,6 +109,7 @@ class TestCompositeSoftmax(unittest.TestCase):
...
@@ -108,6 +109,7 @@ class TestCompositeSoftmax(unittest.TestCase):
exe
.
run
(
startup_program
)
exe
.
run
(
startup_program
)
res
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
inputs
},
fetch_list
=
[
z
])
res
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
inputs
},
fetch_list
=
[
z
])
paddle
.
disable_static
()
paddle
.
disable_static
()
core
.
_set_prim_all_enabled
(
False
)
return
res
return
res
def
compare_backward
(
self
):
def
compare_backward
(
self
):
...
@@ -139,7 +141,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
...
@@ -139,7 +141,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
"test composite softmax and prim backward"
"test composite softmax and prim backward"
def
setUp
(
self
):
def
setUp
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
self
.
dtypes
=
[
"float32"
]
self
.
dtypes
=
[
"float32"
]
self
.
shapes
=
[[
2
,
3
,
4
],
[
2
,
3
]]
self
.
shapes
=
[[
2
,
3
,
4
],
[
2
,
3
]]
self
.
axes
=
[
-
1
,
0
,
1
]
self
.
axes
=
[
-
1
,
0
,
1
]
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py
浏览文件 @
23d20e30
...
@@ -236,11 +236,11 @@ class TestBert(unittest.TestCase):
...
@@ -236,11 +236,11 @@ class TestBert(unittest.TestCase):
self
.
verify_predict
()
self
.
verify_predict
()
def
test_train_composite
(
self
):
def
test_train_composite
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
static_loss
,
static_ppl
=
self
.
train_static
(
static_loss
,
static_ppl
=
self
.
train_static
(
self
.
bert_config
,
self
.
data_reader
self
.
bert_config
,
self
.
data_reader
)
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
dygraph_loss
,
dygraph_ppl
=
self
.
train_dygraph
(
dygraph_loss
,
dygraph_ppl
=
self
.
train_dygraph
(
self
.
bert_config
,
self
.
data_reader
self
.
bert_config
,
self
.
data_reader
)
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py
浏览文件 @
23d20e30
...
@@ -47,7 +47,6 @@ class TestPrimForward(unittest.TestCase):
...
@@ -47,7 +47,6 @@ class TestPrimForward(unittest.TestCase):
"""
"""
def
setUp
(
self
):
def
setUp
(
self
):
core
.
set_prim_backward
(
False
)
paddle
.
seed
(
2022
)
paddle
.
seed
(
2022
)
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
x
.
stop_gradient
=
False
self
.
x
.
stop_gradient
=
False
...
@@ -58,6 +57,7 @@ class TestPrimForward(unittest.TestCase):
...
@@ -58,6 +57,7 @@ class TestPrimForward(unittest.TestCase):
sgd
=
paddle
.
optimizer
.
SGD
(
sgd
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.1
,
parameters
=
net
.
parameters
()
learning_rate
=
0.1
,
parameters
=
net
.
parameters
()
)
)
core
.
_set_prim_forward_enabled
(
use_prim
)
if
use_prim
:
if
use_prim
:
net
=
apply_to_static
(
net
,
use_prim
)
net
=
apply_to_static
(
net
,
use_prim
)
...
@@ -103,12 +103,12 @@ class TestPrimForwardAndBackward(unittest.TestCase):
...
@@ -103,12 +103,12 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self
.
x
.
stop_gradient
=
False
self
.
x
.
stop_gradient
=
False
def
train
(
self
,
use_prim
):
def
train
(
self
,
use_prim
):
core
.
set_prim_backward
(
True
)
paddle
.
seed
(
2022
)
paddle
.
seed
(
2022
)
net
=
PrimeNet
()
net
=
PrimeNet
()
sgd
=
paddle
.
optimizer
.
SGD
(
sgd
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.1
,
parameters
=
net
.
parameters
()
learning_rate
=
0.1
,
parameters
=
net
.
parameters
()
)
)
core
.
_set_prim_all_enabled
(
use_prim
)
if
use_prim
:
if
use_prim
:
net
=
apply_to_static
(
net
,
use_prim
)
net
=
apply_to_static
(
net
,
use_prim
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py
浏览文件 @
23d20e30
...
@@ -427,10 +427,10 @@ class TestResnet(unittest.TestCase):
...
@@ -427,10 +427,10 @@ class TestResnet(unittest.TestCase):
)
)
self
.
verify_predict
()
self
.
verify_predict
()
def
test_resnet_composite
(
self
):
def
test_resnet_composite
_backward
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
dygraph_loss
=
self
.
train
(
to_static
=
True
)
dygraph_loss
=
self
.
train
(
to_static
=
True
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
static_loss
,
static_loss
,
...
@@ -440,65 +440,13 @@ class TestResnet(unittest.TestCase):
...
@@ -440,65 +440,13 @@ class TestResnet(unittest.TestCase):
static_loss
,
dygraph_loss
static_loss
,
dygraph_loss
),
),
)
)
core
.
set_prim_enabled
(
False
)
def
test_in_static_mode_mkldnn
(
self
):
def
test_resnet_composite_forward_backward
(
self
):
fluid
.
set_flags
({
'FLAGS_use_mkldnn'
:
True
})
try
:
if
paddle
.
fluid
.
core
.
is_compiled_with_mkldnn
():
self
.
resnet_helper
.
train
(
to_static
=
True
)
finally
:
fluid
.
set_flags
({
'FLAGS_use_mkldnn'
:
False
})
class
TestResnetPrim
(
unittest
.
TestCase
):
"test prim forward + prim backward + to_static"
def
setUp
(
self
):
self
.
resnet_helper
=
ResNetHelper
()
def
train
(
self
,
to_static
):
paddle
.
jit
.
enable_to_static
(
to_static
)
return
self
.
resnet_helper
.
train
(
to_static
)
def
verify_predict
(
self
):
image
=
np
.
random
.
random
([
1
,
3
,
224
,
224
]).
astype
(
'float32'
)
dy_pre
=
self
.
resnet_helper
.
predict_dygraph
(
image
)
st_pre
=
self
.
resnet_helper
.
predict_static
(
image
)
dy_jit_pre
=
self
.
resnet_helper
.
predict_dygraph_jit
(
image
)
predictor_pre
=
self
.
resnet_helper
.
predict_analysis_inference
(
image
)
np
.
testing
.
assert_allclose
(
dy_pre
,
st_pre
,
rtol
=
1e-05
,
err_msg
=
'dy_pre:
\n
{}
\n
, st_pre:
\n
{}.'
.
format
(
dy_pre
,
st_pre
),
)
np
.
testing
.
assert_allclose
(
dy_jit_pre
,
st_pre
,
rtol
=
1e-05
,
err_msg
=
'dy_jit_pre:
\n
{}
\n
, st_pre:
\n
{}.'
.
format
(
dy_jit_pre
,
st_pre
),
)
np
.
testing
.
assert_allclose
(
predictor_pre
,
st_pre
,
rtol
=
1e-05
,
err_msg
=
'predictor_pre:
\n
{}
\n
, st_pre:
\n
{}.'
.
format
(
predictor_pre
,
st_pre
),
)
def
test_resnet_composite
(
self
):
plat
=
platform
.
system
()
plat
=
platform
.
system
()
if
plat
==
"Linux"
:
if
plat
==
"Linux"
:
print
(
"=================== origin resnet ==================="
)
core
.
_set_prim_all_enabled
(
True
)
core
.
set_prim_enabled
(
False
)
static_loss
=
self
.
train
(
to_static
=
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
print
(
"======= resnet with prim forward and backward ======="
)
core
.
_set_prim_all_enabled
(
False
)
core
.
set_prim_enabled
(
True
)
core
.
set_prim_forward
(
"debug"
)
dygraph_loss
=
self
.
train
(
to_static
=
True
)
dygraph_loss
=
self
.
train
(
to_static
=
True
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
static_loss
,
static_loss
,
...
@@ -508,10 +456,17 @@ class TestResnetPrim(unittest.TestCase):
...
@@ -508,10 +456,17 @@ class TestResnetPrim(unittest.TestCase):
static_loss
,
dygraph_loss
static_loss
,
dygraph_loss
),
),
)
)
core
.
set_prim_enabled
(
False
)
else
:
else
:
pass
pass
def
test_in_static_mode_mkldnn
(
self
):
fluid
.
set_flags
({
'FLAGS_use_mkldnn'
:
True
})
try
:
if
paddle
.
fluid
.
core
.
is_compiled_with_mkldnn
():
self
.
resnet_helper
.
train
(
to_static
=
True
)
finally
:
fluid
.
set_flags
({
'FLAGS_use_mkldnn'
:
False
})
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_amp.py
浏览文件 @
23d20e30
...
@@ -130,9 +130,9 @@ class TestResnet(unittest.TestCase):
...
@@ -130,9 +130,9 @@ class TestResnet(unittest.TestCase):
)
)
def
test_resnet_composite
(
self
):
def
test_resnet_composite
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
dygraph_loss
=
self
.
train
(
to_static
=
False
)
dygraph_loss
=
self
.
train
(
to_static
=
False
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
static_loss
,
static_loss
,
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py
浏览文件 @
23d20e30
...
@@ -137,9 +137,9 @@ class TestResnet(unittest.TestCase):
...
@@ -137,9 +137,9 @@ class TestResnet(unittest.TestCase):
def
test_resnet_composite
(
self
):
def
test_resnet_composite
(
self
):
if
fluid
.
is_compiled_with_cuda
():
if
fluid
.
is_compiled_with_cuda
():
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
dygraph_loss
=
self
.
train
(
to_static
=
False
)
dygraph_loss
=
self
.
train
(
to_static
=
False
)
# NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here.
# NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here.
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py
浏览文件 @
23d20e30
...
@@ -426,9 +426,9 @@ class TestResnet(unittest.TestCase):
...
@@ -426,9 +426,9 @@ class TestResnet(unittest.TestCase):
self
.
verify_predict
()
self
.
verify_predict
()
def
test_resnet_composite
(
self
):
def
test_resnet_composite
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
static_loss
=
self
.
train
(
to_static
=
True
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
dygraph_loss
=
self
.
train
(
to_static
=
False
)
dygraph_loss
=
self
.
train
(
to_static
=
False
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
static_loss
,
static_loss
,
...
...
python/paddle/fluid/tests/unittests/prim/prim/CMakeLists.txt
浏览文件 @
23d20e30
...
@@ -9,3 +9,4 @@ foreach(TEST_OP ${TEST_OPS})
...
@@ -9,3 +9,4 @@ foreach(TEST_OP ${TEST_OPS})
endforeach
()
endforeach
()
add_subdirectory
(
vjp
)
add_subdirectory
(
vjp
)
add_subdirectory
(
flags
)
python/paddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt
0 → 100644
浏览文件 @
23d20e30
file
(
GLOB TEST_OPS
RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"test_*.py"
)
string
(
REPLACE
".py"
""
TEST_OPS
"
${
TEST_OPS
}
"
)
foreach
(
TEST_OP
${
TEST_OPS
}
)
py_test_modules
(
${
TEST_OP
}
MODULES
${
TEST_OP
}
ENVS
${
GC_ENVS
}
)
endforeach
()
python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py
0 → 100644
浏览文件 @
23d20e30
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
unittest
from
paddle.fluid
import
core
class
TestPrimFlags
(
unittest
.
TestCase
):
def
test_prim_flags
(
self
):
self
.
assertFalse
(
core
.
_is_bwd_prim_enabled
())
self
.
assertFalse
(
core
.
_is_fwd_prim_enabled
())
os
.
environ
[
'FLAGS_prim_backward'
]
=
"True"
core
.
check_and_set_prim_all_enabled
()
self
.
assertTrue
(
core
.
_is_bwd_prim_enabled
())
os
.
environ
[
'FLAGS_prim_forward'
]
=
"True"
core
.
check_and_set_prim_all_enabled
()
self
.
assertTrue
(
core
.
_is_fwd_prim_enabled
())
os
.
environ
[
'FLAGS_prim_all'
]
=
"False"
core
.
check_and_set_prim_all_enabled
()
self
.
assertFalse
(
core
.
_is_bwd_prim_enabled
())
self
.
assertFalse
(
core
.
_is_fwd_prim_enabled
())
os
.
environ
[
'FLAGS_prim_all'
]
=
"True"
core
.
check_and_set_prim_all_enabled
()
self
.
assertTrue
(
core
.
_is_bwd_prim_enabled
())
self
.
assertTrue
(
core
.
_is_fwd_prim_enabled
())
del
os
.
environ
[
'FLAGS_prim_all'
]
os
.
environ
[
'FLAGS_prim_backward'
]
=
"False"
core
.
check_and_set_prim_all_enabled
()
self
.
assertFalse
(
core
.
_is_bwd_prim_enabled
())
os
.
environ
[
'FLAGS_prim_forward'
]
=
"False"
core
.
check_and_set_prim_all_enabled
()
self
.
assertFalse
(
core
.
_is_fwd_prim_enabled
())
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py
浏览文件 @
23d20e30
...
@@ -20,7 +20,7 @@ import parameterized as param
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
import
paddle
from
paddle.fluid
import
core
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
@
param
.
parameterized_class
(
@
param
.
parameterized_class
(
...
@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
...
@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
def
test_tanh_grad_comp
(
self
):
def
test_tanh_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
paddle
.
disable_static
()
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
...
@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
return
res
[
0
].
numpy
(),
res
[
1
].
numpy
()
return
res
[
0
].
numpy
(),
res
[
1
].
numpy
()
def
desired
(
primal0
,
primal1
):
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
paddle
.
disable_static
()
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
...
@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol
=
1e-6
,
rtol
=
1e-6
,
atol
=
0
,
atol
=
0
,
)
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py
浏览文件 @
23d20e30
...
@@ -20,7 +20,7 @@ import parameterized as param
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
import
paddle
from
paddle.fluid
import
core
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
@
param
.
parameterized_class
(
@
param
.
parameterized_class
(
...
@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
...
@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
def
test_tanh_grad_comp
(
self
):
def
test_tanh_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
paddle
.
disable_static
()
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
...
@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
return
res
[
0
].
numpy
(),
res
[
1
].
numpy
()
return
res
[
0
].
numpy
(),
res
[
1
].
numpy
()
def
desired
(
primal0
,
primal1
):
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
paddle
.
disable_static
()
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
...
@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol
=
1e-6
,
rtol
=
1e-6
,
atol
=
0
,
atol
=
0
,
)
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py
浏览文件 @
23d20e30
...
@@ -32,14 +32,14 @@ from paddle.fluid import core
...
@@ -32,14 +32,14 @@ from paddle.fluid import core
class
TestExpGradComp
(
unittest
.
TestCase
):
class
TestExpGradComp
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
cls
.
primal
=
cls
.
primal
.
astype
(
cls
.
dtype
)
cls
.
primal
=
cls
.
primal
.
astype
(
cls
.
dtype
)
if
cls
.
cotangent
is
not
None
:
if
cls
.
cotangent
is
not
None
:
cls
.
cotangent
=
cls
.
cotangent
.
astype
(
cls
.
dtype
)
cls
.
cotangent
=
cls
.
cotangent
.
astype
(
cls
.
dtype
)
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
def
test_exp_grad_comp
(
self
):
def
test_exp_grad_comp
(
self
):
def
actual
(
primal
,
cotangent
):
def
actual
(
primal
,
cotangent
):
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py
浏览文件 @
23d20e30
...
@@ -62,7 +62,7 @@ class TestExpandGradComp(unittest.TestCase):
...
@@ -62,7 +62,7 @@ class TestExpandGradComp(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
def
test_comp
(
self
):
def
test_comp
(
self
):
def
func
(
primal
,
cotangent
,
shape
):
def
func
(
primal
,
cotangent
,
shape
):
...
@@ -74,11 +74,11 @@ class TestExpandGradComp(unittest.TestCase):
...
@@ -74,11 +74,11 @@ class TestExpandGradComp(unittest.TestCase):
]
]
def
actual
(
primal
,
cotangent
,
shape
):
def
actual
(
primal
,
cotangent
,
shape
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
return
func
(
primal
,
cotangent
,
shape
)
return
func
(
primal
,
cotangent
,
shape
)
def
desired
(
primal
,
cotangent
,
shape
):
def
desired
(
primal
,
cotangent
,
shape
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
return
func
(
primal
,
cotangent
,
shape
)
return
func
(
primal
,
cotangent
,
shape
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py
浏览文件 @
23d20e30
...
@@ -81,10 +81,10 @@ class TestMultiplyGradComp(unittest.TestCase):
...
@@ -81,10 +81,10 @@ class TestMultiplyGradComp(unittest.TestCase):
return
[
g
for
g
in
grads
if
g
is
not
None
]
return
[
g
for
g
in
grads
if
g
is
not
None
]
def
test_comp
(
self
):
def
test_comp
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
actual
=
self
.
vjp
()
actual
=
self
.
vjp
()
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
desired
=
self
.
vjp
()
desired
=
self
.
vjp
()
for
i
,
j
in
zip
(
actual
,
desired
):
for
i
,
j
in
zip
(
actual
,
desired
):
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py
浏览文件 @
23d20e30
...
@@ -22,7 +22,7 @@ import parameterized as param
...
@@ -22,7 +22,7 @@ import parameterized as param
import
paddle
import
paddle
from
paddle.fluid
import
core
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
@
param
.
parameterized_class
(
@
param
.
parameterized_class
(
...
@@ -63,7 +63,7 @@ class TestSqrtGradComp(unittest.TestCase):
...
@@ -63,7 +63,7 @@ class TestSqrtGradComp(unittest.TestCase):
rtol
=
1e-6
,
rtol
=
1e-6
,
atol
=
0
,
atol
=
0
,
)
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py
浏览文件 @
23d20e30
...
@@ -20,7 +20,7 @@ import parameterized as param
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
import
paddle
from
paddle.fluid
import
core
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
@
param
.
parameterized_class
(
@
param
.
parameterized_class
(
...
@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
...
@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
def
test_tanh_grad_comp
(
self
):
def
test_tanh_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
paddle
.
disable_static
()
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
...
@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
return
res
[
0
].
numpy
(),
res
[
1
].
numpy
()
return
res
[
0
].
numpy
(),
res
[
1
].
numpy
()
def
desired
(
primal0
,
primal1
):
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
paddle
.
disable_static
()
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
x
=
paddle
.
to_tensor
(
primal0
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
to_tensor
(
primal1
,
dtype
=
'float32'
,
stop_gradient
=
False
)
...
@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
...
@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol
=
1e-6
,
rtol
=
1e-6
,
atol
=
0
,
atol
=
0
,
)
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py
浏览文件 @
23d20e30
...
@@ -21,7 +21,7 @@ from paddle.fluid import core
...
@@ -21,7 +21,7 @@ from paddle.fluid import core
def
actual
(
primal
,
cotangent
,
axis
,
keep_dim
):
def
actual
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
x
=
paddle
.
to_tensor
(
primal
,
dtype
=
'float32'
,
stop_gradient
=
False
)
x
=
paddle
.
to_tensor
(
primal
,
dtype
=
'float32'
,
stop_gradient
=
False
)
v
=
paddle
.
to_tensor
(
cotangent
,
dtype
=
'float32'
,
stop_gradient
=
False
)
v
=
paddle
.
to_tensor
(
cotangent
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
sum
(
x
,
axis
=
axis
,
keepdim
=
keep_dim
)
y
=
paddle
.
sum
(
x
,
axis
=
axis
,
keepdim
=
keep_dim
)
...
@@ -30,7 +30,7 @@ def actual(primal, cotangent, axis, keep_dim):
...
@@ -30,7 +30,7 @@ def actual(primal, cotangent, axis, keep_dim):
def
desired
(
primal
,
cotangent
,
axis
,
keep_dim
):
def
desired
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
x
=
paddle
.
to_tensor
(
primal
,
dtype
=
'float32'
,
stop_gradient
=
False
)
x
=
paddle
.
to_tensor
(
primal
,
dtype
=
'float32'
,
stop_gradient
=
False
)
v
=
paddle
.
to_tensor
(
cotangent
,
dtype
=
'float32'
,
stop_gradient
=
False
)
v
=
paddle
.
to_tensor
(
cotangent
,
dtype
=
'float32'
,
stop_gradient
=
False
)
y
=
paddle
.
sum
(
x
,
axis
=
axis
,
keepdim
=
keep_dim
)
y
=
paddle
.
sum
(
x
,
axis
=
axis
,
keepdim
=
keep_dim
)
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py
浏览文件 @
23d20e30
...
@@ -20,7 +20,7 @@ import parameterized as param
...
@@ -20,7 +20,7 @@ import parameterized as param
import
paddle
import
paddle
from
paddle.fluid
import
core
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
@
param
.
parameterized_class
(
@
param
.
parameterized_class
(
...
@@ -74,7 +74,7 @@ class TestTanhGradComp(unittest.TestCase):
...
@@ -74,7 +74,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol
=
1e-6
,
rtol
=
1e-6
,
atol
=
0
,
atol
=
0
,
)
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py
浏览文件 @
23d20e30
...
@@ -81,7 +81,7 @@ class TestAddGradComp(unittest.TestCase):
...
@@ -81,7 +81,7 @@ class TestAddGradComp(unittest.TestCase):
self
.
x
.
stop_gradient
=
False
self
.
x
.
stop_gradient
=
False
self
.
y
.
stop_gradient
=
False
self
.
y
.
stop_gradient
=
False
net
=
PrimeNet
()
net
=
PrimeNet
()
core
.
set_prim
_enabled
(
use_prim
)
core
.
_set_prim_backward
_enabled
(
use_prim
)
net
=
apply_to_static
(
net
,
use_cinn
)
net
=
apply_to_static
(
net
,
use_cinn
)
out
=
net
(
self
.
x
,
self
.
y
)
out
=
net
(
self
.
x
,
self
.
y
)
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
,
self
.
y
])
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
,
self
.
y
])
...
@@ -104,7 +104,7 @@ class TestAddGradComp(unittest.TestCase):
...
@@ -104,7 +104,7 @@ class TestAddGradComp(unittest.TestCase):
def
test_tanh_grad_comp
(
self
):
def
test_tanh_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal0'
,
primal0
.
shape
,
primal0
.
dtype
)
x
=
paddle
.
static
.
data
(
'primal0'
,
primal0
.
shape
,
primal0
.
dtype
)
...
@@ -126,7 +126,7 @@ class TestAddGradComp(unittest.TestCase):
...
@@ -126,7 +126,7 @@ class TestAddGradComp(unittest.TestCase):
return
out
[
0
],
out
[
1
]
return
out
[
0
],
out
[
1
]
def
desired
(
primal0
,
primal1
):
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
x
=
paddle
.
static
.
data
(
...
@@ -167,7 +167,7 @@ class TestAddGradComp(unittest.TestCase):
...
@@ -167,7 +167,7 @@ class TestAddGradComp(unittest.TestCase):
rtol
=
1e-6
,
rtol
=
1e-6
,
atol
=
0
,
atol
=
0
,
)
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py
浏览文件 @
23d20e30
...
@@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase):
...
@@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase):
self
.
x
.
stop_gradient
=
False
self
.
x
.
stop_gradient
=
False
self
.
y
.
stop_gradient
=
False
self
.
y
.
stop_gradient
=
False
net
=
PrimeNet
()
net
=
PrimeNet
()
core
.
set_prim
_enabled
(
use_prim
)
core
.
_set_prim_backward
_enabled
(
use_prim
)
net
=
apply_to_static
(
net
,
use_cinn
)
net
=
apply_to_static
(
net
,
use_cinn
)
out
=
net
(
self
.
x
,
self
.
y
)
out
=
net
(
self
.
x
,
self
.
y
)
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
,
self
.
y
])
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
,
self
.
y
])
...
@@ -107,7 +107,7 @@ class TestDivGradComp(unittest.TestCase):
...
@@ -107,7 +107,7 @@ class TestDivGradComp(unittest.TestCase):
paddle
.
enable_static
()
paddle
.
enable_static
()
def
actual
(
primal0
,
primal1
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal0'
,
primal0
.
shape
,
primal0
.
dtype
)
x
=
paddle
.
static
.
data
(
'primal0'
,
primal0
.
shape
,
primal0
.
dtype
)
...
@@ -130,7 +130,7 @@ class TestDivGradComp(unittest.TestCase):
...
@@ -130,7 +130,7 @@ class TestDivGradComp(unittest.TestCase):
return
out
[
0
],
out
[
1
]
return
out
[
0
],
out
[
1
]
def
desired
(
primal0
,
primal1
):
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
x
=
paddle
.
static
.
data
(
...
@@ -172,7 +172,7 @@ class TestDivGradComp(unittest.TestCase):
...
@@ -172,7 +172,7 @@ class TestDivGradComp(unittest.TestCase):
rtol
=
1e-6
,
rtol
=
1e-6
,
atol
=
0
,
atol
=
0
,
)
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
paddle
.
disable_static
()
paddle
.
disable_static
()
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py
浏览文件 @
23d20e30
...
@@ -81,7 +81,7 @@ class TestDivGradComp(unittest.TestCase):
...
@@ -81,7 +81,7 @@ class TestDivGradComp(unittest.TestCase):
self
.
x
.
stop_gradient
=
False
self
.
x
.
stop_gradient
=
False
self
.
y
.
stop_gradient
=
False
self
.
y
.
stop_gradient
=
False
net
=
PrimeNet
()
net
=
PrimeNet
()
core
.
set_prim
_enabled
(
use_prim
)
core
.
_set_prim_backward
_enabled
(
use_prim
)
net
=
apply_to_static
(
net
,
use_cinn
)
net
=
apply_to_static
(
net
,
use_cinn
)
out
=
net
(
self
.
x
,
self
.
y
)
out
=
net
(
self
.
x
,
self
.
y
)
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
,
self
.
y
])
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
,
self
.
y
])
...
@@ -104,7 +104,7 @@ class TestDivGradComp(unittest.TestCase):
...
@@ -104,7 +104,7 @@ class TestDivGradComp(unittest.TestCase):
def
test_tanh_grad_comp
(
self
):
def
test_tanh_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal0'
,
primal0
.
shape
,
primal0
.
dtype
)
x
=
paddle
.
static
.
data
(
'primal0'
,
primal0
.
shape
,
primal0
.
dtype
)
...
@@ -126,7 +126,7 @@ class TestDivGradComp(unittest.TestCase):
...
@@ -126,7 +126,7 @@ class TestDivGradComp(unittest.TestCase):
return
out
[
0
],
out
[
1
]
return
out
[
0
],
out
[
1
]
def
desired
(
primal0
,
primal1
):
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
x
=
paddle
.
static
.
data
(
...
@@ -167,7 +167,7 @@ class TestDivGradComp(unittest.TestCase):
...
@@ -167,7 +167,7 @@ class TestDivGradComp(unittest.TestCase):
rtol
=
1e-6
,
rtol
=
1e-6
,
atol
=
0
,
atol
=
0
,
)
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py
浏览文件 @
23d20e30
...
@@ -33,14 +33,14 @@ from paddle.fluid import core
...
@@ -33,14 +33,14 @@ from paddle.fluid import core
class
TestExpGradComp
(
unittest
.
TestCase
):
class
TestExpGradComp
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
cls
.
primal
=
cls
.
primal
.
astype
(
cls
.
dtype
)
cls
.
primal
=
cls
.
primal
.
astype
(
cls
.
dtype
)
if
cls
.
cotangent
is
not
None
:
if
cls
.
cotangent
is
not
None
:
cls
.
cotangent
=
cls
.
cotangent
.
astype
(
cls
.
dtype
)
cls
.
cotangent
=
cls
.
cotangent
.
astype
(
cls
.
dtype
)
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
def
setUp
(
self
):
def
setUp
(
self
):
paddle
.
enable_static
()
paddle
.
enable_static
()
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_expand_grad.py
浏览文件 @
23d20e30
...
@@ -71,7 +71,7 @@ class TestExpandGradComp(unittest.TestCase):
...
@@ -71,7 +71,7 @@ class TestExpandGradComp(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
paddle
.
disable_static
()
paddle
.
disable_static
()
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
def
test_comp
(
self
):
def
test_comp
(
self
):
def
func
(
primal
,
cotangent
,
shape
):
def
func
(
primal
,
cotangent
,
shape
):
...
@@ -93,11 +93,11 @@ class TestExpandGradComp(unittest.TestCase):
...
@@ -93,11 +93,11 @@ class TestExpandGradComp(unittest.TestCase):
)[
0
]
)[
0
]
def
actual
(
primal
,
cotangent
,
shape
):
def
actual
(
primal
,
cotangent
,
shape
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
return
func
(
primal
,
cotangent
,
shape
)
return
func
(
primal
,
cotangent
,
shape
)
def
desired
(
primal
,
cotangent
,
shape
):
def
desired
(
primal
,
cotangent
,
shape
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
return
func
(
primal
,
cotangent
,
shape
)
return
func
(
primal
,
cotangent
,
shape
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py
浏览文件 @
23d20e30
...
@@ -108,10 +108,10 @@ class TestMultiplyGradComp(unittest.TestCase):
...
@@ -108,10 +108,10 @@ class TestMultiplyGradComp(unittest.TestCase):
def
test_comp
(
self
):
def
test_comp
(
self
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
actual
=
self
.
vjp
()
actual
=
self
.
vjp
()
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
desired
=
self
.
vjp
()
desired
=
self
.
vjp
()
self
.
assertEqual
(
len
(
actual
),
len
(
desired
))
self
.
assertEqual
(
len
(
actual
),
len
(
desired
))
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py
浏览文件 @
23d20e30
...
@@ -16,7 +16,7 @@ import unittest
...
@@ -16,7 +16,7 @@ import unittest
from
paddle.fluid
import
core
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
import
autograd
import
autograd
import
autograd.numpy
import
autograd.numpy
...
@@ -60,7 +60,7 @@ class TestSqrtGradComp(unittest.TestCase):
...
@@ -60,7 +60,7 @@ class TestSqrtGradComp(unittest.TestCase):
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
x
.
stop_gradient
=
False
self
.
x
.
stop_gradient
=
False
net
=
PrimeNet
()
net
=
PrimeNet
()
core
.
set_prim
_enabled
(
use_prim
)
core
.
_set_prim_backward
_enabled
(
use_prim
)
net
=
apply_to_static
(
net
,
use_cinn
)
net
=
apply_to_static
(
net
,
use_cinn
)
out
=
net
(
self
.
x
)
out
=
net
(
self
.
x
)
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
])
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
])
...
@@ -109,7 +109,7 @@ class TestSqrtGradComp(unittest.TestCase):
...
@@ -109,7 +109,7 @@ class TestSqrtGradComp(unittest.TestCase):
rtol
=
1e-6
,
rtol
=
1e-6
,
atol
=
0
,
atol
=
0
,
)
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py
浏览文件 @
23d20e30
...
@@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase):
...
@@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase):
self
.
x
.
stop_gradient
=
False
self
.
x
.
stop_gradient
=
False
self
.
y
.
stop_gradient
=
False
self
.
y
.
stop_gradient
=
False
net
=
PrimeNet
()
net
=
PrimeNet
()
core
.
set_prim
_enabled
(
use_prim
)
core
.
_set_prim_backward
_enabled
(
use_prim
)
net
=
apply_to_static
(
net
,
use_cinn
)
net
=
apply_to_static
(
net
,
use_cinn
)
out
=
net
(
self
.
x
,
self
.
y
)
out
=
net
(
self
.
x
,
self
.
y
)
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
,
self
.
y
])
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
,
self
.
y
])
...
@@ -105,7 +105,7 @@ class TestDivGradComp(unittest.TestCase):
...
@@ -105,7 +105,7 @@ class TestDivGradComp(unittest.TestCase):
def
test_tanh_grad_comp
(
self
):
def
test_tanh_grad_comp
(
self
):
def
actual
(
primal0
,
primal1
):
def
actual
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal0'
,
primal0
.
shape
,
primal0
.
dtype
)
x
=
paddle
.
static
.
data
(
'primal0'
,
primal0
.
shape
,
primal0
.
dtype
)
...
@@ -127,7 +127,7 @@ class TestDivGradComp(unittest.TestCase):
...
@@ -127,7 +127,7 @@ class TestDivGradComp(unittest.TestCase):
return
out
[
0
],
out
[
1
]
return
out
[
0
],
out
[
1
]
def
desired
(
primal0
,
primal1
):
def
desired
(
primal0
,
primal1
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
x
=
paddle
.
static
.
data
(
...
@@ -168,7 +168,7 @@ class TestDivGradComp(unittest.TestCase):
...
@@ -168,7 +168,7 @@ class TestDivGradComp(unittest.TestCase):
rtol
=
1e-6
,
rtol
=
1e-6
,
atol
=
0
,
atol
=
0
,
)
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py
浏览文件 @
23d20e30
...
@@ -21,7 +21,7 @@ from paddle.fluid import core
...
@@ -21,7 +21,7 @@ from paddle.fluid import core
def
actual
(
primal
,
cotangent
,
axis
,
keep_dim
):
def
actual
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal'
,
primal
.
shape
,
primal
.
dtype
)
x
=
paddle
.
static
.
data
(
'primal'
,
primal
.
shape
,
primal
.
dtype
)
...
@@ -40,7 +40,7 @@ def actual(primal, cotangent, axis, keep_dim):
...
@@ -40,7 +40,7 @@ def actual(primal, cotangent, axis, keep_dim):
def
desired
(
primal
,
cotangent
,
axis
,
keep_dim
):
def
desired
(
primal
,
cotangent
,
axis
,
keep_dim
):
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal'
,
primal
.
shape
,
primal
.
dtype
)
x
=
paddle
.
static
.
data
(
'primal'
,
primal
.
shape
,
primal
.
dtype
)
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py
浏览文件 @
23d20e30
...
@@ -16,7 +16,7 @@ import unittest
...
@@ -16,7 +16,7 @@ import unittest
from
paddle.fluid
import
core
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
import
autograd
import
autograd
import
autograd.numpy
import
autograd.numpy
...
@@ -60,7 +60,7 @@ class TestTanhGradComp(unittest.TestCase):
...
@@ -60,7 +60,7 @@ class TestTanhGradComp(unittest.TestCase):
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
x
=
paddle
.
randn
([
2
,
4
])
self
.
x
.
stop_gradient
=
False
self
.
x
.
stop_gradient
=
False
net
=
PrimeNet
()
net
=
PrimeNet
()
core
.
set_prim
_enabled
(
use_prim
)
core
.
_set_prim_backward
_enabled
(
use_prim
)
net
=
apply_to_static
(
net
,
use_cinn
)
net
=
apply_to_static
(
net
,
use_cinn
)
out
=
net
(
self
.
x
)
out
=
net
(
self
.
x
)
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
])
res
=
paddle
.
autograd
.
grad
(
out
,
[
self
.
x
])
...
@@ -109,7 +109,7 @@ class TestTanhGradComp(unittest.TestCase):
...
@@ -109,7 +109,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol
=
1e-6
,
rtol
=
1e-6
,
atol
=
0
,
atol
=
0
,
)
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_disabled.py
浏览文件 @
23d20e30
...
@@ -17,7 +17,7 @@ import unittest
...
@@ -17,7 +17,7 @@ import unittest
from
paddle.fluid
import
core
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
import
parameterized
as
param
import
parameterized
as
param
...
...
python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py
浏览文件 @
23d20e30
...
@@ -17,7 +17,7 @@ import unittest
...
@@ -17,7 +17,7 @@ import unittest
from
paddle.fluid
import
core
from
paddle.fluid
import
core
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
import
parameterized
as
param
import
parameterized
as
param
...
@@ -77,7 +77,7 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase):
...
@@ -77,7 +77,7 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase):
)
)
print
(
actual
)
print
(
actual
)
self
.
assertEquals
(
actual
,
self
.
desired_ops
)
self
.
assertEquals
(
actual
,
self
.
desired_ops
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py
浏览文件 @
23d20e30
...
@@ -135,9 +135,9 @@ class TestResnet50Accuracy(unittest.TestCase):
...
@@ -135,9 +135,9 @@ class TestResnet50Accuracy(unittest.TestCase):
loop_num
=
10
loop_num
=
10
feed
=
self
.
generate_random_data
(
loop_num
)
feed
=
self
.
generate_random_data
(
loop_num
)
core
.
set_prim
_enabled
(
True
)
core
.
_set_prim_backward
_enabled
(
True
)
loss_c
=
self
.
train
(
place
,
loop_num
,
feed
,
use_cinn
=
True
)
loss_c
=
self
.
train
(
place
,
loop_num
,
feed
,
use_cinn
=
True
)
core
.
set_prim
_enabled
(
False
)
core
.
_set_prim_backward
_enabled
(
False
)
loss_p
=
self
.
train
(
place
,
loop_num
,
feed
,
use_cinn
=
True
)
loss_p
=
self
.
train
(
place
,
loop_num
,
feed
,
use_cinn
=
True
)
print
(
"Losses of Composite + CINN:"
)
print
(
"Losses of Composite + CINN:"
)
print
(
loss_c
)
print
(
loss_c
)
...
...
python/paddle/incubate/autograd/primapi.py
浏览文件 @
23d20e30
...
@@ -218,7 +218,7 @@ def grad(outputs, inputs, grad_outputs=None):
...
@@ -218,7 +218,7 @@ def grad(outputs, inputs, grad_outputs=None):
@
framework
.
static_only
@
framework
.
static_only
def
to_prim
(
blocks
):
def
to_prim
(
blocks
):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops."""
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops."""
if
not
core
.
enable_prim_forwar
d
():
if
not
core
.
_is_fwd_prim_enable
d
():
return
return
if
isinstance
(
blocks
,
paddle
.
fluid
.
framework
.
Block
):
if
isinstance
(
blocks
,
paddle
.
fluid
.
framework
.
Block
):
logging
.
info
(
"Atomize composite op to primitive ops begin."
)
logging
.
info
(
"Atomize composite op to primitive ops begin."
)
...
@@ -235,5 +235,6 @@ def to_prim(blocks):
...
@@ -235,5 +235,6 @@ def to_prim(blocks):
f
"Expect block or sequence of blocks, but got
{
type
(
blocks
)
}
."
f
"Expect block or sequence of blocks, but got
{
type
(
blocks
)
}
."
)
)
with
framework
.
program_guard
(
main_program
):
with
framework
.
program_guard
(
main_program
):
print
(
"Running lowering for forward..."
)
primx
.
_lower_composite
(
blocks
)
primx
.
_lower_composite
(
blocks
)
return
return
python/paddle/jit/dy2static/partial_program.py
浏览文件 @
23d20e30
...
@@ -571,13 +571,10 @@ class PartialProgramLayer:
...
@@ -571,13 +571,10 @@ class PartialProgramLayer:
targets
.
append
(
program
.
global_block
().
var
(
out
.
name
))
targets
.
append
(
program
.
global_block
().
var
(
out
.
name
))
if
targets
:
if
targets
:
enable_prim
=
self
.
_build_strategy
.
build_cinn_pass
if
self
.
_build_strategy
.
build_cinn_pass
:
if
enable_prim
and
core
.
enable_prim_backward
():
# TODO(Jiabin): Change this to True if we need this to be default option
core
.
set_prim_enabled
(
True
)
core
.
check_and_set_prim_all_enabled
()
backward
.
gradients
(
targets
=
targets
,
inputs
=
[])
backward
.
gradients
(
targets
=
targets
,
inputs
=
[])
core
.
set_prim_enabled
(
False
)
else
:
backward
.
gradients
(
targets
=
targets
,
inputs
=
[])
start_idx
=
len
(
main_program
.
block
(
0
).
ops
)
+
2
*
len
(
start_idx
=
len
(
main_program
.
block
(
0
).
ops
)
+
2
*
len
(
self
.
_outputs
.
tolist
()
self
.
_outputs
.
tolist
()
...
...
python/paddle/jit/dy2static/program_translator.py
浏览文件 @
23d20e30
...
@@ -1092,8 +1092,9 @@ class ProgramCache:
...
@@ -1092,8 +1092,9 @@ class ProgramCache:
def
_build_once
(
self
,
cache_key
):
def
_build_once
(
self
,
cache_key
):
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
enable_prim
=
cache_key
.
kwargs
[
'build_strategy'
].
build_cinn_pass
enable_prim
=
cache_key
.
kwargs
[
'build_strategy'
].
build_cinn_pass
if
enable_prim
and
core
.
enable_prim_backward
():
if
enable_prim
:
core
.
set_prim_enabled
(
True
)
# TODO(Jiabin): Change this to True if we need this to be default option
core
.
check_and_set_prim_all_enabled
()
concrete_program
=
ConcreteProgram
.
from_func_spec
(
concrete_program
=
ConcreteProgram
.
from_func_spec
(
func_spec
=
cache_key
.
function_spec
,
func_spec
=
cache_key
.
function_spec
,
...
@@ -1103,9 +1104,7 @@ class ProgramCache:
...
@@ -1103,9 +1104,7 @@ class ProgramCache:
**
cache_key
.
kwargs
**
cache_key
.
kwargs
)
)
if
enable_prim
or
core
.
enable_prim_forward
()
==
"debug"
:
concrete_program
.
_to_prim
()
concrete_program
.
_to_prim
()
core
.
set_prim_enabled
(
False
)
return
concrete_program
,
partial_program_from
(
concrete_program
)
return
concrete_program
,
partial_program_from
(
concrete_program
)
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录