未验证 提交 23d20e30 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Refactor prim flags system (#49930)

上级 44855da3
......@@ -1841,7 +1841,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if is_composite_grad_api and next_grad_node_creation_str != '':
next_grad_node_creation_str = f"""
if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
if (!paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
{next_grad_node_creation_str}
}}
"""
......@@ -2261,7 +2261,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif is_composite_grad_api:
grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
VLOG(4) << "Composite api {composite_grad_api_name} is called ";
}}else{{
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string.h>
#include <memory>
#include <sstream>
#include <string>
......@@ -166,7 +167,16 @@ Tensor full<DescTensor>(const IntArray& shape,
phi::errors::InvalidArgument(
"We only support float32/float16 for full, but we got data type: %s",
phi::DataTypeToString(dtype)));
op->SetAttr("value", value.to<float>());
if (dtype == phi::DataType::FLOAT32) {
op->SetAttr("value", value.to<float>());
} else if (dtype == phi::DataType::FLOAT64) {
op->SetAttr("str_value", std::to_string(value.to<double>()));
} else if (dtype == phi::DataType::FLOAT16) {
op->SetAttr("str_value", std::to_string(value.to<float>()));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"We only support float64/float32/float16 for full"));
}
op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
......
......@@ -192,7 +192,7 @@ void divide_grad(const Tensor& x,
} // indicate we will compute dy
if (dx) {
// dx = (1/y) * dout
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0);
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
auto tmp0 = divide<T>(one_tensor, y);
auto dx_res = multiply<T>(tmp0, out_grad);
if (y.dims() != x.dims()) {
......
......@@ -68,16 +68,16 @@ TEST(EagerPrim, TanhBackwardTest) {
paddle::experimental::Tensor out0 = tanh_ad_func(tensor0);
std::vector<paddle::experimental::Tensor> outs0 = {out0};
// Disable prim
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward
egr::Backward(outs0, {}, false);
paddle::experimental::Tensor out1 = tanh_ad_func(tensor1);
std::vector<paddle::experimental::Tensor> outs1 = {out1};
// Disable prim
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward
::egr::Backward(outs1, {}, false);
VLOG(7)
......@@ -99,10 +99,10 @@ TEST(EagerPrim, TanhBackwardTest) {
}
TEST(EagerPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
}
} // namespace prim
......
......@@ -341,10 +341,10 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
}
TEST(StaticPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
}
} // namespace prim
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace prim {
StaticCompositeContext* StaticCompositeContext::static_composite_context_ =
new StaticCompositeContext();
thread_local bool StaticCompositeContext::enable_prim_ = false;
thread_local bool StaticCompositeContext::enable_bwd_prim_ = false;
thread_local bool StaticCompositeContext::enable_fwd_prim_ = false;
} // namespace prim
} // namespace paddle
......@@ -56,9 +56,18 @@ class StaticCompositeContext {
return generator_->Generate(key);
}
void SetPrimEnabled(bool enable_prim) { enable_prim_ = enable_prim; }
void SetBwdPrimEnabled(bool enable_prim) { enable_bwd_prim_ = enable_prim; }
bool IsPrimEnabled() { return enable_prim_; }
bool IsBwdPrimEnabled() { return enable_bwd_prim_; }
void SetFwdPrimEnabled(bool enable_prim) { enable_fwd_prim_ = enable_prim; }
bool IsFwdPrimEnabled() { return enable_fwd_prim_; }
void SetAllPrimEnabled(bool enable_prim) {
enable_fwd_prim_ = enable_prim;
enable_bwd_prim_ = enable_prim;
}
private:
StaticCompositeContext()
......@@ -66,7 +75,8 @@ class StaticCompositeContext {
framework::BlockDesc* current_block_desc_;
std::unique_ptr<UniqueNameGenerator> generator_;
static thread_local bool enable_prim_;
static thread_local bool enable_bwd_prim_;
static thread_local bool enable_fwd_prim_;
static StaticCompositeContext* static_composite_context_;
DISABLE_COPY_AND_ASSIGN(StaticCompositeContext);
};
......
......@@ -19,12 +19,24 @@
PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not");
namespace paddle {
namespace prim {
bool PrimCommonUtils::IsPrimEnabled() {
return StaticCompositeContext::Instance().IsPrimEnabled();
bool PrimCommonUtils::IsBwdPrimEnabled() {
return StaticCompositeContext::Instance().IsBwdPrimEnabled();
}
void PrimCommonUtils::SetPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetPrimEnabled(enable_prim);
void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim);
}
bool PrimCommonUtils::IsFwdPrimEnabled() {
return StaticCompositeContext::Instance().IsFwdPrimEnabled();
}
void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim);
}
void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim);
}
} // namespace prim
} // namespace paddle
......@@ -18,8 +18,11 @@ namespace paddle {
namespace prim {
class PrimCommonUtils {
public:
static bool IsPrimEnabled();
static void SetPrimEnabled(bool enabled);
static bool IsBwdPrimEnabled();
static void SetBwdPrimEnabled(bool enabled);
static bool IsFwdPrimEnabled();
static void SetFwdPrimEnabled(bool enabled);
static void SetAllPrimEnabled(bool enabled);
};
} // namespace prim
} // namespace paddle
......@@ -660,8 +660,16 @@ PYBIND11_MODULE(libpaddle, m) {
return oss.str();
});
m.def("set_prim_enabled", &paddle::prim::PrimCommonUtils::SetPrimEnabled);
m.def("is_prim_enabled", &paddle::prim::PrimCommonUtils::IsPrimEnabled);
m.def("__set_bwd_prim_enabled",
&paddle::prim::PrimCommonUtils::SetBwdPrimEnabled);
m.def("_is_bwd_prim_enabled",
&paddle::prim::PrimCommonUtils::IsBwdPrimEnabled);
m.def("__set_fwd_prim_enabled",
&paddle::prim::PrimCommonUtils::SetFwdPrimEnabled);
m.def("_is_fwd_prim_enabled",
&paddle::prim::PrimCommonUtils::IsFwdPrimEnabled);
m.def("__set_all_prim_enabled",
&paddle::prim::PrimCommonUtils::SetAllPrimEnabled);
m.def("set_num_threads", &platform::SetNumThreads);
m.def("disable_signal_handler", &DisableSignalHandler);
......@@ -1264,8 +1272,9 @@ All parameter, weight, gradient are variables in Paddle.
// priority of GradCompOpMaker is less than GradCompMaker for better
// performance.
std::vector<std::unique_ptr<OpDesc>> grad_op_descs;
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {
if (grad_comp_op_maker != nullptr) {
VLOG(3) << "Runing composite fun for " << op_desc.Type();
grad_op_descs = grad_comp_op_maker(op_desc,
no_grad_set,
&grad_to_var,
......
......@@ -42,7 +42,7 @@
kernel :
func : add_grad
no_need_buffer : x, y
composite : add_grad(Tensor x, Tensor y, Tensor out_grad, int axis)
composite : add_grad(x, y, out_grad, axis)
backward : add_double_grad
inplace : (out_grad -> x_grad)
......@@ -390,7 +390,7 @@
param : [x, y]
kernel :
func : divide_grad
composite : divide_grad(Tensor x, Tensor y, Tensor out, Tensor out_grad, int axis = -1)
composite : divide_grad(x, y, out, out_grad, -1)
backward : divide_double_grad
- backward_op : dropout_grad
......@@ -1319,7 +1319,7 @@
kernel :
func : subtract_grad
no_need_buffer : x, y
composite : subtract_grad(Tensor x, Tensor y, Tensor out_grad, int axis)
composite : subtract_grad(x, y, out_grad, axis)
backward : subtract_double_grad
inplace : (out_grad -> x_grad)
......
......@@ -1493,14 +1493,15 @@ def _append_backward_ops_(
# remove some backward ops
# TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem
if not core.is_prim_enabled():
if not core._is_bwd_prim_enabled():
not_need_ops = _find_not_need_ops(
grad_op_descs, ops, input_grad_names_set
)
grad_op_descs = [
op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
]
else:
logging.debug("Runing backward composite and disable find_not_need_ops")
# append op_desc in grad_op_descs to target_block
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
......
......@@ -17,6 +17,7 @@ import sys
import os
import warnings
import platform
import logging
has_paddle_dy_lib = False
......@@ -305,8 +306,13 @@ try:
from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent
from .libpaddle import _set_current_stream
from .libpaddle import _get_phi_kernel_name
from .libpaddle import set_prim_enabled
from .libpaddle import is_prim_enabled
# prim controller flags
from .libpaddle import __set_bwd_prim_enabled
from .libpaddle import _is_bwd_prim_enabled
from .libpaddle import __set_fwd_prim_enabled
from .libpaddle import _is_fwd_prim_enabled
from .libpaddle import __set_all_prim_enabled
if sys.platform != 'win32':
from .libpaddle import _set_process_pids
......@@ -373,36 +379,98 @@ def set_paddle_lib_path():
set_paddle_lib_path()
# We have 3 FLAGS to judge whether prim is enabled
# FLAGS_prim_forward: Open or close forward prim strategy
# FLAGS_prim_backward: Open or close backward prim strategy
# FLAGS_prim_all: Open or close all prim strategy
#
#
# Priorities:
# if With CINN and Dy2St:
# # # _set_prim_all_enabled > FLAGS_prim_all > check_and_set_prim_all_enabled == _set_prim_backward_enabled == _set_prim_backward_enabled > FLAGS_prim_forward == FLAGS_prim_backward
# else:
# # # _set_prim_all_enabled > FLAGS_prim_all == check_and_set_prim_all_enabled == _set_prim_backward_enabled == _set_prim_backward_enabled > FLAGS_prim_forward == FLAGS_prim_backward
def __sync_stat_with_flag(flag):
if flag is "FLAGS_prim_forward":
flag_value = os.getenv("FLAGS_prim_forward")
assert flag_value is not None
flag_value = flag_value.lower()
if flag_value == "false":
__set_fwd_prim_enabled(False)
elif flag_value == "true":
__set_fwd_prim_enabled(True)
else:
raise TypeError(f"flag {flag} should be true or false.")
logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
elif flag is "FLAGS_prim_backward":
flag_value = os.getenv("FLAGS_prim_backward")
assert flag_value is not None
flag_value = flag_value.lower()
if flag_value == "false":
__set_bwd_prim_enabled(False)
elif flag_value == "true":
__set_bwd_prim_enabled(True)
else:
raise TypeError(f"flag {flag} should be true or false.")
logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
elif flag is "FLAGS_prim_all":
flag_value = os.getenv("FLAGS_prim_all")
assert flag_value is not None
flag_value = flag_value.lower()
if flag_value == "false":
__set_all_prim_enabled(False)
elif flag_value == "true":
__set_all_prim_enabled(True)
else:
raise TypeError(f"flag {flag} should be true or false.")
logging.debug(
"all prim enabled: ",
bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
)
else:
raise TypeError(
f"We only support FLAGS_prim_forward/FLAGS_prim_backward/FLAGS_prim_all but we got {flag}."
)
def set_prim_forward(value):
"""set flag FLAGS_prim_forward."""
flag = str(value)
if flag.lower() not in ["true", "false", "debug"]:
raise TypeError(f"flag {flag} should be string of bool or 'debug'.")
os.environ["FLAGS_prim_forward"] = flag
return
def _set_prim_backward_enabled(value):
__set_bwd_prim_enabled(bool(value))
logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
def enable_prim_forward():
flag = os.getenv("FLAGS_prim_forward", "true").lower()
if flag == "false":
return False
if flag == "debug":
return "debug"
return True
def _set_prim_forward_enabled(value):
__set_fwd_prim_enabled(bool(value))
logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
def set_prim_backward(value):
"""set flag FLAGS_prim_backward,"""
flag = str(value)
if flag.lower() not in ["true", "false"]:
raise TypeError(f"flag {flag} should be bool or string of bool.")
os.environ["FLAGS_prim_backward"] = flag
return
def _set_prim_all_enabled(value):
__set_all_prim_enabled(bool(value))
logging.debug(
"all prim enabled: ",
bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
)
def enable_prim_backward():
flag = os.getenv("FLAGS_prim_backward", "true")
if flag.lower() == "false":
return False
return True
def __sync_prim_backward_status():
flag_value = os.getenv("FLAGS_prim_backward")
if flag_value is None:
logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
else:
__sync_stat_with_flag("FLAGS_prim_backward")
def __sync_prim_forward_status():
flag_value = os.getenv("FLAGS_prim_forward")
if flag_value is None:
logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
else:
__sync_stat_with_flag("FLAGS_prim_forward")
def check_and_set_prim_all_enabled():
flag_value = os.getenv("FLAGS_prim_all")
if flag_value is None:
__sync_prim_backward_status()
__sync_prim_forward_status()
else:
__sync_stat_with_flag("FLAGS_prim_all")
......@@ -19,6 +19,7 @@ from utils import TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
def generate_data(shape, dtype="float32"):
......@@ -72,6 +73,7 @@ class TestCompositeSoftmax(unittest.TestCase):
def cal_composite(self, inputs):
paddle.enable_static()
core._set_prim_forward_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
......@@ -95,6 +97,7 @@ class TestCompositeSoftmax(unittest.TestCase):
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y])
paddle.disable_static()
core._set_prim_forward_enabled(False)
return res
def compare_forward(self):
......
......@@ -78,6 +78,7 @@ class TestCompositeSoftmax(unittest.TestCase):
def cal_composite_grad(self, inputs):
paddle.enable_static()
core._set_prim_all_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
......@@ -108,6 +109,7 @@ class TestCompositeSoftmax(unittest.TestCase):
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
core._set_prim_all_enabled(False)
return res
def compare_backward(self):
......@@ -139,7 +141,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
"test composite softmax and prim backward"
def setUp(self):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
self.dtypes = ["float32"]
self.shapes = [[2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
......
......@@ -236,11 +236,11 @@ class TestBert(unittest.TestCase):
self.verify_predict()
def test_train_composite(self):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
dygraph_loss, dygraph_ppl = self.train_dygraph(
self.bert_config, self.data_reader
)
......
......@@ -47,7 +47,6 @@ class TestPrimForward(unittest.TestCase):
"""
def setUp(self):
core.set_prim_backward(False)
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
......@@ -58,6 +57,7 @@ class TestPrimForward(unittest.TestCase):
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
core._set_prim_forward_enabled(use_prim)
if use_prim:
net = apply_to_static(net, use_prim)
......@@ -103,12 +103,12 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.x.stop_gradient = False
def train(self, use_prim):
core.set_prim_backward(True)
paddle.seed(2022)
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
core._set_prim_all_enabled(use_prim)
if use_prim:
net = apply_to_static(net, use_prim)
......
......@@ -427,10 +427,10 @@ class TestResnet(unittest.TestCase):
)
self.verify_predict()
def test_resnet_composite(self):
core.set_prim_enabled(True)
def test_resnet_composite_backward(self):
core._set_prim_backward_enabled(True)
static_loss = self.train(to_static=True)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
dygraph_loss = self.train(to_static=True)
np.testing.assert_allclose(
static_loss,
......@@ -440,65 +440,13 @@ class TestResnet(unittest.TestCase):
static_loss, dygraph_loss
),
)
core.set_prim_enabled(False)
def test_in_static_mode_mkldnn(self):
fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
if paddle.fluid.core.is_compiled_with_mkldnn():
self.resnet_helper.train(to_static=True)
finally:
fluid.set_flags({'FLAGS_use_mkldnn': False})
class TestResnetPrim(unittest.TestCase):
"test prim forward + prim backward + to_static"
def setUp(self):
self.resnet_helper = ResNetHelper()
def train(self, to_static):
paddle.jit.enable_to_static(to_static)
return self.resnet_helper.train(to_static)
def verify_predict(self):
image = np.random.random([1, 3, 224, 224]).astype('float32')
dy_pre = self.resnet_helper.predict_dygraph(image)
st_pre = self.resnet_helper.predict_static(image)
dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image)
predictor_pre = self.resnet_helper.predict_analysis_inference(image)
np.testing.assert_allclose(
dy_pre,
st_pre,
rtol=1e-05,
err_msg='dy_pre:\n {}\n, st_pre: \n{}.'.format(dy_pre, st_pre),
)
np.testing.assert_allclose(
dy_jit_pre,
st_pre,
rtol=1e-05,
err_msg='dy_jit_pre:\n {}\n, st_pre: \n{}.'.format(
dy_jit_pre, st_pre
),
)
np.testing.assert_allclose(
predictor_pre,
st_pre,
rtol=1e-05,
err_msg='predictor_pre:\n {}\n, st_pre: \n{}.'.format(
predictor_pre, st_pre
),
)
def test_resnet_composite(self):
def test_resnet_composite_forward_backward(self):
plat = platform.system()
if plat == "Linux":
print("=================== origin resnet ===================")
core.set_prim_enabled(False)
core._set_prim_all_enabled(True)
static_loss = self.train(to_static=True)
print("======= resnet with prim forward and backward =======")
core.set_prim_enabled(True)
core.set_prim_forward("debug")
core._set_prim_all_enabled(False)
dygraph_loss = self.train(to_static=True)
np.testing.assert_allclose(
static_loss,
......@@ -508,10 +456,17 @@ class TestResnetPrim(unittest.TestCase):
static_loss, dygraph_loss
),
)
core.set_prim_enabled(False)
else:
pass
def test_in_static_mode_mkldnn(self):
fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
if paddle.fluid.core.is_compiled_with_mkldnn():
self.resnet_helper.train(to_static=True)
finally:
fluid.set_flags({'FLAGS_use_mkldnn': False})
if __name__ == '__main__':
unittest.main()
......@@ -130,9 +130,9 @@ class TestResnet(unittest.TestCase):
)
def test_resnet_composite(self):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
static_loss = self.train(to_static=True)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
dygraph_loss = self.train(to_static=False)
np.testing.assert_allclose(
static_loss,
......
......@@ -137,9 +137,9 @@ class TestResnet(unittest.TestCase):
def test_resnet_composite(self):
if fluid.is_compiled_with_cuda():
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
static_loss = self.train(to_static=True)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
dygraph_loss = self.train(to_static=False)
# NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here.
np.testing.assert_allclose(
......
......@@ -426,9 +426,9 @@ class TestResnet(unittest.TestCase):
self.verify_predict()
def test_resnet_composite(self):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
static_loss = self.train(to_static=True)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
dygraph_loss = self.train(to_static=False)
np.testing.assert_allclose(
static_loss,
......
......@@ -9,3 +9,4 @@ foreach(TEST_OP ${TEST_OPS})
endforeach()
add_subdirectory(vjp)
add_subdirectory(flags)
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from paddle.fluid import core
class TestPrimFlags(unittest.TestCase):
def test_prim_flags(self):
self.assertFalse(core._is_bwd_prim_enabled())
self.assertFalse(core._is_fwd_prim_enabled())
os.environ['FLAGS_prim_backward'] = "True"
core.check_and_set_prim_all_enabled()
self.assertTrue(core._is_bwd_prim_enabled())
os.environ['FLAGS_prim_forward'] = "True"
core.check_and_set_prim_all_enabled()
self.assertTrue(core._is_fwd_prim_enabled())
os.environ['FLAGS_prim_all'] = "False"
core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_bwd_prim_enabled())
self.assertFalse(core._is_fwd_prim_enabled())
os.environ['FLAGS_prim_all'] = "True"
core.check_and_set_prim_all_enabled()
self.assertTrue(core._is_bwd_prim_enabled())
self.assertTrue(core._is_fwd_prim_enabled())
del os.environ['FLAGS_prim_all']
os.environ['FLAGS_prim_backward'] = "False"
core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_bwd_prim_enabled())
os.environ['FLAGS_prim_forward'] = "False"
core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_fwd_prim_enabled())
if __name__ == '__main__':
unittest.main()
......@@ -20,7 +20,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
@param.parameterized_class(
......@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
......@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
......@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
......
......@@ -20,7 +20,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
@param.parameterized_class(
......@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
......@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
......@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
......
......@@ -32,14 +32,14 @@ from paddle.fluid import core
class TestExpGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
cls.primal = cls.primal.astype(cls.dtype)
if cls.cotangent is not None:
cls.cotangent = cls.cotangent.astype(cls.dtype)
@classmethod
def tearDownClass(cls):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
def test_exp_grad_comp(self):
def actual(primal, cotangent):
......
......@@ -62,7 +62,7 @@ class TestExpandGradComp(unittest.TestCase):
@classmethod
def tearDownClass(cls):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
def test_comp(self):
def func(primal, cotangent, shape):
......@@ -74,11 +74,11 @@ class TestExpandGradComp(unittest.TestCase):
]
def actual(primal, cotangent, shape):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
return func(primal, cotangent, shape)
def desired(primal, cotangent, shape):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
return func(primal, cotangent, shape)
np.testing.assert_allclose(
......
......@@ -81,10 +81,10 @@ class TestMultiplyGradComp(unittest.TestCase):
return [g for g in grads if g is not None]
def test_comp(self):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
actual = self.vjp()
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
desired = self.vjp()
for i, j in zip(actual, desired):
......
......@@ -22,7 +22,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
@param.parameterized_class(
......@@ -63,7 +63,7 @@ class TestSqrtGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
......
......@@ -20,7 +20,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
@param.parameterized_class(
......@@ -67,7 +67,7 @@ class TestTanhGradComp(unittest.TestCase):
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
......@@ -78,7 +78,7 @@ class TestTanhGradComp(unittest.TestCase):
return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
......@@ -104,7 +104,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
......
......@@ -21,7 +21,7 @@ from paddle.fluid import core
def actual(primal, cotangent, axis, keep_dim):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False)
v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False)
y = paddle.sum(x, axis=axis, keepdim=keep_dim)
......@@ -30,7 +30,7 @@ def actual(primal, cotangent, axis, keep_dim):
def desired(primal, cotangent, axis, keep_dim):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False)
v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False)
y = paddle.sum(x, axis=axis, keepdim=keep_dim)
......
......@@ -20,7 +20,7 @@ import parameterized as param
import paddle
from paddle.fluid import core
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
@param.parameterized_class(
......@@ -74,7 +74,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
......
......@@ -81,7 +81,7 @@ class TestAddGradComp(unittest.TestCase):
self.x.stop_gradient = False
self.y.stop_gradient = False
net = PrimeNet()
core.set_prim_enabled(use_prim)
core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn)
out = net(self.x, self.y)
res = paddle.autograd.grad(out, [self.x, self.y])
......@@ -104,7 +104,7 @@ class TestAddGradComp(unittest.TestCase):
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
......@@ -126,7 +126,7 @@ class TestAddGradComp(unittest.TestCase):
return out[0], out[1]
def desired(primal0, primal1):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data(
......@@ -167,7 +167,7 @@ class TestAddGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
......
......@@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase):
self.x.stop_gradient = False
self.y.stop_gradient = False
net = PrimeNet()
core.set_prim_enabled(use_prim)
core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn)
out = net(self.x, self.y)
res = paddle.autograd.grad(out, [self.x, self.y])
......@@ -107,7 +107,7 @@ class TestDivGradComp(unittest.TestCase):
paddle.enable_static()
def actual(primal0, primal1):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
......@@ -130,7 +130,7 @@ class TestDivGradComp(unittest.TestCase):
return out[0], out[1]
def desired(primal0, primal1):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data(
......@@ -172,7 +172,7 @@ class TestDivGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
paddle.disable_static()
......
......@@ -81,7 +81,7 @@ class TestDivGradComp(unittest.TestCase):
self.x.stop_gradient = False
self.y.stop_gradient = False
net = PrimeNet()
core.set_prim_enabled(use_prim)
core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn)
out = net(self.x, self.y)
res = paddle.autograd.grad(out, [self.x, self.y])
......@@ -104,7 +104,7 @@ class TestDivGradComp(unittest.TestCase):
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
......@@ -126,7 +126,7 @@ class TestDivGradComp(unittest.TestCase):
return out[0], out[1]
def desired(primal0, primal1):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data(
......@@ -167,7 +167,7 @@ class TestDivGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
......
......@@ -33,14 +33,14 @@ from paddle.fluid import core
class TestExpGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
cls.primal = cls.primal.astype(cls.dtype)
if cls.cotangent is not None:
cls.cotangent = cls.cotangent.astype(cls.dtype)
@classmethod
def tearDownClass(cls):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
def setUp(self):
paddle.enable_static()
......
......@@ -71,7 +71,7 @@ class TestExpandGradComp(unittest.TestCase):
@classmethod
def tearDownClass(cls):
paddle.disable_static()
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
def test_comp(self):
def func(primal, cotangent, shape):
......@@ -93,11 +93,11 @@ class TestExpandGradComp(unittest.TestCase):
)[0]
def actual(primal, cotangent, shape):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
return func(primal, cotangent, shape)
def desired(primal, cotangent, shape):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
return func(primal, cotangent, shape)
np.testing.assert_allclose(
......
......@@ -108,10 +108,10 @@ class TestMultiplyGradComp(unittest.TestCase):
def test_comp(self):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
actual = self.vjp()
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
desired = self.vjp()
self.assertEqual(len(actual), len(desired))
......
......@@ -16,7 +16,7 @@ import unittest
from paddle.fluid import core
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
import autograd
import autograd.numpy
......@@ -60,7 +60,7 @@ class TestSqrtGradComp(unittest.TestCase):
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
net = PrimeNet()
core.set_prim_enabled(use_prim)
core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn)
out = net(self.x)
res = paddle.autograd.grad(out, [self.x])
......@@ -109,7 +109,7 @@ class TestSqrtGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
......
......@@ -82,7 +82,7 @@ class TestDivGradComp(unittest.TestCase):
self.x.stop_gradient = False
self.y.stop_gradient = False
net = PrimeNet()
core.set_prim_enabled(use_prim)
core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn)
out = net(self.x, self.y)
res = paddle.autograd.grad(out, [self.x, self.y])
......@@ -105,7 +105,7 @@ class TestDivGradComp(unittest.TestCase):
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
......@@ -127,7 +127,7 @@ class TestDivGradComp(unittest.TestCase):
return out[0], out[1]
def desired(primal0, primal1):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data(
......@@ -168,7 +168,7 @@ class TestDivGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
......
......@@ -21,7 +21,7 @@ from paddle.fluid import core
def actual(primal, cotangent, axis, keep_dim):
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal', primal.shape, primal.dtype)
......@@ -40,7 +40,7 @@ def actual(primal, cotangent, axis, keep_dim):
def desired(primal, cotangent, axis, keep_dim):
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal', primal.shape, primal.dtype)
......
......@@ -16,7 +16,7 @@ import unittest
from paddle.fluid import core
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
import autograd
import autograd.numpy
......@@ -60,7 +60,7 @@ class TestTanhGradComp(unittest.TestCase):
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
net = PrimeNet()
core.set_prim_enabled(use_prim)
core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn)
out = net(self.x)
res = paddle.autograd.grad(out, [self.x])
......@@ -109,7 +109,7 @@ class TestTanhGradComp(unittest.TestCase):
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
......
......@@ -17,7 +17,7 @@ import unittest
from paddle.fluid import core
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
import parameterized as param
......
......@@ -17,7 +17,7 @@ import unittest
from paddle.fluid import core
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
import parameterized as param
......@@ -77,7 +77,7 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase):
)
print(actual)
self.assertEquals(actual, self.desired_ops)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
......
......@@ -135,9 +135,9 @@ class TestResnet50Accuracy(unittest.TestCase):
loop_num = 10
feed = self.generate_random_data(loop_num)
core.set_prim_enabled(True)
core._set_prim_backward_enabled(True)
loss_c = self.train(place, loop_num, feed, use_cinn=True)
core.set_prim_enabled(False)
core._set_prim_backward_enabled(False)
loss_p = self.train(place, loop_num, feed, use_cinn=True)
print("Losses of Composite + CINN:")
print(loss_c)
......
......@@ -218,7 +218,7 @@ def grad(outputs, inputs, grad_outputs=None):
@framework.static_only
def to_prim(blocks):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops."""
if not core.enable_prim_forward():
if not core._is_fwd_prim_enabled():
return
if isinstance(blocks, paddle.fluid.framework.Block):
logging.info("Atomize composite op to primitive ops begin.")
......@@ -235,5 +235,6 @@ def to_prim(blocks):
f"Expect block or sequence of blocks, but got {type(blocks)}."
)
with framework.program_guard(main_program):
print("Running lowering for forward...")
primx._lower_composite(blocks)
return
......@@ -571,13 +571,10 @@ class PartialProgramLayer:
targets.append(program.global_block().var(out.name))
if targets:
enable_prim = self._build_strategy.build_cinn_pass
if enable_prim and core.enable_prim_backward():
core.set_prim_enabled(True)
backward.gradients(targets=targets, inputs=[])
core.set_prim_enabled(False)
else:
backward.gradients(targets=targets, inputs=[])
if self._build_strategy.build_cinn_pass:
# TODO(Jiabin): Change this to True if we need this to be default option
core.check_and_set_prim_all_enabled()
backward.gradients(targets=targets, inputs=[])
start_idx = len(main_program.block(0).ops) + 2 * len(
self._outputs.tolist()
......
......@@ -1092,8 +1092,9 @@ class ProgramCache:
def _build_once(self, cache_key):
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass
if enable_prim and core.enable_prim_backward():
core.set_prim_enabled(True)
if enable_prim:
# TODO(Jiabin): Change this to True if we need this to be default option
core.check_and_set_prim_all_enabled()
concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec,
......@@ -1103,9 +1104,7 @@ class ProgramCache:
**cache_key.kwargs
)
if enable_prim or core.enable_prim_forward() == "debug":
concrete_program._to_prim()
core.set_prim_enabled(False)
concrete_program._to_prim()
return concrete_program, partial_program_from(concrete_program)
def __getitem__(self, item):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册