未验证 提交 fca0a4bf 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Support fuse jit save (#52344)

* fix_prim

* fix bug

* add note

* fix logic

* fix

* add note

* fix check

* fix bug

* fix bug

* fix bug

* add debug

* fix check

* fix bug

* sync print log

* fix test case

* change default

* support jit save with fuse

* add more check

* sync with pr 52120

* add more ut

---------
Co-authored-by: Ncyber-pioneer <chenzhuo@tju.edu.cn>
上级 abf3701d
......@@ -908,6 +908,7 @@ def save(layer, path, input_spec=None, **configs):
# 1. input build & check
prog_translator = ProgramTranslator()
is_prim_infer = core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled()
if not prog_translator.enable_to_static:
raise RuntimeError(
"The paddle.jit.save doesn't work when setting 'paddle.jit.enable_to_static' to False."
......@@ -1021,7 +1022,9 @@ def save(layer, path, input_spec=None, **configs):
concrete_program = (
static_func.concrete_program_specify_input_spec(
inner_input_spec, with_hook=with_hook
inner_input_spec,
with_hook=with_hook,
is_prim_infer=is_prim_infer,
)
)
elif 'forward' == attr_func:
......@@ -1041,7 +1044,7 @@ def save(layer, path, input_spec=None, **configs):
)
concrete_program = (
static_forward.concrete_program_specify_input_spec(
with_hook=with_hook
with_hook=with_hook, is_prim_infer=is_prim_infer
)
)
# the input_spec has been used in declarative, which is equal to
......@@ -1061,7 +1064,7 @@ def save(layer, path, input_spec=None, **configs):
concrete_program = (
attr_func.concrete_program_specify_input_spec(
inner_input_spec
inner_input_spec, is_prim_infer=is_prim_infer
)
)
else:
......
......@@ -550,10 +550,13 @@ class StaticFunction:
with_hook = kwargs.get("with_hook", False)
is_train = kwargs.get("is_train", True)
is_prim_infer = kwargs.get("is_prim_infer", False)
if "is_train" in kwargs:
kwargs.pop("is_train")
if "with_hook" in kwargs:
kwargs.pop("with_hook")
if "is_prim_infer" in kwargs:
kwargs.pop("is_prim_infer")
# 1. unify args/kwargs and replace Tensor with InputSpec
if len(args) != len(self._function_spec.args_name):
args, kwargs = self._function_spec.unified_args_and_kwargs(
......@@ -574,9 +577,33 @@ class StaticFunction:
with_hook=with_hook,
is_train=is_train,
)
if is_prim_infer:
(
concrete_program,
partial_program_layer,
) = self._program_cache.get_program_without_cache(cache_key)
else:
# 3. check whether hit the cache or build a new program for the input arguments
concrete_program, partial_program_layer = self._program_cache[
cache_key
]
return concrete_program, partial_program_layer
# 3. check whether hit the cache or build a new program for the input arguments
concrete_program, partial_program_layer = self._program_cache[cache_key]
def get_concrete_program_with_cache_key(self, cached_key):
"""
Returns traced concrete program and inner executable partial layer by cached key.
Args:
cached_key(CacheKey): The cached key use to get concrete program.
Returns:
Traced ConcreteProgram and executable translated Layer.
"""
self._raise_when_property()
(
concrete_program,
partial_program_layer,
) = self._program_cache.get_program_without_cache(cached_key)
return concrete_program, partial_program_layer
def get_traced_count(self):
......@@ -634,7 +661,7 @@ class StaticFunction:
return self.concrete_program_specify_input_spec(input_spec=None)
def concrete_program_specify_input_spec(
self, input_spec=None, with_hook=False
self, input_spec=None, with_hook=False, is_prim_infer=False
):
"""
Returns recent ConcreteProgram instance of decorated function while
......@@ -652,6 +679,8 @@ class StaticFunction:
# else, return the last one.
cached_program_len = len(self._program_cache)
# If specific `input_spec`, apply convertion from dygraph layers into static Program.
# NOTE(jiabin): is_prim_infer indicates this method called by paddle.jit.save and it is worked in prim mode
if cached_program_len == 0:
desired_input_spec = input_spec
if self._function_spec.input_spec is not None:
......@@ -679,6 +708,7 @@ class StaticFunction:
*desired_input_spec,
with_hook=with_hook,
is_train=self._is_train_mode(),
is_prim_infer=is_prim_infer,
)
return concrete_program
else:
......@@ -690,9 +720,14 @@ class StaticFunction:
elif with_hook:
cache_key = self._program_cache._recent_cache_key
cache_key.kwargs["with_hook"] = True
concrete_program, _ = self._program_cache[cache_key]
return concrete_program
if not is_prim_infer:
concrete_program, _ = self._program_cache[cache_key]
return concrete_program
else:
concrete_program, _ = self.get_concrete_program_with_cache_key(
cache_key
)
return concrete_program
# If more than one programs have been cached, return the recent converted program by default.
elif cached_program_len > 1:
logging_utils.warn(
......@@ -700,12 +735,18 @@ class StaticFunction:
self._function_spec, cached_program_len
)
)
cache_key, (
concrete_program,
partial_layer,
) = self._program_cache.last()
return concrete_program
if not is_prim_infer:
cache_key, (
concrete_program,
partial_layer,
) = self._program_cache.last()
return concrete_program
else:
cache_key = self._program_cache._recent_cache_key
concrete_program, _ = self.get_concrete_program_with_cache_key(
cache_key
)
return concrete_program
def rollback(self):
"""
......@@ -1214,6 +1255,9 @@ class ProgramCache:
return self._caches[item_id]
def get_program_without_cache(self, cache_key):
return self._build_once(cache_key=cache_key)
def get_program(self, item):
if not isinstance(item, CacheKey):
raise ValueError(
......
......@@ -20,8 +20,11 @@ import numpy as np
from test_fetch_feed import Linear
import paddle
from paddle import fluid
import paddle.nn.functional as F
from paddle import fluid, nn
from paddle.fluid import core
from paddle.fluid.optimizer import AdamOptimizer
from paddle.nn import BatchNorm
np.random.seed(2020)
......@@ -30,6 +33,29 @@ place = (
)
class PrimeNet(paddle.nn.Layer):
def __init__(self, data_layout='NCHW'):
super().__init__()
self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False)
self.bn = BatchNorm(4, act="relu", data_layout=data_layout)
def forward(self, x):
y = self.conv(x)
out = self.bn(y)
res = F.max_pool2d(out, kernel_size=2, stride=2, padding=0)
return res
def apply_to_static(net):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = False
return paddle.jit.to_static(net, build_strategy=False)
def forward_post_hook_for_prim_net(layer, input, output):
return output * 2
class TestDyToStaticSaveLoad(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
......@@ -89,6 +115,87 @@ class TestDyToStaticSaveLoad(unittest.TestCase):
dygraph_loss.numpy(), static_loss.numpy(), rtol=1e-05
)
def test_save_load_prim(self):
with fluid.dygraph.guard(place):
self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
self.x.stop_gradient = False
net = PrimeNet(data_layout="NCHW")
core._set_prim_all_enabled(True)
net.eval()
static_net = apply_to_static(net)
res = static_net(self.x)
composite_program = static_net.forward.get_concrete_program(self.x)[
1
].train_program
comp_op_type_list = [
op.type for op in composite_program.block(0).ops
]
self.assertNotIn("batch_norm", comp_op_type_list)
self.assertNotIn("relu", comp_op_type_list)
self.assertNotIn("pow", comp_op_type_list)
self.assertNotIn("expand_v2", comp_op_type_list)
self.assertNotIn("unsqueeze2", comp_op_type_list)
self.assertNotIn("reduce_mean", comp_op_type_list)
self.assertNotIn("batch_norm_grad", comp_op_type_list)
self.assertNotIn("relu_grad", comp_op_type_list)
self.assertNotIn("pow_grad", comp_op_type_list)
self.assertNotIn("expand_v2_grad", comp_op_type_list)
self.assertNotIn("unsqueeze2_grad", comp_op_type_list)
self.assertNotIn("reduce_mean_grad", comp_op_type_list)
paddle.jit.save(static_net, self.model_path)
load_func = paddle.jit.load(self.model_path)
load_program = load_func.program()
print("load_program:", load_program)
load_op_type_list = [op.type for op in load_program.block(0).ops]
new_res = load_func(self.x)
self.assertIn("conv2d", load_op_type_list)
self.assertIn("batch_norm", load_op_type_list)
self.assertIn("relu", load_op_type_list)
self.assertIn("pool2d", load_op_type_list)
np.testing.assert_allclose(res.numpy(), new_res.numpy(), rtol=1e-05)
def test_save_load_prim_with_hook(self):
with fluid.dygraph.guard(place):
self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
self.x.stop_gradient = False
net = PrimeNet(data_layout="NCHW")
net.register_forward_post_hook(forward_post_hook_for_prim_net)
core._set_prim_all_enabled(True)
net.eval()
static_net = apply_to_static(net)
res = static_net(self.x)
composite_program = static_net.forward.get_concrete_program(self.x)[
1
].train_program
comp_op_type_list = [
op.type for op in composite_program.block(0).ops
]
self.assertNotIn("batch_norm", comp_op_type_list)
self.assertNotIn("relu", comp_op_type_list)
self.assertNotIn("pow", comp_op_type_list)
self.assertNotIn("expand_v2", comp_op_type_list)
self.assertNotIn("unsqueeze2", comp_op_type_list)
self.assertNotIn("reduce_mean", comp_op_type_list)
self.assertNotIn("batch_norm_grad", comp_op_type_list)
self.assertNotIn("relu_grad", comp_op_type_list)
self.assertNotIn("pow_grad", comp_op_type_list)
self.assertNotIn("expand_v2_grad", comp_op_type_list)
self.assertNotIn("unsqueeze2_grad", comp_op_type_list)
self.assertNotIn("reduce_mean_grad", comp_op_type_list)
self.assertNotIn("multiply_grad", comp_op_type_list)
paddle.jit.save(static_net, self.model_path)
load_func = paddle.jit.load(self.model_path)
load_program = load_func.program()
print("load_program:", load_program)
load_op_type_list = [op.type for op in load_program.block(0).ops]
new_res = load_func(self.x)
self.assertIn("conv2d", load_op_type_list)
self.assertIn("batch_norm", load_op_type_list)
self.assertIn("relu", load_op_type_list)
self.assertIn("pool2d", load_op_type_list)
np.testing.assert_allclose(res.numpy(), new_res.numpy(), rtol=1e-05)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册