未验证 提交 23d20e30 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Refactor prim flags system (#49930)

上级 44855da3
...@@ -1841,7 +1841,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1841,7 +1841,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if is_composite_grad_api and next_grad_node_creation_str != '': if is_composite_grad_api and next_grad_node_creation_str != '':
next_grad_node_creation_str = f""" next_grad_node_creation_str = f"""
if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{ if (!paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
{next_grad_node_creation_str} {next_grad_node_creation_str}
}} }}
""" """
...@@ -2261,7 +2261,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -2261,7 +2261,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# TODO(Ruting):using composite only when we don't have backward kernel in the future. # TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif is_composite_grad_api: elif is_composite_grad_api:
grad_function_call_str = f""" grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{ if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str}); {indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
VLOG(4) << "Composite api {composite_grad_api_name} is called "; VLOG(4) << "Composite api {composite_grad_api_name} is called ";
}}else{{ }}else{{
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <string.h>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
...@@ -166,7 +167,16 @@ Tensor full<DescTensor>(const IntArray& shape, ...@@ -166,7 +167,16 @@ Tensor full<DescTensor>(const IntArray& shape,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"We only support float32/float16 for full, but we got data type: %s", "We only support float32/float16 for full, but we got data type: %s",
phi::DataTypeToString(dtype))); phi::DataTypeToString(dtype)));
op->SetAttr("value", value.to<float>()); if (dtype == phi::DataType::FLOAT32) {
op->SetAttr("value", value.to<float>());
} else if (dtype == phi::DataType::FLOAT64) {
op->SetAttr("str_value", std::to_string(value.to<double>()));
} else if (dtype == phi::DataType::FLOAT16) {
op->SetAttr("str_value", std::to_string(value.to<float>()));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"We only support float64/float32/float16 for full"));
}
op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype)); op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
op->SetOutput( op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()}); "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
......
...@@ -192,7 +192,7 @@ void divide_grad(const Tensor& x, ...@@ -192,7 +192,7 @@ void divide_grad(const Tensor& x,
} // indicate we will compute dy } // indicate we will compute dy
if (dx) { if (dx) {
// dx = (1/y) * dout // dx = (1/y) * dout
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0); auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
auto tmp0 = divide<T>(one_tensor, y); auto tmp0 = divide<T>(one_tensor, y);
auto dx_res = multiply<T>(tmp0, out_grad); auto dx_res = multiply<T>(tmp0, out_grad);
if (y.dims() != x.dims()) { if (y.dims() != x.dims()) {
......
...@@ -68,16 +68,16 @@ TEST(EagerPrim, TanhBackwardTest) { ...@@ -68,16 +68,16 @@ TEST(EagerPrim, TanhBackwardTest) {
paddle::experimental::Tensor out0 = tanh_ad_func(tensor0); paddle::experimental::Tensor out0 = tanh_ad_func(tensor0);
std::vector<paddle::experimental::Tensor> outs0 = {out0}; std::vector<paddle::experimental::Tensor> outs0 = {out0};
// Disable prim // Disable prim
PrimCommonUtils::SetPrimEnabled(false); PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled()); ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward // 4. Run Backward
egr::Backward(outs0, {}, false); egr::Backward(outs0, {}, false);
paddle::experimental::Tensor out1 = tanh_ad_func(tensor1); paddle::experimental::Tensor out1 = tanh_ad_func(tensor1);
std::vector<paddle::experimental::Tensor> outs1 = {out1}; std::vector<paddle::experimental::Tensor> outs1 = {out1};
// Disable prim // Disable prim
PrimCommonUtils::SetPrimEnabled(true); PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled()); ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward // 4. Run Backward
::egr::Backward(outs1, {}, false); ::egr::Backward(outs1, {}, false);
VLOG(7) VLOG(7)
...@@ -99,10 +99,10 @@ TEST(EagerPrim, TanhBackwardTest) { ...@@ -99,10 +99,10 @@ TEST(EagerPrim, TanhBackwardTest) {
} }
TEST(EagerPrim, TestFlags) { TEST(EagerPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true); PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled()); ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false); PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled()); ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
} }
} // namespace prim } // namespace prim
......
...@@ -341,10 +341,10 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) { ...@@ -341,10 +341,10 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
} }
TEST(StaticPrim, TestFlags) { TEST(StaticPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true); PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled()); ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false); PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled()); ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
} }
} // namespace prim } // namespace prim
......
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace prim { namespace prim {
StaticCompositeContext* StaticCompositeContext::static_composite_context_ = StaticCompositeContext* StaticCompositeContext::static_composite_context_ =
new StaticCompositeContext(); new StaticCompositeContext();
thread_local bool StaticCompositeContext::enable_prim_ = false; thread_local bool StaticCompositeContext::enable_bwd_prim_ = false;
thread_local bool StaticCompositeContext::enable_fwd_prim_ = false;
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -56,9 +56,18 @@ class StaticCompositeContext { ...@@ -56,9 +56,18 @@ class StaticCompositeContext {
return generator_->Generate(key); return generator_->Generate(key);
} }
void SetPrimEnabled(bool enable_prim) { enable_prim_ = enable_prim; } void SetBwdPrimEnabled(bool enable_prim) { enable_bwd_prim_ = enable_prim; }
bool IsPrimEnabled() { return enable_prim_; } bool IsBwdPrimEnabled() { return enable_bwd_prim_; }
void SetFwdPrimEnabled(bool enable_prim) { enable_fwd_prim_ = enable_prim; }
bool IsFwdPrimEnabled() { return enable_fwd_prim_; }
void SetAllPrimEnabled(bool enable_prim) {
enable_fwd_prim_ = enable_prim;
enable_bwd_prim_ = enable_prim;
}
private: private:
StaticCompositeContext() StaticCompositeContext()
...@@ -66,7 +75,8 @@ class StaticCompositeContext { ...@@ -66,7 +75,8 @@ class StaticCompositeContext {
framework::BlockDesc* current_block_desc_; framework::BlockDesc* current_block_desc_;
std::unique_ptr<UniqueNameGenerator> generator_; std::unique_ptr<UniqueNameGenerator> generator_;
static thread_local bool enable_prim_; static thread_local bool enable_bwd_prim_;
static thread_local bool enable_fwd_prim_;
static StaticCompositeContext* static_composite_context_; static StaticCompositeContext* static_composite_context_;
DISABLE_COPY_AND_ASSIGN(StaticCompositeContext); DISABLE_COPY_AND_ASSIGN(StaticCompositeContext);
}; };
......
...@@ -19,12 +19,24 @@ ...@@ -19,12 +19,24 @@
PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not"); PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not");
namespace paddle { namespace paddle {
namespace prim { namespace prim {
bool PrimCommonUtils::IsPrimEnabled() { bool PrimCommonUtils::IsBwdPrimEnabled() {
return StaticCompositeContext::Instance().IsPrimEnabled(); return StaticCompositeContext::Instance().IsBwdPrimEnabled();
} }
void PrimCommonUtils::SetPrimEnabled(bool enable_prim) { void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetPrimEnabled(enable_prim); return StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim);
}
bool PrimCommonUtils::IsFwdPrimEnabled() {
return StaticCompositeContext::Instance().IsFwdPrimEnabled();
}
void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim);
}
void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim);
} }
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -18,8 +18,11 @@ namespace paddle { ...@@ -18,8 +18,11 @@ namespace paddle {
namespace prim { namespace prim {
class PrimCommonUtils { class PrimCommonUtils {
public: public:
static bool IsPrimEnabled(); static bool IsBwdPrimEnabled();
static void SetPrimEnabled(bool enabled); static void SetBwdPrimEnabled(bool enabled);
static bool IsFwdPrimEnabled();
static void SetFwdPrimEnabled(bool enabled);
static void SetAllPrimEnabled(bool enabled);
}; };
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -660,8 +660,16 @@ PYBIND11_MODULE(libpaddle, m) { ...@@ -660,8 +660,16 @@ PYBIND11_MODULE(libpaddle, m) {
return oss.str(); return oss.str();
}); });
m.def("set_prim_enabled", &paddle::prim::PrimCommonUtils::SetPrimEnabled); m.def("__set_bwd_prim_enabled",
m.def("is_prim_enabled", &paddle::prim::PrimCommonUtils::IsPrimEnabled); &paddle::prim::PrimCommonUtils::SetBwdPrimEnabled);
m.def("_is_bwd_prim_enabled",
&paddle::prim::PrimCommonUtils::IsBwdPrimEnabled);
m.def("__set_fwd_prim_enabled",
&paddle::prim::PrimCommonUtils::SetFwdPrimEnabled);
m.def("_is_fwd_prim_enabled",
&paddle::prim::PrimCommonUtils::IsFwdPrimEnabled);
m.def("__set_all_prim_enabled",
&paddle::prim::PrimCommonUtils::SetAllPrimEnabled);
m.def("set_num_threads", &platform::SetNumThreads); m.def("set_num_threads", &platform::SetNumThreads);
m.def("disable_signal_handler", &DisableSignalHandler); m.def("disable_signal_handler", &DisableSignalHandler);
...@@ -1264,8 +1272,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1264,8 +1272,9 @@ All parameter, weight, gradient are variables in Paddle.
// priority of GradCompOpMaker is less than GradCompMaker for better // priority of GradCompOpMaker is less than GradCompMaker for better
// performance. // performance.
std::vector<std::unique_ptr<OpDesc>> grad_op_descs; std::vector<std::unique_ptr<OpDesc>> grad_op_descs;
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) { if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {
if (grad_comp_op_maker != nullptr) { if (grad_comp_op_maker != nullptr) {
VLOG(3) << "Runing composite fun for " << op_desc.Type();
grad_op_descs = grad_comp_op_maker(op_desc, grad_op_descs = grad_comp_op_maker(op_desc,
no_grad_set, no_grad_set,
&grad_to_var, &grad_to_var,
......
...@@ -42,7 +42,7 @@ ...@@ -42,7 +42,7 @@
kernel : kernel :
func : add_grad func : add_grad
no_need_buffer : x, y no_need_buffer : x, y
composite : add_grad(Tensor x, Tensor y, Tensor out_grad, int axis) composite : add_grad(x, y, out_grad, axis)
backward : add_double_grad backward : add_double_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
...@@ -390,7 +390,7 @@ ...@@ -390,7 +390,7 @@
param : [x, y] param : [x, y]
kernel : kernel :
func : divide_grad func : divide_grad
composite : divide_grad(Tensor x, Tensor y, Tensor out, Tensor out_grad, int axis = -1) composite : divide_grad(x, y, out, out_grad, -1)
backward : divide_double_grad backward : divide_double_grad
- backward_op : dropout_grad - backward_op : dropout_grad
...@@ -1319,7 +1319,7 @@ ...@@ -1319,7 +1319,7 @@
kernel : kernel :
func : subtract_grad func : subtract_grad
no_need_buffer : x, y no_need_buffer : x, y
composite : subtract_grad(Tensor x, Tensor y, Tensor out_grad, int axis) composite : subtract_grad(x, y, out_grad, axis)
backward : subtract_double_grad backward : subtract_double_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
......
...@@ -1493,14 +1493,15 @@ def _append_backward_ops_( ...@@ -1493,14 +1493,15 @@ def _append_backward_ops_(
# remove some backward ops # remove some backward ops
# TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem # TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem
if not core.is_prim_enabled(): if not core._is_bwd_prim_enabled():
not_need_ops = _find_not_need_ops( not_need_ops = _find_not_need_ops(
grad_op_descs, ops, input_grad_names_set grad_op_descs, ops, input_grad_names_set
) )
grad_op_descs = [ grad_op_descs = [
op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
] ]
else:
logging.debug("Runing backward composite and disable find_not_need_ops")
# append op_desc in grad_op_descs to target_block # append op_desc in grad_op_descs to target_block
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
......
...@@ -17,6 +17,7 @@ import sys ...@@ -17,6 +17,7 @@ import sys
import os import os
import warnings import warnings
import platform import platform
import logging
has_paddle_dy_lib = False has_paddle_dy_lib = False
...@@ -305,8 +306,13 @@ try: ...@@ -305,8 +306,13 @@ try:
from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent
from .libpaddle import _set_current_stream from .libpaddle import _set_current_stream
from .libpaddle import _get_phi_kernel_name from .libpaddle import _get_phi_kernel_name
from .libpaddle import set_prim_enabled
from .libpaddle import is_prim_enabled # prim controller flags
from .libpaddle import __set_bwd_prim_enabled
from .libpaddle import _is_bwd_prim_enabled
from .libpaddle import __set_fwd_prim_enabled
from .libpaddle import _is_fwd_prim_enabled
from .libpaddle import __set_all_prim_enabled
if sys.platform != 'win32': if sys.platform != 'win32':
from .libpaddle import _set_process_pids from .libpaddle import _set_process_pids
...@@ -373,36 +379,98 @@ def set_paddle_lib_path(): ...@@ -373,36 +379,98 @@ def set_paddle_lib_path():
set_paddle_lib_path() set_paddle_lib_path()
# We have 3 FLAGS to judge whether prim is enabled
# FLAGS_prim_forward: Open or close forward prim strategy
# FLAGS_prim_backward: Open or close backward prim strategy
# FLAGS_prim_all: Open or close all prim strategy
#
#
# Priorities:
# if With CINN and Dy2St:
# # # _set_prim_all_enabled > FLAGS_prim_all > check_and_set_prim_all_enabled == _set_prim_backward_enabled == _set_prim_backward_enabled > FLAGS_prim_forward == FLAGS_prim_backward
# else:
# # # _set_prim_all_enabled > FLAGS_prim_all == check_and_set_prim_all_enabled == _set_prim_backward_enabled == _set_prim_backward_enabled > FLAGS_prim_forward == FLAGS_prim_backward
def __sync_stat_with_flag(flag):
if flag is "FLAGS_prim_forward":
flag_value = os.getenv("FLAGS_prim_forward")
assert flag_value is not None
flag_value = flag_value.lower()
if flag_value == "false":
__set_fwd_prim_enabled(False)
elif flag_value == "true":
__set_fwd_prim_enabled(True)
else:
raise TypeError(f"flag {flag} should be true or false.")
logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
elif flag is "FLAGS_prim_backward":
flag_value = os.getenv("FLAGS_prim_backward")
assert flag_value is not None
flag_value = flag_value.lower()
if flag_value == "false":
__set_bwd_prim_enabled(False)
elif flag_value == "true":
__set_bwd_prim_enabled(True)
else:
raise TypeError(f"flag {flag} should be true or false.")
logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
elif flag is "FLAGS_prim_all":
flag_value = os.getenv("FLAGS_prim_all")
assert flag_value is not None
flag_value = flag_value.lower()
if flag_value == "false":
__set_all_prim_enabled(False)
elif flag_value == "true":
__set_all_prim_enabled(True)
else:
raise TypeError(f"flag {flag} should be true or false.")
logging.debug(
"all prim enabled: ",
bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
)
else:
raise TypeError(
f"We only support FLAGS_prim_forward/FLAGS_prim_backward/FLAGS_prim_all but we got {flag}."
)
def set_prim_forward(value):
"""set flag FLAGS_prim_forward."""
flag = str(value)
if flag.lower() not in ["true", "false", "debug"]:
raise TypeError(f"flag {flag} should be string of bool or 'debug'.")
os.environ["FLAGS_prim_forward"] = flag
return
def _set_prim_backward_enabled(value):
__set_bwd_prim_enabled(bool(value))
logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
def enable_prim_forward():
flag = os.getenv("FLAGS_prim_forward", "true").lower()
if flag == "false":
return False
if flag == "debug":
return "debug"
return True
def _set_prim_forward_enabled(value):
__set_fwd_prim_enabled(bool(value))
logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
def set_prim_backward(value):
"""set flag FLAGS_prim_backward,"""
flag = str(value)
if flag.lower() not in ["true", "false"]:
raise TypeError(f"flag {flag} should be bool or string of bool.")
os.environ["FLAGS_prim_backward"] = flag
return
def _set_prim_all_enabled(value):
__set_all_prim_enabled(bool(value))
logging.debug(
"all prim enabled: ",
bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
)
def enable_prim_backward():
flag = os.getenv("FLAGS_prim_backward", "true") def __sync_prim_backward_status():
if flag.lower() == "false": flag_value = os.getenv("FLAGS_prim_backward")
return False if flag_value is None:
return True logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
else:
__sync_stat_with_flag("FLAGS_prim_backward")
def __sync_prim_forward_status():
flag_value = os.getenv("FLAGS_prim_forward")
if flag_value is None:
logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
else:
__sync_stat_with_flag("FLAGS_prim_forward")
def check_and_set_prim_all_enabled():
flag_value = os.getenv("FLAGS_prim_all")
if flag_value is None:
__sync_prim_backward_status()
__sync_prim_forward_status()
else:
__sync_stat_with_flag("FLAGS_prim_all")
...@@ -19,6 +19,7 @@ from utils import TOLERANCE ...@@ -19,6 +19,7 @@ from utils import TOLERANCE
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core
def generate_data(shape, dtype="float32"): def generate_data(shape, dtype="float32"):
...@@ -72,6 +73,7 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -72,6 +73,7 @@ class TestCompositeSoftmax(unittest.TestCase):
def cal_composite(self, inputs): def cal_composite(self, inputs):
paddle.enable_static() paddle.enable_static()
core._set_prim_forward_enabled(True)
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
main_program = paddle.static.Program() main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
...@@ -95,6 +97,7 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -95,6 +97,7 @@ class TestCompositeSoftmax(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y]) res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y])
paddle.disable_static() paddle.disable_static()
core._set_prim_forward_enabled(False)
return res return res
def compare_forward(self): def compare_forward(self):
......
...@@ -78,6 +78,7 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -78,6 +78,7 @@ class TestCompositeSoftmax(unittest.TestCase):
def cal_composite_grad(self, inputs): def cal_composite_grad(self, inputs):
paddle.enable_static() paddle.enable_static()
core._set_prim_all_enabled(True)
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
main_program = paddle.static.Program() main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
...@@ -108,6 +109,7 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -108,6 +109,7 @@ class TestCompositeSoftmax(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static() paddle.disable_static()
core._set_prim_all_enabled(False)
return res return res
def compare_backward(self): def compare_backward(self):
...@@ -139,7 +141,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase): ...@@ -139,7 +141,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
"test composite softmax and prim backward" "test composite softmax and prim backward"
def setUp(self): def setUp(self):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
self.dtypes = ["float32"] self.dtypes = ["float32"]
self.shapes = [[2, 3, 4], [2, 3]] self.shapes = [[2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1] self.axes = [-1, 0, 1]
......
...@@ -236,11 +236,11 @@ class TestBert(unittest.TestCase): ...@@ -236,11 +236,11 @@ class TestBert(unittest.TestCase):
self.verify_predict() self.verify_predict()
def test_train_composite(self): def test_train_composite(self):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
static_loss, static_ppl = self.train_static( static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader self.bert_config, self.data_reader
) )
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
dygraph_loss, dygraph_ppl = self.train_dygraph( dygraph_loss, dygraph_ppl = self.train_dygraph(
self.bert_config, self.data_reader self.bert_config, self.data_reader
) )
......
...@@ -47,7 +47,6 @@ class TestPrimForward(unittest.TestCase): ...@@ -47,7 +47,6 @@ class TestPrimForward(unittest.TestCase):
""" """
def setUp(self): def setUp(self):
core.set_prim_backward(False)
paddle.seed(2022) paddle.seed(2022)
self.x = paddle.randn([2, 4]) self.x = paddle.randn([2, 4])
self.x.stop_gradient = False self.x.stop_gradient = False
...@@ -58,6 +57,7 @@ class TestPrimForward(unittest.TestCase): ...@@ -58,6 +57,7 @@ class TestPrimForward(unittest.TestCase):
sgd = paddle.optimizer.SGD( sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters() learning_rate=0.1, parameters=net.parameters()
) )
core._set_prim_forward_enabled(use_prim)
if use_prim: if use_prim:
net = apply_to_static(net, use_prim) net = apply_to_static(net, use_prim)
...@@ -103,12 +103,12 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -103,12 +103,12 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.x.stop_gradient = False self.x.stop_gradient = False
def train(self, use_prim): def train(self, use_prim):
core.set_prim_backward(True)
paddle.seed(2022) paddle.seed(2022)
net = PrimeNet() net = PrimeNet()
sgd = paddle.optimizer.SGD( sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters() learning_rate=0.1, parameters=net.parameters()
) )
core._set_prim_all_enabled(use_prim)
if use_prim: if use_prim:
net = apply_to_static(net, use_prim) net = apply_to_static(net, use_prim)
......
...@@ -427,10 +427,10 @@ class TestResnet(unittest.TestCase): ...@@ -427,10 +427,10 @@ class TestResnet(unittest.TestCase):
) )
self.verify_predict() self.verify_predict()
def test_resnet_composite(self): def test_resnet_composite_backward(self):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
static_loss = self.train(to_static=True) static_loss = self.train(to_static=True)
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
dygraph_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=True)
np.testing.assert_allclose( np.testing.assert_allclose(
static_loss, static_loss,
...@@ -440,65 +440,13 @@ class TestResnet(unittest.TestCase): ...@@ -440,65 +440,13 @@ class TestResnet(unittest.TestCase):
static_loss, dygraph_loss static_loss, dygraph_loss
), ),
) )
core.set_prim_enabled(False)
def test_in_static_mode_mkldnn(self): def test_resnet_composite_forward_backward(self):
fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
if paddle.fluid.core.is_compiled_with_mkldnn():
self.resnet_helper.train(to_static=True)
finally:
fluid.set_flags({'FLAGS_use_mkldnn': False})
class TestResnetPrim(unittest.TestCase):
"test prim forward + prim backward + to_static"
def setUp(self):
self.resnet_helper = ResNetHelper()
def train(self, to_static):
paddle.jit.enable_to_static(to_static)
return self.resnet_helper.train(to_static)
def verify_predict(self):
image = np.random.random([1, 3, 224, 224]).astype('float32')
dy_pre = self.resnet_helper.predict_dygraph(image)
st_pre = self.resnet_helper.predict_static(image)
dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image)
predictor_pre = self.resnet_helper.predict_analysis_inference(image)
np.testing.assert_allclose(
dy_pre,
st_pre,
rtol=1e-05,
err_msg='dy_pre:\n {}\n, st_pre: \n{}.'.format(dy_pre, st_pre),
)
np.testing.assert_allclose(
dy_jit_pre,
st_pre,
rtol=1e-05,
err_msg='dy_jit_pre:\n {}\n, st_pre: \n{}.'.format(
dy_jit_pre, st_pre
),
)
np.testing.assert_allclose(
predictor_pre,
st_pre,
rtol=1e-05,
err_msg='predictor_pre:\n {}\n, st_pre: \n{}.'.format(
predictor_pre, st_pre
),
)
def test_resnet_composite(self):
plat = platform.system() plat = platform.system()
if plat == "Linux": if plat == "Linux":
print("=================== origin resnet ===================") core._set_prim_all_enabled(True)
core.set_prim_enabled(False)
static_loss = self.train(to_static=True) static_loss = self.train(to_static=True)
print("======= resnet with prim forward and backward =======") core._set_prim_all_enabled(False)
core.set_prim_enabled(True)
core.set_prim_forward("debug")
dygraph_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=True)
np.testing.assert_allclose( np.testing.assert_allclose(
static_loss, static_loss,
...@@ -508,10 +456,17 @@ class TestResnetPrim(unittest.TestCase): ...@@ -508,10 +456,17 @@ class TestResnetPrim(unittest.TestCase):
static_loss, dygraph_loss static_loss, dygraph_loss
), ),
) )
core.set_prim_enabled(False)
else: else:
pass pass
def test_in_static_mode_mkldnn(self):
fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
if paddle.fluid.core.is_compiled_with_mkldnn():
self.resnet_helper.train(to_static=True)
finally:
fluid.set_flags({'FLAGS_use_mkldnn': False})
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -130,9 +130,9 @@ class TestResnet(unittest.TestCase): ...@@ -130,9 +130,9 @@ class TestResnet(unittest.TestCase):
) )
def test_resnet_composite(self): def test_resnet_composite(self):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
static_loss = self.train(to_static=True) static_loss = self.train(to_static=True)
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
dygraph_loss = self.train(to_static=False) dygraph_loss = self.train(to_static=False)
np.testing.assert_allclose( np.testing.assert_allclose(
static_loss, static_loss,
......
...@@ -137,9 +137,9 @@ class TestResnet(unittest.TestCase): ...@@ -137,9 +137,9 @@ class TestResnet(unittest.TestCase):
def test_resnet_composite(self): def test_resnet_composite(self):
if fluid.is_compiled_with_cuda(): if fluid.is_compiled_with_cuda():
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
static_loss = self.train(to_static=True) static_loss = self.train(to_static=True)
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
dygraph_loss = self.train(to_static=False) dygraph_loss = self.train(to_static=False)
# NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here. # NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here.
np.testing.assert_allclose( np.testing.assert_allclose(
......
...@@ -426,9 +426,9 @@ class TestResnet(unittest.TestCase): ...@@ -426,9 +426,9 @@ class TestResnet(unittest.TestCase):
self.verify_predict() self.verify_predict()
def test_resnet_composite(self): def test_resnet_composite(self):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
static_loss = self.train(to_static=True) static_loss = self.train(to_static=True)
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
dygraph_loss = self.train(to_static=False) dygraph_loss = self.train(to_static=False)
np.testing.assert_allclose( np.testing.assert_allclose(
static_loss, static_loss,
......
...@@ -9,3 +9,4 @@ foreach(TEST_OP ${TEST_OPS}) ...@@ -9,3 +9,4 @@ foreach(TEST_OP ${TEST_OPS})
endforeach() endforeach()
add_subdirectory(vjp) add_subdirectory(vjp)
add_subdirectory(flags)
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from paddle.fluid import core
class TestPrimFlags(unittest.TestCase):
def test_prim_flags(self):
self.assertFalse(core._is_bwd_prim_enabled())
self.assertFalse(core._is_fwd_prim_enabled())
os.environ['FLAGS_prim_backward'] = "True"
core.check_and_set_prim_all_enabled()
self.assertTrue(core._is_bwd_prim_enabled())
os.environ['FLAGS_prim_forward'] = "True"
core.check_and_set_prim_all_enabled()
self.assertTrue(core._is_fwd_prim_enabled())
os.environ['FLAGS_prim_all'] = "False"
core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_bwd_prim_enabled())
self.assertFalse(core._is_fwd_prim_enabled())
os.environ['FLAGS_prim_all'] = "True"
core.check_and_set_prim_all_enabled()
self.assertTrue(core._is_bwd_prim_enabled())
self.assertTrue(core._is_fwd_prim_enabled())
del os.environ['FLAGS_prim_all']
os.environ['FLAGS_prim_backward'] = "False"
core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_bwd_prim_enabled())
os.environ['FLAGS_prim_forward'] = "False"
core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_fwd_prim_enabled())
if __name__ == '__main__':
unittest.main()
...@@ -20,7 +20,7 @@ import parameterized as param ...@@ -20,7 +20,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
@param.parameterized_class( @param.parameterized_class(
...@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
def test_tanh_grad_comp(self): def test_tanh_grad_comp(self):
def actual(primal0, primal1): def actual(primal0, primal1):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
...@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
return res[0].numpy(), res[1].numpy() return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1): def desired(primal0, primal1):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
...@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,7 +20,7 @@ import parameterized as param ...@@ -20,7 +20,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
@param.parameterized_class( @param.parameterized_class(
...@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
def test_tanh_grad_comp(self): def test_tanh_grad_comp(self):
def actual(primal0, primal1): def actual(primal0, primal1):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
...@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
return res[0].numpy(), res[1].numpy() return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1): def desired(primal0, primal1):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
...@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -32,14 +32,14 @@ from paddle.fluid import core ...@@ -32,14 +32,14 @@ from paddle.fluid import core
class TestExpGradComp(unittest.TestCase): class TestExpGradComp(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
cls.primal = cls.primal.astype(cls.dtype) cls.primal = cls.primal.astype(cls.dtype)
if cls.cotangent is not None: if cls.cotangent is not None:
cls.cotangent = cls.cotangent.astype(cls.dtype) cls.cotangent = cls.cotangent.astype(cls.dtype)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
def test_exp_grad_comp(self): def test_exp_grad_comp(self):
def actual(primal, cotangent): def actual(primal, cotangent):
......
...@@ -62,7 +62,7 @@ class TestExpandGradComp(unittest.TestCase): ...@@ -62,7 +62,7 @@ class TestExpandGradComp(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
def test_comp(self): def test_comp(self):
def func(primal, cotangent, shape): def func(primal, cotangent, shape):
...@@ -74,11 +74,11 @@ class TestExpandGradComp(unittest.TestCase): ...@@ -74,11 +74,11 @@ class TestExpandGradComp(unittest.TestCase):
] ]
def actual(primal, cotangent, shape): def actual(primal, cotangent, shape):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
return func(primal, cotangent, shape) return func(primal, cotangent, shape)
def desired(primal, cotangent, shape): def desired(primal, cotangent, shape):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
return func(primal, cotangent, shape) return func(primal, cotangent, shape)
np.testing.assert_allclose( np.testing.assert_allclose(
......
...@@ -81,10 +81,10 @@ class TestMultiplyGradComp(unittest.TestCase): ...@@ -81,10 +81,10 @@ class TestMultiplyGradComp(unittest.TestCase):
return [g for g in grads if g is not None] return [g for g in grads if g is not None]
def test_comp(self): def test_comp(self):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
actual = self.vjp() actual = self.vjp()
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
desired = self.vjp() desired = self.vjp()
for i, j in zip(actual, desired): for i, j in zip(actual, desired):
......
...@@ -22,7 +22,7 @@ import parameterized as param ...@@ -22,7 +22,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
@param.parameterized_class( @param.parameterized_class(
...@@ -63,7 +63,7 @@ class TestSqrtGradComp(unittest.TestCase): ...@@ -63,7 +63,7 @@ class TestSqrtGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,7 +20,7 @@ import parameterized as param ...@@ -20,7 +20,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
@param.parameterized_class( @param.parameterized_class(
...@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
def test_tanh_grad_comp(self): def test_tanh_grad_comp(self):
def actual(primal0, primal1): def actual(primal0, primal1):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
...@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
return res[0].numpy(), res[1].numpy() return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1): def desired(primal0, primal1):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
...@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -21,7 +21,7 @@ from paddle.fluid import core ...@@ -21,7 +21,7 @@ from paddle.fluid import core
def actual(primal, cotangent, axis, keep_dim): def actual(primal, cotangent, axis, keep_dim):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False)
v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False) v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False)
y = paddle.sum(x, axis=axis, keepdim=keep_dim) y = paddle.sum(x, axis=axis, keepdim=keep_dim)
...@@ -30,7 +30,7 @@ def actual(primal, cotangent, axis, keep_dim): ...@@ -30,7 +30,7 @@ def actual(primal, cotangent, axis, keep_dim):
def desired(primal, cotangent, axis, keep_dim): def desired(primal, cotangent, axis, keep_dim):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False)
v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False) v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False)
y = paddle.sum(x, axis=axis, keepdim=keep_dim) y = paddle.sum(x, axis=axis, keepdim=keep_dim)
......
...@@ -20,7 +20,7 @@ import parameterized as param ...@@ -20,7 +20,7 @@ import parameterized as param
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
@param.parameterized_class( @param.parameterized_class(
...@@ -74,7 +74,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -74,7 +74,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -81,7 +81,7 @@ class TestAddGradComp(unittest.TestCase): ...@@ -81,7 +81,7 @@ class TestAddGradComp(unittest.TestCase):
self.x.stop_gradient = False self.x.stop_gradient = False
self.y.stop_gradient = False self.y.stop_gradient = False
net = PrimeNet() net = PrimeNet()
core.set_prim_enabled(use_prim) core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn) net = apply_to_static(net, use_cinn)
out = net(self.x, self.y) out = net(self.x, self.y)
res = paddle.autograd.grad(out, [self.x, self.y]) res = paddle.autograd.grad(out, [self.x, self.y])
...@@ -104,7 +104,7 @@ class TestAddGradComp(unittest.TestCase): ...@@ -104,7 +104,7 @@ class TestAddGradComp(unittest.TestCase):
def test_tanh_grad_comp(self): def test_tanh_grad_comp(self):
def actual(primal0, primal1): def actual(primal0, primal1):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype) x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
...@@ -126,7 +126,7 @@ class TestAddGradComp(unittest.TestCase): ...@@ -126,7 +126,7 @@ class TestAddGradComp(unittest.TestCase):
return out[0], out[1] return out[0], out[1]
def desired(primal0, primal1): def desired(primal0, primal1):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
x = paddle.static.data( x = paddle.static.data(
...@@ -167,7 +167,7 @@ class TestAddGradComp(unittest.TestCase): ...@@ -167,7 +167,7 @@ class TestAddGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase):
self.x.stop_gradient = False self.x.stop_gradient = False
self.y.stop_gradient = False self.y.stop_gradient = False
net = PrimeNet() net = PrimeNet()
core.set_prim_enabled(use_prim) core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn) net = apply_to_static(net, use_cinn)
out = net(self.x, self.y) out = net(self.x, self.y)
res = paddle.autograd.grad(out, [self.x, self.y]) res = paddle.autograd.grad(out, [self.x, self.y])
...@@ -107,7 +107,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -107,7 +107,7 @@ class TestDivGradComp(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
def actual(primal0, primal1): def actual(primal0, primal1):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype) x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
...@@ -130,7 +130,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -130,7 +130,7 @@ class TestDivGradComp(unittest.TestCase):
return out[0], out[1] return out[0], out[1]
def desired(primal0, primal1): def desired(primal0, primal1):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
x = paddle.static.data( x = paddle.static.data(
...@@ -172,7 +172,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -172,7 +172,7 @@ class TestDivGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
paddle.disable_static() paddle.disable_static()
......
...@@ -81,7 +81,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -81,7 +81,7 @@ class TestDivGradComp(unittest.TestCase):
self.x.stop_gradient = False self.x.stop_gradient = False
self.y.stop_gradient = False self.y.stop_gradient = False
net = PrimeNet() net = PrimeNet()
core.set_prim_enabled(use_prim) core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn) net = apply_to_static(net, use_cinn)
out = net(self.x, self.y) out = net(self.x, self.y)
res = paddle.autograd.grad(out, [self.x, self.y]) res = paddle.autograd.grad(out, [self.x, self.y])
...@@ -104,7 +104,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -104,7 +104,7 @@ class TestDivGradComp(unittest.TestCase):
def test_tanh_grad_comp(self): def test_tanh_grad_comp(self):
def actual(primal0, primal1): def actual(primal0, primal1):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype) x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
...@@ -126,7 +126,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -126,7 +126,7 @@ class TestDivGradComp(unittest.TestCase):
return out[0], out[1] return out[0], out[1]
def desired(primal0, primal1): def desired(primal0, primal1):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
x = paddle.static.data( x = paddle.static.data(
...@@ -167,7 +167,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -167,7 +167,7 @@ class TestDivGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -33,14 +33,14 @@ from paddle.fluid import core ...@@ -33,14 +33,14 @@ from paddle.fluid import core
class TestExpGradComp(unittest.TestCase): class TestExpGradComp(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
cls.primal = cls.primal.astype(cls.dtype) cls.primal = cls.primal.astype(cls.dtype)
if cls.cotangent is not None: if cls.cotangent is not None:
cls.cotangent = cls.cotangent.astype(cls.dtype) cls.cotangent = cls.cotangent.astype(cls.dtype)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
def setUp(self): def setUp(self):
paddle.enable_static() paddle.enable_static()
......
...@@ -71,7 +71,7 @@ class TestExpandGradComp(unittest.TestCase): ...@@ -71,7 +71,7 @@ class TestExpandGradComp(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
paddle.disable_static() paddle.disable_static()
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
def test_comp(self): def test_comp(self):
def func(primal, cotangent, shape): def func(primal, cotangent, shape):
...@@ -93,11 +93,11 @@ class TestExpandGradComp(unittest.TestCase): ...@@ -93,11 +93,11 @@ class TestExpandGradComp(unittest.TestCase):
)[0] )[0]
def actual(primal, cotangent, shape): def actual(primal, cotangent, shape):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
return func(primal, cotangent, shape) return func(primal, cotangent, shape)
def desired(primal, cotangent, shape): def desired(primal, cotangent, shape):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
return func(primal, cotangent, shape) return func(primal, cotangent, shape)
np.testing.assert_allclose( np.testing.assert_allclose(
......
...@@ -108,10 +108,10 @@ class TestMultiplyGradComp(unittest.TestCase): ...@@ -108,10 +108,10 @@ class TestMultiplyGradComp(unittest.TestCase):
def test_comp(self): def test_comp(self):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
actual = self.vjp() actual = self.vjp()
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
desired = self.vjp() desired = self.vjp()
self.assertEqual(len(actual), len(desired)) self.assertEqual(len(actual), len(desired))
......
...@@ -16,7 +16,7 @@ import unittest ...@@ -16,7 +16,7 @@ import unittest
from paddle.fluid import core from paddle.fluid import core
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
import autograd import autograd
import autograd.numpy import autograd.numpy
...@@ -60,7 +60,7 @@ class TestSqrtGradComp(unittest.TestCase): ...@@ -60,7 +60,7 @@ class TestSqrtGradComp(unittest.TestCase):
self.x = paddle.randn([2, 4]) self.x = paddle.randn([2, 4])
self.x.stop_gradient = False self.x.stop_gradient = False
net = PrimeNet() net = PrimeNet()
core.set_prim_enabled(use_prim) core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn) net = apply_to_static(net, use_cinn)
out = net(self.x) out = net(self.x)
res = paddle.autograd.grad(out, [self.x]) res = paddle.autograd.grad(out, [self.x])
...@@ -109,7 +109,7 @@ class TestSqrtGradComp(unittest.TestCase): ...@@ -109,7 +109,7 @@ class TestSqrtGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase):
self.x.stop_gradient = False self.x.stop_gradient = False
self.y.stop_gradient = False self.y.stop_gradient = False
net = PrimeNet() net = PrimeNet()
core.set_prim_enabled(use_prim) core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn) net = apply_to_static(net, use_cinn)
out = net(self.x, self.y) out = net(self.x, self.y)
res = paddle.autograd.grad(out, [self.x, self.y]) res = paddle.autograd.grad(out, [self.x, self.y])
...@@ -105,7 +105,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -105,7 +105,7 @@ class TestDivGradComp(unittest.TestCase):
def test_tanh_grad_comp(self): def test_tanh_grad_comp(self):
def actual(primal0, primal1): def actual(primal0, primal1):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype) x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
...@@ -127,7 +127,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -127,7 +127,7 @@ class TestDivGradComp(unittest.TestCase):
return out[0], out[1] return out[0], out[1]
def desired(primal0, primal1): def desired(primal0, primal1):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
x = paddle.static.data( x = paddle.static.data(
...@@ -168,7 +168,7 @@ class TestDivGradComp(unittest.TestCase): ...@@ -168,7 +168,7 @@ class TestDivGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -21,7 +21,7 @@ from paddle.fluid import core ...@@ -21,7 +21,7 @@ from paddle.fluid import core
def actual(primal, cotangent, axis, keep_dim): def actual(primal, cotangent, axis, keep_dim):
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal', primal.shape, primal.dtype) x = paddle.static.data('primal', primal.shape, primal.dtype)
...@@ -40,7 +40,7 @@ def actual(primal, cotangent, axis, keep_dim): ...@@ -40,7 +40,7 @@ def actual(primal, cotangent, axis, keep_dim):
def desired(primal, cotangent, axis, keep_dim): def desired(primal, cotangent, axis, keep_dim):
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program() mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp): with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal', primal.shape, primal.dtype) x = paddle.static.data('primal', primal.shape, primal.dtype)
......
...@@ -16,7 +16,7 @@ import unittest ...@@ -16,7 +16,7 @@ import unittest
from paddle.fluid import core from paddle.fluid import core
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
import autograd import autograd
import autograd.numpy import autograd.numpy
...@@ -60,7 +60,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -60,7 +60,7 @@ class TestTanhGradComp(unittest.TestCase):
self.x = paddle.randn([2, 4]) self.x = paddle.randn([2, 4])
self.x.stop_gradient = False self.x.stop_gradient = False
net = PrimeNet() net = PrimeNet()
core.set_prim_enabled(use_prim) core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn) net = apply_to_static(net, use_cinn)
out = net(self.x) out = net(self.x)
res = paddle.autograd.grad(out, [self.x]) res = paddle.autograd.grad(out, [self.x])
...@@ -109,7 +109,7 @@ class TestTanhGradComp(unittest.TestCase): ...@@ -109,7 +109,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol=1e-6, rtol=1e-6,
atol=0, atol=0,
) )
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -17,7 +17,7 @@ import unittest ...@@ -17,7 +17,7 @@ import unittest
from paddle.fluid import core from paddle.fluid import core
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
import parameterized as param import parameterized as param
......
...@@ -17,7 +17,7 @@ import unittest ...@@ -17,7 +17,7 @@ import unittest
from paddle.fluid import core from paddle.fluid import core
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
import parameterized as param import parameterized as param
...@@ -77,7 +77,7 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase): ...@@ -77,7 +77,7 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase):
) )
print(actual) print(actual)
self.assertEquals(actual, self.desired_ops) self.assertEquals(actual, self.desired_ops)
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -135,9 +135,9 @@ class TestResnet50Accuracy(unittest.TestCase): ...@@ -135,9 +135,9 @@ class TestResnet50Accuracy(unittest.TestCase):
loop_num = 10 loop_num = 10
feed = self.generate_random_data(loop_num) feed = self.generate_random_data(loop_num)
core.set_prim_enabled(True) core._set_prim_backward_enabled(True)
loss_c = self.train(place, loop_num, feed, use_cinn=True) loss_c = self.train(place, loop_num, feed, use_cinn=True)
core.set_prim_enabled(False) core._set_prim_backward_enabled(False)
loss_p = self.train(place, loop_num, feed, use_cinn=True) loss_p = self.train(place, loop_num, feed, use_cinn=True)
print("Losses of Composite + CINN:") print("Losses of Composite + CINN:")
print(loss_c) print(loss_c)
......
...@@ -218,7 +218,7 @@ def grad(outputs, inputs, grad_outputs=None): ...@@ -218,7 +218,7 @@ def grad(outputs, inputs, grad_outputs=None):
@framework.static_only @framework.static_only
def to_prim(blocks): def to_prim(blocks):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops.""" """Search nonbasic ops which have be registered composite rules and replace them with primitive ops."""
if not core.enable_prim_forward(): if not core._is_fwd_prim_enabled():
return return
if isinstance(blocks, paddle.fluid.framework.Block): if isinstance(blocks, paddle.fluid.framework.Block):
logging.info("Atomize composite op to primitive ops begin.") logging.info("Atomize composite op to primitive ops begin.")
...@@ -235,5 +235,6 @@ def to_prim(blocks): ...@@ -235,5 +235,6 @@ def to_prim(blocks):
f"Expect block or sequence of blocks, but got {type(blocks)}." f"Expect block or sequence of blocks, but got {type(blocks)}."
) )
with framework.program_guard(main_program): with framework.program_guard(main_program):
print("Running lowering for forward...")
primx._lower_composite(blocks) primx._lower_composite(blocks)
return return
...@@ -571,13 +571,10 @@ class PartialProgramLayer: ...@@ -571,13 +571,10 @@ class PartialProgramLayer:
targets.append(program.global_block().var(out.name)) targets.append(program.global_block().var(out.name))
if targets: if targets:
enable_prim = self._build_strategy.build_cinn_pass if self._build_strategy.build_cinn_pass:
if enable_prim and core.enable_prim_backward(): # TODO(Jiabin): Change this to True if we need this to be default option
core.set_prim_enabled(True) core.check_and_set_prim_all_enabled()
backward.gradients(targets=targets, inputs=[]) backward.gradients(targets=targets, inputs=[])
core.set_prim_enabled(False)
else:
backward.gradients(targets=targets, inputs=[])
start_idx = len(main_program.block(0).ops) + 2 * len( start_idx = len(main_program.block(0).ops) + 2 * len(
self._outputs.tolist() self._outputs.tolist()
......
...@@ -1092,8 +1092,9 @@ class ProgramCache: ...@@ -1092,8 +1092,9 @@ class ProgramCache:
def _build_once(self, cache_key): def _build_once(self, cache_key):
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim # TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass
if enable_prim and core.enable_prim_backward(): if enable_prim:
core.set_prim_enabled(True) # TODO(Jiabin): Change this to True if we need this to be default option
core.check_and_set_prim_all_enabled()
concrete_program = ConcreteProgram.from_func_spec( concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec, func_spec=cache_key.function_spec,
...@@ -1103,9 +1104,7 @@ class ProgramCache: ...@@ -1103,9 +1104,7 @@ class ProgramCache:
**cache_key.kwargs **cache_key.kwargs
) )
if enable_prim or core.enable_prim_forward() == "debug": concrete_program._to_prim()
concrete_program._to_prim()
core.set_prim_enabled(False)
return concrete_program, partial_program_from(concrete_program) return concrete_program, partial_program_from(concrete_program)
def __getitem__(self, item): def __getitem__(self, item):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册