未验证 提交 1e7dc9c0 编写于 作者: W WangZhen 提交者: GitHub

Check -1 shape for input_spec and program when prim or cinn enabled (#50473)

* Check -1 shape for input_spec and program when prim or cinn enabled

* Polish neg shape check

* Polish code

* Fix UT

* Fix UT in static
上级 8ad635d5
......@@ -20,6 +20,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.jit.dy2static.utils import _compatible_non_tensor_spec
from paddle.static import InputSpec
......@@ -331,5 +332,32 @@ class TestCompatibleNonTensorSpec(unittest.TestCase):
)
class NegSpecNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.linear = paddle.nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
class TestNegSpecWithPrim(unittest.TestCase):
def setUp(self):
paddle.disable_static()
core._set_prim_all_enabled(True)
def tearDown(self):
core._set_prim_all_enabled(False)
def test_run(self):
net = NegSpecNet()
net = paddle.jit.to_static(
net, input_spec=[paddle.static.InputSpec(shape=[-1, 10])]
)
x = paddle.randn([2, 10])
out = net(x)
np.testing.assert_equal(out.shape, [2, 5])
if __name__ == '__main__':
unittest.main()
......@@ -1069,32 +1069,6 @@ class PartialProgramLayer:
return vars if vars else None
def _create_fake_var():
"""
Create a fake_var (force on CPU) to handle empty input or output
"""
if not framework.global_var._in_eager_mode_:
return [
core.VarBase(
core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
]
else:
return [
core.eager.Tensor(
core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
]
def partial_program_from(concrete_program):
inputs = concrete_program.inputs
if inputs and isinstance(inputs[0], layers.Layer):
......
......@@ -48,6 +48,7 @@ from .utils import (
func_to_source_code,
input_specs_compatible,
make_hashable,
prim_or_cinn_is_enabled,
type_name,
unwrap,
)
......@@ -320,6 +321,17 @@ class StaticFunction:
self._dygraph_function = function
self._class_instance = None
if input_spec is not None and prim_or_cinn_is_enabled(
kwargs.get("build_strategy", None)
):
for spec in input_spec:
if spec is not None and -1 in spec.shape:
input_spec = None
warnings.warn(
'Now prim and cinn do not support -1 shape, but input_spec has -1 shape so we set it to None.'
)
break
self._input_spec = input_spec
self._function_spec = FunctionSpec(function, input_spec)
self._program_cache = ProgramCache()
......@@ -1046,17 +1058,6 @@ class ConcreteProgram:
)
def _extract_indeed_params_buffers(class_instance):
"""
To filter not initialzed buffers.
"""
params = list(get_parameters(class_instance).values())
buffers = list(get_buffers(class_instance).values())
buffers = [buffer for buffer in buffers if len(buffer.shape) != 0]
return params + buffers
class ParametersRecorder:
def __init__(self):
self.params_dict = {}
......@@ -1177,6 +1178,15 @@ class ProgramCache:
else:
raise
if prim_or_cinn_is_enabled(cache_key.kwargs['build_strategy']):
for var in concrete_program.main_program.list_vars():
if -1 in var.shape:
warnings.warn(
"Now prim and cinn do not support -1 shape, but the shape of var {} is {}".format(
var.name, var.shape
)
)
concrete_program._to_prim()
return concrete_program, partial_program_from(concrete_program)
......
......@@ -1493,7 +1493,7 @@ def _param_grad_names(program_desc, params):
Parse PARAM@GARD name from original train and infer program.
"""
names = []
# NOTE: `names` and `self._params` must be in the same order so that
# NOTE: `names` and `params` must be in the same order so that
# the param grad name can be set correctly in the run_program.
for param in params:
candidate = [
......@@ -1523,3 +1523,25 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
var_name = op.output('Out')[0]
names.append(var_name)
return names
def prim_or_cinn_is_enabled(build_strategy):
if build_strategy is not None and build_strategy.build_cinn_pass:
return True
if core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled():
return True
env_flags = [
'FLAGS_prim_forward',
'FLAGS_prim_backward',
'FLAGS_prim_all',
'FLAGS_use_cinn',
]
for flag in env_flags:
value = os.getenv(flag)
if value is None:
continue
elif value.lower() in ['true', '1']:
return True
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册