diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 17ea95e3f4babd6b8fc766244eb057c078b2cd87..499eb42ea5ca3ea8dc546b0d9c278f46bed61c16 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -1841,7 +1841,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): if is_composite_grad_api and next_grad_node_creation_str != '': next_grad_node_creation_str = f""" - if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{ + if (!paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{ {next_grad_node_creation_str} }} """ @@ -2261,7 +2261,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): # TODO(Ruting):using composite only when we don't have backward kernel in the future. elif is_composite_grad_api: grad_function_call_str = f""" - if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{ + if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{ {indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str}); VLOG(4) << "Composite api {composite_grad_api_name} is called "; }}else{{ diff --git a/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc b/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc index fd309750ed6014048421d370501bad0a1fe71eff..30a82b4989972b4a0dd6f24b077b0c662306115e 100644 --- a/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc +++ b/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -166,7 +167,16 @@ Tensor full(const IntArray& shape, phi::errors::InvalidArgument( "We only support float32/float16 for full, but we got data type: %s", phi::DataTypeToString(dtype))); - op->SetAttr("value", value.to()); + if (dtype == phi::DataType::FLOAT32) { + op->SetAttr("value", value.to()); + } else if (dtype == phi::DataType::FLOAT64) { + op->SetAttr("str_value", std::to_string(value.to())); + } else if (dtype == phi::DataType::FLOAT16) { + op->SetAttr("str_value", std::to_string(value.to())); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "We only support float64/float32/float16 for full")); + } op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype)); op->SetOutput( "Out", {std::static_pointer_cast(out.impl())->Name()}); diff --git a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h index 9c12de9fe56607aeca3487c657a677ec0bf83da4..99ef82d08881c28f13cc88a18371c23d447c88d8 100644 --- a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h @@ -192,7 +192,7 @@ void divide_grad(const Tensor& x, } // indicate we will compute dy if (dx) { // dx = (1/y) * dout - auto one_tensor = full(phi::vectorize(y.dims()), 1.0); + auto one_tensor = full(phi::vectorize(y.dims()), 1.0, y.dtype()); auto tmp0 = divide(one_tensor, y); auto dx_res = multiply(tmp0, out_grad); if (y.dims() != x.dims()) { diff --git a/paddle/fluid/prim/tests/test_eager_prim.cc b/paddle/fluid/prim/tests/test_eager_prim.cc index 7bb9a389828f28d5cfe691a649057a893ebbc133..35902797ea24517d834715711a670f6ece4b899d 100644 --- a/paddle/fluid/prim/tests/test_eager_prim.cc +++ b/paddle/fluid/prim/tests/test_eager_prim.cc @@ -68,16 +68,16 @@ TEST(EagerPrim, TanhBackwardTest) { paddle::experimental::Tensor out0 = tanh_ad_func(tensor0); std::vector outs0 = {out0}; // Disable prim - PrimCommonUtils::SetPrimEnabled(false); - ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled()); + PrimCommonUtils::SetBwdPrimEnabled(false); + ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled()); // 4. Run Backward egr::Backward(outs0, {}, false); paddle::experimental::Tensor out1 = tanh_ad_func(tensor1); std::vector outs1 = {out1}; // Disable prim - PrimCommonUtils::SetPrimEnabled(true); - ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled()); + PrimCommonUtils::SetBwdPrimEnabled(true); + ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled()); // 4. Run Backward ::egr::Backward(outs1, {}, false); VLOG(7) @@ -99,10 +99,10 @@ TEST(EagerPrim, TanhBackwardTest) { } TEST(EagerPrim, TestFlags) { - PrimCommonUtils::SetPrimEnabled(true); - ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled()); - PrimCommonUtils::SetPrimEnabled(false); - ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled()); + PrimCommonUtils::SetBwdPrimEnabled(true); + ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled()); + PrimCommonUtils::SetBwdPrimEnabled(false); + ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled()); } } // namespace prim diff --git a/paddle/fluid/prim/tests/test_static_prim.cc b/paddle/fluid/prim/tests/test_static_prim.cc index 87475559617fb6b70c89417769c62695891ea443..fe7a6ca4040448306f79978e527dafa10c9a9a27 100644 --- a/paddle/fluid/prim/tests/test_static_prim.cc +++ b/paddle/fluid/prim/tests/test_static_prim.cc @@ -341,10 +341,10 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) { } TEST(StaticPrim, TestFlags) { - PrimCommonUtils::SetPrimEnabled(true); - ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled()); - PrimCommonUtils::SetPrimEnabled(false); - ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled()); + PrimCommonUtils::SetBwdPrimEnabled(true); + ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled()); + PrimCommonUtils::SetBwdPrimEnabled(false); + ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled()); } } // namespace prim diff --git a/paddle/fluid/prim/utils/static/static_global_utils.cc b/paddle/fluid/prim/utils/static/static_global_utils.cc index 3e3a0f56977e3c78e0e5de72b53654b540907eee..9631994ab2bce79e24023fe95c77934fedd2acda 100644 --- a/paddle/fluid/prim/utils/static/static_global_utils.cc +++ b/paddle/fluid/prim/utils/static/static_global_utils.cc @@ -18,6 +18,7 @@ namespace paddle { namespace prim { StaticCompositeContext* StaticCompositeContext::static_composite_context_ = new StaticCompositeContext(); -thread_local bool StaticCompositeContext::enable_prim_ = false; +thread_local bool StaticCompositeContext::enable_bwd_prim_ = false; +thread_local bool StaticCompositeContext::enable_fwd_prim_ = false; } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/utils/static/static_global_utils.h b/paddle/fluid/prim/utils/static/static_global_utils.h index f70659c278aeca0ba81d5096aa27163d972b5003..08407013673621a364c177aa1c453e8904fcac63 100644 --- a/paddle/fluid/prim/utils/static/static_global_utils.h +++ b/paddle/fluid/prim/utils/static/static_global_utils.h @@ -56,9 +56,18 @@ class StaticCompositeContext { return generator_->Generate(key); } - void SetPrimEnabled(bool enable_prim) { enable_prim_ = enable_prim; } + void SetBwdPrimEnabled(bool enable_prim) { enable_bwd_prim_ = enable_prim; } - bool IsPrimEnabled() { return enable_prim_; } + bool IsBwdPrimEnabled() { return enable_bwd_prim_; } + + void SetFwdPrimEnabled(bool enable_prim) { enable_fwd_prim_ = enable_prim; } + + bool IsFwdPrimEnabled() { return enable_fwd_prim_; } + + void SetAllPrimEnabled(bool enable_prim) { + enable_fwd_prim_ = enable_prim; + enable_bwd_prim_ = enable_prim; + } private: StaticCompositeContext() @@ -66,7 +75,8 @@ class StaticCompositeContext { framework::BlockDesc* current_block_desc_; std::unique_ptr generator_; - static thread_local bool enable_prim_; + static thread_local bool enable_bwd_prim_; + static thread_local bool enable_fwd_prim_; static StaticCompositeContext* static_composite_context_; DISABLE_COPY_AND_ASSIGN(StaticCompositeContext); }; diff --git a/paddle/fluid/prim/utils/utils.cc b/paddle/fluid/prim/utils/utils.cc index ddb97ab640d20b84e0c9ab143ead2129b45c884d..fb415262c8d13e2e0ca297f98eda8288c5ceb53c 100644 --- a/paddle/fluid/prim/utils/utils.cc +++ b/paddle/fluid/prim/utils/utils.cc @@ -19,12 +19,24 @@ PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not"); namespace paddle { namespace prim { -bool PrimCommonUtils::IsPrimEnabled() { - return StaticCompositeContext::Instance().IsPrimEnabled(); +bool PrimCommonUtils::IsBwdPrimEnabled() { + return StaticCompositeContext::Instance().IsBwdPrimEnabled(); } -void PrimCommonUtils::SetPrimEnabled(bool enable_prim) { - return StaticCompositeContext::Instance().SetPrimEnabled(enable_prim); +void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) { + return StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim); +} + +bool PrimCommonUtils::IsFwdPrimEnabled() { + return StaticCompositeContext::Instance().IsFwdPrimEnabled(); +} + +void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) { + return StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim); +} + +void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) { + return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim); } } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/utils/utils.h b/paddle/fluid/prim/utils/utils.h index 14757c4eecde6aaf693b0b6fd2f476b56efd0e04..38973dc87b8adf9408e0fc62dd85d11cad754551 100644 --- a/paddle/fluid/prim/utils/utils.h +++ b/paddle/fluid/prim/utils/utils.h @@ -18,8 +18,11 @@ namespace paddle { namespace prim { class PrimCommonUtils { public: - static bool IsPrimEnabled(); - static void SetPrimEnabled(bool enabled); + static bool IsBwdPrimEnabled(); + static void SetBwdPrimEnabled(bool enabled); + static bool IsFwdPrimEnabled(); + static void SetFwdPrimEnabled(bool enabled); + static void SetAllPrimEnabled(bool enabled); }; } // namespace prim } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 43ee2d479b0b76b0d6851fe2c1b58e06e977fb76..d2f622537216b3954298b60d1514b5781b1b5eb1 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -660,8 +660,16 @@ PYBIND11_MODULE(libpaddle, m) { return oss.str(); }); - m.def("set_prim_enabled", &paddle::prim::PrimCommonUtils::SetPrimEnabled); - m.def("is_prim_enabled", &paddle::prim::PrimCommonUtils::IsPrimEnabled); + m.def("__set_bwd_prim_enabled", + &paddle::prim::PrimCommonUtils::SetBwdPrimEnabled); + m.def("_is_bwd_prim_enabled", + &paddle::prim::PrimCommonUtils::IsBwdPrimEnabled); + m.def("__set_fwd_prim_enabled", + &paddle::prim::PrimCommonUtils::SetFwdPrimEnabled); + m.def("_is_fwd_prim_enabled", + &paddle::prim::PrimCommonUtils::IsFwdPrimEnabled); + m.def("__set_all_prim_enabled", + &paddle::prim::PrimCommonUtils::SetAllPrimEnabled); m.def("set_num_threads", &platform::SetNumThreads); m.def("disable_signal_handler", &DisableSignalHandler); @@ -1264,8 +1272,9 @@ All parameter, weight, gradient are variables in Paddle. // priority of GradCompOpMaker is less than GradCompMaker for better // performance. std::vector> grad_op_descs; - if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) { + if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) { if (grad_comp_op_maker != nullptr) { + VLOG(3) << "Runing composite fun for " << op_desc.Type(); grad_op_descs = grad_comp_op_maker(op_desc, no_grad_set, &grad_to_var, diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index f47e206c7ce2fe2742529382ef18092f92571cde..615008a8291c5599132386119ffbf985559a2f7b 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -42,7 +42,7 @@ kernel : func : add_grad no_need_buffer : x, y - composite : add_grad(Tensor x, Tensor y, Tensor out_grad, int axis) + composite : add_grad(x, y, out_grad, axis) backward : add_double_grad inplace : (out_grad -> x_grad) @@ -390,7 +390,7 @@ param : [x, y] kernel : func : divide_grad - composite : divide_grad(Tensor x, Tensor y, Tensor out, Tensor out_grad, int axis = -1) + composite : divide_grad(x, y, out, out_grad, -1) backward : divide_double_grad - backward_op : dropout_grad @@ -1319,7 +1319,7 @@ kernel : func : subtract_grad no_need_buffer : x, y - composite : subtract_grad(Tensor x, Tensor y, Tensor out_grad, int axis) + composite : subtract_grad(x, y, out_grad, axis) backward : subtract_double_grad inplace : (out_grad -> x_grad) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 76401d5c47a9aba70631a88ad05edf4b37db2f79..5169f9f085fe24d7611fd5181bc00366782f72df 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1493,14 +1493,15 @@ def _append_backward_ops_( # remove some backward ops # TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem - if not core.is_prim_enabled(): + if not core._is_bwd_prim_enabled(): not_need_ops = _find_not_need_ops( grad_op_descs, ops, input_grad_names_set ) - grad_op_descs = [ op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops ] + else: + logging.debug("Runing backward composite and disable find_not_need_ops") # append op_desc in grad_op_descs to target_block op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index 771caa4ef3c4fa822dd40d0077b90e778bbbb5d6..b17c29a97868aab5d6efd0d96dfa26c97e2245e5 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -17,6 +17,7 @@ import sys import os import warnings import platform +import logging has_paddle_dy_lib = False @@ -305,8 +306,13 @@ try: from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent from .libpaddle import _set_current_stream from .libpaddle import _get_phi_kernel_name - from .libpaddle import set_prim_enabled - from .libpaddle import is_prim_enabled + + # prim controller flags + from .libpaddle import __set_bwd_prim_enabled + from .libpaddle import _is_bwd_prim_enabled + from .libpaddle import __set_fwd_prim_enabled + from .libpaddle import _is_fwd_prim_enabled + from .libpaddle import __set_all_prim_enabled if sys.platform != 'win32': from .libpaddle import _set_process_pids @@ -373,36 +379,98 @@ def set_paddle_lib_path(): set_paddle_lib_path() +# We have 3 FLAGS to judge whether prim is enabled +# FLAGS_prim_forward: Open or close forward prim strategy +# FLAGS_prim_backward: Open or close backward prim strategy +# FLAGS_prim_all: Open or close all prim strategy +# +# +# Priorities: +# if With CINN and Dy2St: +# # # _set_prim_all_enabled > FLAGS_prim_all > check_and_set_prim_all_enabled == _set_prim_backward_enabled == _set_prim_backward_enabled > FLAGS_prim_forward == FLAGS_prim_backward +# else: +# # # _set_prim_all_enabled > FLAGS_prim_all == check_and_set_prim_all_enabled == _set_prim_backward_enabled == _set_prim_backward_enabled > FLAGS_prim_forward == FLAGS_prim_backward +def __sync_stat_with_flag(flag): + if flag is "FLAGS_prim_forward": + flag_value = os.getenv("FLAGS_prim_forward") + assert flag_value is not None + flag_value = flag_value.lower() + if flag_value == "false": + __set_fwd_prim_enabled(False) + elif flag_value == "true": + __set_fwd_prim_enabled(True) + else: + raise TypeError(f"flag {flag} should be true or false.") + logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled())) + elif flag is "FLAGS_prim_backward": + flag_value = os.getenv("FLAGS_prim_backward") + assert flag_value is not None + flag_value = flag_value.lower() + if flag_value == "false": + __set_bwd_prim_enabled(False) + elif flag_value == "true": + __set_bwd_prim_enabled(True) + else: + raise TypeError(f"flag {flag} should be true or false.") + logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled())) + elif flag is "FLAGS_prim_all": + flag_value = os.getenv("FLAGS_prim_all") + assert flag_value is not None + flag_value = flag_value.lower() + if flag_value == "false": + __set_all_prim_enabled(False) + elif flag_value == "true": + __set_all_prim_enabled(True) + else: + raise TypeError(f"flag {flag} should be true or false.") + logging.debug( + "all prim enabled: ", + bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()), + ) + else: + raise TypeError( + f"We only support FLAGS_prim_forward/FLAGS_prim_backward/FLAGS_prim_all but we got {flag}." + ) -def set_prim_forward(value): - """set flag FLAGS_prim_forward.""" - flag = str(value) - if flag.lower() not in ["true", "false", "debug"]: - raise TypeError(f"flag {flag} should be string of bool or 'debug'.") - os.environ["FLAGS_prim_forward"] = flag - return +def _set_prim_backward_enabled(value): + __set_bwd_prim_enabled(bool(value)) + logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled())) -def enable_prim_forward(): - flag = os.getenv("FLAGS_prim_forward", "true").lower() - if flag == "false": - return False - if flag == "debug": - return "debug" - return True +def _set_prim_forward_enabled(value): + __set_fwd_prim_enabled(bool(value)) + logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled())) -def set_prim_backward(value): - """set flag FLAGS_prim_backward,""" - flag = str(value) - if flag.lower() not in ["true", "false"]: - raise TypeError(f"flag {flag} should be bool or string of bool.") - os.environ["FLAGS_prim_backward"] = flag - return +def _set_prim_all_enabled(value): + __set_all_prim_enabled(bool(value)) + logging.debug( + "all prim enabled: ", + bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()), + ) -def enable_prim_backward(): - flag = os.getenv("FLAGS_prim_backward", "true") - if flag.lower() == "false": - return False - return True + +def __sync_prim_backward_status(): + flag_value = os.getenv("FLAGS_prim_backward") + if flag_value is None: + logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled())) + else: + __sync_stat_with_flag("FLAGS_prim_backward") + + +def __sync_prim_forward_status(): + flag_value = os.getenv("FLAGS_prim_forward") + if flag_value is None: + logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled())) + else: + __sync_stat_with_flag("FLAGS_prim_forward") + + +def check_and_set_prim_all_enabled(): + flag_value = os.getenv("FLAGS_prim_all") + if flag_value is None: + __sync_prim_backward_status() + __sync_prim_forward_status() + else: + __sync_stat_with_flag("FLAGS_prim_all") diff --git a/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py b/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py index c7c876b8f8fea6c79bfdf91c95e4dc16d7cefb93..fd54850b2cb6f2a55d18f28c4644b7ea4b859fd1 100644 --- a/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py +++ b/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py @@ -19,6 +19,7 @@ from utils import TOLERANCE import paddle import paddle.nn.functional as F +from paddle.fluid import core def generate_data(shape, dtype="float32"): @@ -72,6 +73,7 @@ class TestCompositeSoftmax(unittest.TestCase): def cal_composite(self, inputs): paddle.enable_static() + core._set_prim_forward_enabled(True) startup_program = paddle.static.Program() main_program = paddle.static.Program() with paddle.static.program_guard(main_program, startup_program): @@ -95,6 +97,7 @@ class TestCompositeSoftmax(unittest.TestCase): exe.run(startup_program) res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y]) paddle.disable_static() + core._set_prim_forward_enabled(False) return res def compare_forward(self): diff --git a/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py b/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py index 808c5f8324b65a87efa5c46005c553f5f58703fb..9b6e5db7953565c1289800a550e7b9dca7e9b399 100644 --- a/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py +++ b/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py @@ -78,6 +78,7 @@ class TestCompositeSoftmax(unittest.TestCase): def cal_composite_grad(self, inputs): paddle.enable_static() + core._set_prim_all_enabled(True) startup_program = paddle.static.Program() main_program = paddle.static.Program() with paddle.static.program_guard(main_program, startup_program): @@ -108,6 +109,7 @@ class TestCompositeSoftmax(unittest.TestCase): exe.run(startup_program) res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) paddle.disable_static() + core._set_prim_all_enabled(False) return res def compare_backward(self): @@ -139,7 +141,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase): "test composite softmax and prim backward" def setUp(self): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) self.dtypes = ["float32"] self.shapes = [[2, 3, 4], [2, 3]] self.axes = [-1, 0, 1] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py index 4ac7a3dfe4cf39076d4a550cff04466aeaacad7c..f4d59f1a1552f978460e6c23ffaf209f497b154e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py @@ -236,11 +236,11 @@ class TestBert(unittest.TestCase): self.verify_predict() def test_train_composite(self): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) static_loss, static_ppl = self.train_static( self.bert_config, self.data_reader ) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) dygraph_loss, dygraph_ppl = self.train_dygraph( self.bert_config, self.data_reader ) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py index 2811a348f46561423449eb9f646e750c7935e3cc..a807e1eef234048fd0860ea616a7f6aaf241da7e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py @@ -47,7 +47,6 @@ class TestPrimForward(unittest.TestCase): """ def setUp(self): - core.set_prim_backward(False) paddle.seed(2022) self.x = paddle.randn([2, 4]) self.x.stop_gradient = False @@ -58,6 +57,7 @@ class TestPrimForward(unittest.TestCase): sgd = paddle.optimizer.SGD( learning_rate=0.1, parameters=net.parameters() ) + core._set_prim_forward_enabled(use_prim) if use_prim: net = apply_to_static(net, use_prim) @@ -103,12 +103,12 @@ class TestPrimForwardAndBackward(unittest.TestCase): self.x.stop_gradient = False def train(self, use_prim): - core.set_prim_backward(True) paddle.seed(2022) net = PrimeNet() sgd = paddle.optimizer.SGD( learning_rate=0.1, parameters=net.parameters() ) + core._set_prim_all_enabled(use_prim) if use_prim: net = apply_to_static(net, use_prim) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py index b195c7d342a724598bfa175c60442d3bca418048..911ca2ec9016f122b0ed16abd506fda22d4aaccd 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py @@ -427,10 +427,10 @@ class TestResnet(unittest.TestCase): ) self.verify_predict() - def test_resnet_composite(self): - core.set_prim_enabled(True) + def test_resnet_composite_backward(self): + core._set_prim_backward_enabled(True) static_loss = self.train(to_static=True) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) dygraph_loss = self.train(to_static=True) np.testing.assert_allclose( static_loss, @@ -440,65 +440,13 @@ class TestResnet(unittest.TestCase): static_loss, dygraph_loss ), ) - core.set_prim_enabled(False) - def test_in_static_mode_mkldnn(self): - fluid.set_flags({'FLAGS_use_mkldnn': True}) - try: - if paddle.fluid.core.is_compiled_with_mkldnn(): - self.resnet_helper.train(to_static=True) - finally: - fluid.set_flags({'FLAGS_use_mkldnn': False}) - - -class TestResnetPrim(unittest.TestCase): - "test prim forward + prim backward + to_static" - - def setUp(self): - self.resnet_helper = ResNetHelper() - - def train(self, to_static): - paddle.jit.enable_to_static(to_static) - return self.resnet_helper.train(to_static) - - def verify_predict(self): - image = np.random.random([1, 3, 224, 224]).astype('float32') - dy_pre = self.resnet_helper.predict_dygraph(image) - st_pre = self.resnet_helper.predict_static(image) - dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image) - predictor_pre = self.resnet_helper.predict_analysis_inference(image) - np.testing.assert_allclose( - dy_pre, - st_pre, - rtol=1e-05, - err_msg='dy_pre:\n {}\n, st_pre: \n{}.'.format(dy_pre, st_pre), - ) - np.testing.assert_allclose( - dy_jit_pre, - st_pre, - rtol=1e-05, - err_msg='dy_jit_pre:\n {}\n, st_pre: \n{}.'.format( - dy_jit_pre, st_pre - ), - ) - np.testing.assert_allclose( - predictor_pre, - st_pre, - rtol=1e-05, - err_msg='predictor_pre:\n {}\n, st_pre: \n{}.'.format( - predictor_pre, st_pre - ), - ) - - def test_resnet_composite(self): + def test_resnet_composite_forward_backward(self): plat = platform.system() if plat == "Linux": - print("=================== origin resnet ===================") - core.set_prim_enabled(False) + core._set_prim_all_enabled(True) static_loss = self.train(to_static=True) - print("======= resnet with prim forward and backward =======") - core.set_prim_enabled(True) - core.set_prim_forward("debug") + core._set_prim_all_enabled(False) dygraph_loss = self.train(to_static=True) np.testing.assert_allclose( static_loss, @@ -508,10 +456,17 @@ class TestResnetPrim(unittest.TestCase): static_loss, dygraph_loss ), ) - core.set_prim_enabled(False) else: pass + def test_in_static_mode_mkldnn(self): + fluid.set_flags({'FLAGS_use_mkldnn': True}) + try: + if paddle.fluid.core.is_compiled_with_mkldnn(): + self.resnet_helper.train(to_static=True) + finally: + fluid.set_flags({'FLAGS_use_mkldnn': False}) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_amp.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_amp.py index 8e6872c079cec5157eef6b7debb654ff88a43261..f0cd98c2c110bbb9592c69b3373c3da2e74f75fb 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_amp.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_amp.py @@ -130,9 +130,9 @@ class TestResnet(unittest.TestCase): ) def test_resnet_composite(self): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) static_loss = self.train(to_static=True) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) dygraph_loss = self.train(to_static=False) np.testing.assert_allclose( static_loss, diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py index 6213f6fae2415a90eaea437044df0a975f69406a..252b63a646b7a42e288cc2b16d2262de223bfdcf 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py @@ -137,9 +137,9 @@ class TestResnet(unittest.TestCase): def test_resnet_composite(self): if fluid.is_compiled_with_cuda(): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) static_loss = self.train(to_static=True) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) dygraph_loss = self.train(to_static=False) # NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here. np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py index 1b4d01114f8c2acd680b1bd1f232cbdafd139cfa..5bbeba860f590bd6b51930807e928cdd28225c19 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py @@ -426,9 +426,9 @@ class TestResnet(unittest.TestCase): self.verify_predict() def test_resnet_composite(self): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) static_loss = self.train(to_static=True) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) dygraph_loss = self.train(to_static=False) np.testing.assert_allclose( static_loss, diff --git a/python/paddle/fluid/tests/unittests/prim/prim/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/prim/CMakeLists.txt index db4822bce3f91bfdfff8dfeeedb7ba1ae0ba45be..80c5c8fe1538f8e378f1d3b0f9f37eeeba1fcbb8 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/prim/prim/CMakeLists.txt @@ -9,3 +9,4 @@ foreach(TEST_OP ${TEST_OPS}) endforeach() add_subdirectory(vjp) +add_subdirectory(flags) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..72c6bbd7d05e8fdf99fce350ad15c216dcac5c92 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/flags/CMakeLists.txt @@ -0,0 +1,9 @@ +file( + GLOB TEST_OPS + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) +endforeach() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3053af919e926c274a0f5d5bfba7d75ed12805 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from paddle.fluid import core + + +class TestPrimFlags(unittest.TestCase): + def test_prim_flags(self): + self.assertFalse(core._is_bwd_prim_enabled()) + self.assertFalse(core._is_fwd_prim_enabled()) + + os.environ['FLAGS_prim_backward'] = "True" + core.check_and_set_prim_all_enabled() + self.assertTrue(core._is_bwd_prim_enabled()) + os.environ['FLAGS_prim_forward'] = "True" + core.check_and_set_prim_all_enabled() + self.assertTrue(core._is_fwd_prim_enabled()) + os.environ['FLAGS_prim_all'] = "False" + core.check_and_set_prim_all_enabled() + self.assertFalse(core._is_bwd_prim_enabled()) + self.assertFalse(core._is_fwd_prim_enabled()) + + os.environ['FLAGS_prim_all'] = "True" + core.check_and_set_prim_all_enabled() + self.assertTrue(core._is_bwd_prim_enabled()) + self.assertTrue(core._is_fwd_prim_enabled()) + + del os.environ['FLAGS_prim_all'] + os.environ['FLAGS_prim_backward'] = "False" + core.check_and_set_prim_all_enabled() + self.assertFalse(core._is_bwd_prim_enabled()) + os.environ['FLAGS_prim_forward'] = "False" + core.check_and_set_prim_all_enabled() + self.assertFalse(core._is_fwd_prim_enabled()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py index e05fef3b18d129bb702e82129131aef645cf02bb..b5a183add8cd0a3191dc670a5cf710eabb5351ec 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py @@ -20,7 +20,7 @@ import parameterized as param import paddle from paddle.fluid import core -core.set_prim_enabled(True) +core._set_prim_backward_enabled(True) @param.parameterized_class( @@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase): def test_tanh_grad_comp(self): def actual(primal0, primal1): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) @@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase): return res[0].numpy(), res[1].numpy() def desired(primal0, primal1): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) @@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py index c9ae5cd7ecbafd7f554dc24a38a40b0a42eec854..96e186e32e91041afd734e3e3b6109f1d762f766 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py @@ -20,7 +20,7 @@ import parameterized as param import paddle from paddle.fluid import core -core.set_prim_enabled(True) +core._set_prim_backward_enabled(True) @param.parameterized_class( @@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase): def test_tanh_grad_comp(self): def actual(primal0, primal1): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) @@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase): return res[0].numpy(), res[1].numpy() def desired(primal0, primal1): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) @@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py index e81314ba041ef7f91f0bdb4f1c266d4bcc92bb72..85974031280820dd0c056b8101340b3d344dc83e 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py @@ -32,14 +32,14 @@ from paddle.fluid import core class TestExpGradComp(unittest.TestCase): @classmethod def setUpClass(cls): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) cls.primal = cls.primal.astype(cls.dtype) if cls.cotangent is not None: cls.cotangent = cls.cotangent.astype(cls.dtype) @classmethod def tearDownClass(cls): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) def test_exp_grad_comp(self): def actual(primal, cotangent): diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py index c4de565dc504f73c9e505d091c4e05ec06798d57..92b0b98942caaff065ce513aa0aee11550bf7ab7 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py @@ -62,7 +62,7 @@ class TestExpandGradComp(unittest.TestCase): @classmethod def tearDownClass(cls): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) def test_comp(self): def func(primal, cotangent, shape): @@ -74,11 +74,11 @@ class TestExpandGradComp(unittest.TestCase): ] def actual(primal, cotangent, shape): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) return func(primal, cotangent, shape) def desired(primal, cotangent, shape): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) return func(primal, cotangent, shape) np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py index 59daf91ab8b84b391e971ae6c28b75ea7e05b89f..fdef2779a41b8a05fee7159e9190989095b93f36 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py @@ -81,10 +81,10 @@ class TestMultiplyGradComp(unittest.TestCase): return [g for g in grads if g is not None] def test_comp(self): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) actual = self.vjp() - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) desired = self.vjp() for i, j in zip(actual, desired): diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py index 7abb91e912ac4fa74cd5b8d4ca1e9d59d4ca219b..a97cb37420145f87aab2036495faa1bedf125015 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py @@ -22,7 +22,7 @@ import parameterized as param import paddle from paddle.fluid import core -core.set_prim_enabled(True) +core._set_prim_backward_enabled(True) @param.parameterized_class( @@ -63,7 +63,7 @@ class TestSqrtGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py index e508ae63803cb121e6cdce099c178b3cce28c9c2..2a6e758ba42cf3575ee6819ed9ed4d6aec850a88 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py @@ -20,7 +20,7 @@ import parameterized as param import paddle from paddle.fluid import core -core.set_prim_enabled(True) +core._set_prim_backward_enabled(True) @param.parameterized_class( @@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase): def test_tanh_grad_comp(self): def actual(primal0, primal1): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) @@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase): return res[0].numpy(), res[1].numpy() def desired(primal0, primal1): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) @@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py index 5586f7c0ccaf64bd924b30e6053ac9f2932bab30..e7f8b23542e6b8f562796e25b50bf5f9e68bc3a4 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py @@ -21,7 +21,7 @@ from paddle.fluid import core def actual(primal, cotangent, axis, keep_dim): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False) y = paddle.sum(x, axis=axis, keepdim=keep_dim) @@ -30,7 +30,7 @@ def actual(primal, cotangent, axis, keep_dim): def desired(primal, cotangent, axis, keep_dim): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False) y = paddle.sum(x, axis=axis, keepdim=keep_dim) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py index 438f71b573a717efde1def08a118e9fbf1fbfa81..11cc010b2ee130880bff5e4a9bbeccc10c51f057 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py @@ -20,7 +20,7 @@ import parameterized as param import paddle from paddle.fluid import core -core.set_prim_enabled(True) +core._set_prim_backward_enabled(True) @param.parameterized_class( @@ -74,7 +74,7 @@ class TestTanhGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py index b7d7969d9aa0469d98e8d460c25bc17058235648..1673ff083e7cf4081b300ab7de6e585e7b7d1c21 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py @@ -81,7 +81,7 @@ class TestAddGradComp(unittest.TestCase): self.x.stop_gradient = False self.y.stop_gradient = False net = PrimeNet() - core.set_prim_enabled(use_prim) + core._set_prim_backward_enabled(use_prim) net = apply_to_static(net, use_cinn) out = net(self.x, self.y) res = paddle.autograd.grad(out, [self.x, self.y]) @@ -104,7 +104,7 @@ class TestAddGradComp(unittest.TestCase): def test_tanh_grad_comp(self): def actual(primal0, primal1): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): x = paddle.static.data('primal0', primal0.shape, primal0.dtype) @@ -126,7 +126,7 @@ class TestAddGradComp(unittest.TestCase): return out[0], out[1] def desired(primal0, primal1): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): x = paddle.static.data( @@ -167,7 +167,7 @@ class TestAddGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py index 45cae351a73ebb98d93e526efc26e148a96ef764..5dd7417130bc1137b751ea420384d63350c216b0 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py @@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase): self.x.stop_gradient = False self.y.stop_gradient = False net = PrimeNet() - core.set_prim_enabled(use_prim) + core._set_prim_backward_enabled(use_prim) net = apply_to_static(net, use_cinn) out = net(self.x, self.y) res = paddle.autograd.grad(out, [self.x, self.y]) @@ -107,7 +107,7 @@ class TestDivGradComp(unittest.TestCase): paddle.enable_static() def actual(primal0, primal1): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): x = paddle.static.data('primal0', primal0.shape, primal0.dtype) @@ -130,7 +130,7 @@ class TestDivGradComp(unittest.TestCase): return out[0], out[1] def desired(primal0, primal1): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): x = paddle.static.data( @@ -172,7 +172,7 @@ class TestDivGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) paddle.disable_static() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py index 1d675e8bd097968ed660f52de0c1f658803837c6..95d3c3027fd9d28e4b054806959c8ad8ec391e9a 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py @@ -81,7 +81,7 @@ class TestDivGradComp(unittest.TestCase): self.x.stop_gradient = False self.y.stop_gradient = False net = PrimeNet() - core.set_prim_enabled(use_prim) + core._set_prim_backward_enabled(use_prim) net = apply_to_static(net, use_cinn) out = net(self.x, self.y) res = paddle.autograd.grad(out, [self.x, self.y]) @@ -104,7 +104,7 @@ class TestDivGradComp(unittest.TestCase): def test_tanh_grad_comp(self): def actual(primal0, primal1): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): x = paddle.static.data('primal0', primal0.shape, primal0.dtype) @@ -126,7 +126,7 @@ class TestDivGradComp(unittest.TestCase): return out[0], out[1] def desired(primal0, primal1): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): x = paddle.static.data( @@ -167,7 +167,7 @@ class TestDivGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py index c1c76631232c007f4a2a81cb2227035910bb57d8..2e720f6934f5cd8975b8f4137b15de873cc06277 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py @@ -33,14 +33,14 @@ from paddle.fluid import core class TestExpGradComp(unittest.TestCase): @classmethod def setUpClass(cls): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) cls.primal = cls.primal.astype(cls.dtype) if cls.cotangent is not None: cls.cotangent = cls.cotangent.astype(cls.dtype) @classmethod def tearDownClass(cls): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) def setUp(self): paddle.enable_static() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_expand_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_expand_grad.py index c322074d34d88715aa9aec84ca5ca6e05c88aba8..2772719a81820a9e46599b1a3881e10c5e015f95 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_expand_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_expand_grad.py @@ -71,7 +71,7 @@ class TestExpandGradComp(unittest.TestCase): @classmethod def tearDownClass(cls): paddle.disable_static() - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) def test_comp(self): def func(primal, cotangent, shape): @@ -93,11 +93,11 @@ class TestExpandGradComp(unittest.TestCase): )[0] def actual(primal, cotangent, shape): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) return func(primal, cotangent, shape) def desired(primal, cotangent, shape): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) return func(primal, cotangent, shape) np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py index 63e8a4f1bbf3451bed5c9402a40ffa13a0bbd319..2d1a10a6d4b5794d938b1685f03057ee9b63ca89 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py @@ -108,10 +108,10 @@ class TestMultiplyGradComp(unittest.TestCase): def test_comp(self): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) actual = self.vjp() - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) desired = self.vjp() self.assertEqual(len(actual), len(desired)) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py index 505a4391138e95adb376924499cb95bc43fcb5cb..8df50c768c2b72e11d5de955b3b88e65183c0aad 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py @@ -16,7 +16,7 @@ import unittest from paddle.fluid import core -core.set_prim_enabled(True) +core._set_prim_backward_enabled(True) import autograd import autograd.numpy @@ -60,7 +60,7 @@ class TestSqrtGradComp(unittest.TestCase): self.x = paddle.randn([2, 4]) self.x.stop_gradient = False net = PrimeNet() - core.set_prim_enabled(use_prim) + core._set_prim_backward_enabled(use_prim) net = apply_to_static(net, use_cinn) out = net(self.x) res = paddle.autograd.grad(out, [self.x]) @@ -109,7 +109,7 @@ class TestSqrtGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py index f98a6af621f96f32e99f6d5f46afd5c297e6a528..693bf8b942bab23e9af6b10c5456b8e76936d38b 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py @@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase): self.x.stop_gradient = False self.y.stop_gradient = False net = PrimeNet() - core.set_prim_enabled(use_prim) + core._set_prim_backward_enabled(use_prim) net = apply_to_static(net, use_cinn) out = net(self.x, self.y) res = paddle.autograd.grad(out, [self.x, self.y]) @@ -105,7 +105,7 @@ class TestDivGradComp(unittest.TestCase): def test_tanh_grad_comp(self): def actual(primal0, primal1): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): x = paddle.static.data('primal0', primal0.shape, primal0.dtype) @@ -127,7 +127,7 @@ class TestDivGradComp(unittest.TestCase): return out[0], out[1] def desired(primal0, primal1): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): x = paddle.static.data( @@ -168,7 +168,7 @@ class TestDivGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py index b9b2ad03913cb7127c63c1a57c0d1af5944cff2f..a6b12c7cf623c8d9dede05aa9c7d0fcc64549572 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sum_grad.py @@ -21,7 +21,7 @@ from paddle.fluid import core def actual(primal, cotangent, axis, keep_dim): - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): x = paddle.static.data('primal', primal.shape, primal.dtype) @@ -40,7 +40,7 @@ def actual(primal, cotangent, axis, keep_dim): def desired(primal, cotangent, axis, keep_dim): - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) mp, sp = paddle.static.Program(), paddle.static.Program() with paddle.static.program_guard(mp, sp): x = paddle.static.data('primal', primal.shape, primal.dtype) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py index c7c9109eeaab0403b87e74a0d9edea8ebe995e21..e643cf620a8118fb1c36202ada5543b60b3f0012 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py @@ -16,7 +16,7 @@ import unittest from paddle.fluid import core -core.set_prim_enabled(True) +core._set_prim_backward_enabled(True) import autograd import autograd.numpy @@ -60,7 +60,7 @@ class TestTanhGradComp(unittest.TestCase): self.x = paddle.randn([2, 4]) self.x.stop_gradient = False net = PrimeNet() - core.set_prim_enabled(use_prim) + core._set_prim_backward_enabled(use_prim) net = apply_to_static(net, use_cinn) out = net(self.x) res = paddle.autograd.grad(out, [self.x]) @@ -109,7 +109,7 @@ class TestTanhGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_disabled.py b/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_disabled.py index 3170313a0d845f383f8b8503ad9c0bc5782f5828..9292a1c4276d6b5b9087e4d52003519010a7a45f 100644 --- a/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_disabled.py +++ b/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_disabled.py @@ -17,7 +17,7 @@ import unittest from paddle.fluid import core -core.set_prim_enabled(False) +core._set_prim_backward_enabled(False) import parameterized as param diff --git a/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py b/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py index d76f99dc601063aca2e2e6160f4c68fa52280335..18b445f38da3a8ea74f671b52ac74d33624be54a 100644 --- a/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py +++ b/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py @@ -17,7 +17,7 @@ import unittest from paddle.fluid import core -core.set_prim_enabled(True) +core._set_prim_backward_enabled(True) import parameterized as param @@ -77,7 +77,7 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase): ) print(actual) self.assertEquals(actual, self.desired_ops) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py b/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py index 277f3558d09d4f6d744b0edd1f9fc9ffcf2e4f9e..d580636ce50d7efe02454cb57b3b23c0994c7f93 100644 --- a/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py +++ b/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py @@ -135,9 +135,9 @@ class TestResnet50Accuracy(unittest.TestCase): loop_num = 10 feed = self.generate_random_data(loop_num) - core.set_prim_enabled(True) + core._set_prim_backward_enabled(True) loss_c = self.train(place, loop_num, feed, use_cinn=True) - core.set_prim_enabled(False) + core._set_prim_backward_enabled(False) loss_p = self.train(place, loop_num, feed, use_cinn=True) print("Losses of Composite + CINN:") print(loss_c) diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py index 76e0802194272927c6318bba7def02e67314cdfd..476f7125c443ecc9ae0d045be8cade974f634f09 100644 --- a/python/paddle/incubate/autograd/primapi.py +++ b/python/paddle/incubate/autograd/primapi.py @@ -218,7 +218,7 @@ def grad(outputs, inputs, grad_outputs=None): @framework.static_only def to_prim(blocks): """Search nonbasic ops which have be registered composite rules and replace them with primitive ops.""" - if not core.enable_prim_forward(): + if not core._is_fwd_prim_enabled(): return if isinstance(blocks, paddle.fluid.framework.Block): logging.info("Atomize composite op to primitive ops begin.") @@ -235,5 +235,6 @@ def to_prim(blocks): f"Expect block or sequence of blocks, but got {type(blocks)}." ) with framework.program_guard(main_program): + print("Running lowering for forward...") primx._lower_composite(blocks) return diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index 293c8b40f7752bffa3819e459a929052780cff0e..19d44b8e35c8ba5d9d6f414eacd4c4a266eac856 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -571,13 +571,10 @@ class PartialProgramLayer: targets.append(program.global_block().var(out.name)) if targets: - enable_prim = self._build_strategy.build_cinn_pass - if enable_prim and core.enable_prim_backward(): - core.set_prim_enabled(True) - backward.gradients(targets=targets, inputs=[]) - core.set_prim_enabled(False) - else: - backward.gradients(targets=targets, inputs=[]) + if self._build_strategy.build_cinn_pass: + # TODO(Jiabin): Change this to True if we need this to be default option + core.check_and_set_prim_all_enabled() + backward.gradients(targets=targets, inputs=[]) start_idx = len(main_program.block(0).ops) + 2 * len( self._outputs.tolist() diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 5b8493977e904b9fefdb4ae448b8df1abd499e13..5a66cd103a7fe70efb455b2840b09e86fed67f53 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1092,8 +1092,9 @@ class ProgramCache: def _build_once(self, cache_key): # TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass - if enable_prim and core.enable_prim_backward(): - core.set_prim_enabled(True) + if enable_prim: + # TODO(Jiabin): Change this to True if we need this to be default option + core.check_and_set_prim_all_enabled() concrete_program = ConcreteProgram.from_func_spec( func_spec=cache_key.function_spec, @@ -1103,9 +1104,7 @@ class ProgramCache: **cache_key.kwargs ) - if enable_prim or core.enable_prim_forward() == "debug": - concrete_program._to_prim() - core.set_prim_enabled(False) + concrete_program._to_prim() return concrete_program, partial_program_from(concrete_program) def __getitem__(self, item):