未验证 提交 8e6d5d2b 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Consider InputSpec.name to calculate Cachekey hash id (#38273)

* Consider InputSpec.name to calculate Cachekey hash id

* fix function
上级 a858326a
......@@ -394,3 +394,34 @@ def _set_spec_stop_gradient(spec, stop_gradient):
"""
assert isinstance(spec, paddle.static.InputSpec)
spec.stop_gradient = stop_gradient
def _hash_spec_names(args_specs, kwargs_specs):
"""
Generater hash spec with args/kwargs InputSpec names.
Consider the following InputSpecs with same shape/dtype except for name:
1. [InputSpec([3,3], 'float32', 'x'), InputSpec([3,3], 'float32', 'x')]
2. [InputSpec([3,3], 'float32', 'x'), InputSpec([3,3], 'float32', 'y')]
Under @to_static, we should generate two different program not just one, because
the former has one input ('x'), but the latter has two input ('x', 'y').
"""
spec_names = [
spec.name for spec in flatten(args_specs)
if isinstance(spec, paddle.static.InputSpec)
]
spec_names += [
spec.name for spec in flatten(kwargs_specs)
if isinstance(spec, paddle.static.InputSpec)
]
i, name_ids = 0, {}
def to_idx(name):
nonlocal i
if name not in name_ids:
name_ids[name] = i
i += 1
return name_ids[name]
value = [to_idx(name) for name in spec_names]
return tuple(value)
......@@ -43,7 +43,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import input_specs_compatible
from paddle.fluid.dygraph.dygraph_to_static.utils import type_name
from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap
from paddle.fluid.dygraph.dygraph_to_static.utils import make_hashable
from paddle.fluid.dygraph.dygraph_to_static.function_spec import FunctionSpec
from paddle.fluid.dygraph.dygraph_to_static.function_spec import FunctionSpec, _hash_spec_names
from paddle.fluid.dygraph.dygraph_to_static.function_spec import get_buffers, get_parameters
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
......@@ -147,7 +147,7 @@ class CacheKey(object):
"""
__slots__ = [
'function_spec', 'input_args_with_spec', 'input_kwargs_with_spec',
'class_instance', 'kwargs'
'class_instance', 'kwargs', '_spec_names_id'
]
def __init__(self, function_spec, input_args_with_spec,
......@@ -168,6 +168,8 @@ class CacheKey(object):
self.class_instance = class_instance
# NOTE: `kwargs` is usually not considered as basic member for `__hash__`
self.kwargs = kwargs
self._spec_names_id = _hash_spec_names(input_args_with_spec,
input_kwargs_with_spec)
@classmethod
def from_func_and_args(cls, function_spec, args, kwargs, class_instance):
......@@ -197,7 +199,7 @@ class CacheKey(object):
return hash((id(self.function_spec),
make_hashable(self.input_args_with_spec, error_msg),
make_hashable(self.input_kwargs_with_spec, error_msg),
self.class_instance))
self._spec_names_id, self.class_instance))
def __eq__(self, other):
return (type(self) is type(other)) and hash(self) == hash(other)
......@@ -703,6 +705,7 @@ class ProgramCache(object):
"""
def __init__(self):
# {hash_id : (concrete_program, partial_layer)}
self._caches = collections.OrderedDict()
def _build_once(self, cache_key):
......@@ -718,9 +721,9 @@ class ProgramCache(object):
if not isinstance(item, CacheKey):
raise ValueError('type(item) should be CacheKey, but received %s' %
type_name(item))
if item not in self._caches:
self._caches[item] = self._build_once(item)
item_id = hash(item)
if item_id not in self._caches:
self._caches[item_id] = self._build_once(item)
# Note: raise warnings if number of traced program is more than `max_tracing_count`
current_tracing_count = len(self._caches)
if current_tracing_count > MAX_TRACED_PROGRAM_COUNT:
......@@ -729,18 +732,19 @@ class ProgramCache(object):
"The reason may be: (1) passing tensors with different shapes, (2) passing python objects instead of tensors.".
format(current_tracing_count, MAX_TRACED_PROGRAM_COUNT))
return self._caches[item]
return self._caches[item_id]
def get_program(self, item):
if not isinstance(item, CacheKey):
raise ValueError(
"Input item's type should be FunctionSpec, but received %s" %
type_name(item))
if item not in self._caches:
item_id = hash(item)
if item_id not in self._caches:
raise RuntimeError(
"Failed to find program for input item, please decorate input function by `@paddle.jit.to_static`."
)
return self._caches[item]
return self._caches[item_id]
def last(self):
assert len(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.nn import Layer
import numpy as np
import unittest
class Net(Layer):
def __init__(self):
super(Net, self).__init__()
self.fc = paddle.nn.Linear(16, 3)
def forward(self, x, y, m, n):
inputs = [x, y, m, n]
outs = []
for var in inputs:
out = paddle.reshape(x, [-1, 16])
out = self.fc(out)
outs.append(out)
out = paddle.stack(outs)
return paddle.sum(out)
class TestArgsSpecName(unittest.TestCase):
def read_from_dataset(self):
self.x = paddle.randn([4, 2, 8])
self.y = paddle.randn([4, 2, 8])
self.m = paddle.randn([4, 2, 8])
self.n = paddle.randn([4, 2, 8])
def test_spec_name_hash(self):
net = Net()
net = paddle.jit.to_static(net)
# Convert into program with four input
self.read_from_dataset()
self.run_test(net, [self.x, self.y, self.m, self.n], 1, [0, 1, 2, 3])
# Convert into program with three input
self.read_from_dataset()
self.run_test(net, [self.x, self.x, self.m, self.n], 2, [0, 0, 1, 2])
# Convert into program with two input
self.read_from_dataset()
self.run_test(net, [self.x, self.x, self.m, self.m], 3, [0, 0, 1, 1])
# Use Cache Program
self.read_from_dataset()
self.run_test(net, [self.n, self.n, self.y, self.y], 3, [0, 0, 1, 1])
# Convert into program with two input
self.read_from_dataset()
self.run_test(net, [self.x, self.y, self.x, self.y], 4, [0, 1, 0, 1])
# Use Cache Program
self.read_from_dataset()
self.run_test(net, [self.m, self.n, self.m, self.n], 4, [0, 1, 0, 1])
# Convert into program with one input
self.read_from_dataset()
self.run_test(net, [self.x, self.x, self.x, self.x], 5, [0, 0, 0, 0])
# Use Cache Program
self.read_from_dataset()
self.run_test(net, [self.m, self.m, self.m, self.m], 5, [0, 0, 0, 0])
def run_test(self, net, inputs, trace_count, mode):
out = net(*inputs)
self.assertEqual(net.forward.get_traced_count(), trace_count)
self.assert_feed_mode(net.forward.inputs, mode)
def assert_feed_mode(self, inputs, expect_mode):
assert isinstance(inputs, list)
assert isinstance(expect_mode, list)
in_names = [var.name for var in inputs]
i, name_ids = 0, {}
def to_idx(name):
nonlocal i
if name not in name_ids:
name_ids[name] = i
i += 1
return name_ids[name]
mode = [to_idx(name) for name in in_names]
self.assertEquals(mode, expect_mode)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册