未验证 提交 5579edfb 编写于 作者: L LiuChiachi 提交者: GitHub

save dtype of inputs (#28301)

上级 c47bfe98
......@@ -200,16 +200,22 @@ def prepare_distributed_context(place=None):
return strategy
def _update_input_shapes(inputs):
def _update_input_info(inputs):
"Get input shape list by given inputs in Model initialization."
shapes = None
dtypes = None
if isinstance(inputs, Input):
shapes = [list(inputs.shape)]
dtypes = [inputs.dtype]
elif isinstance(inputs, list):
shapes = [list(input.shape) for input in inputs]
dtypes = [input.dtype for input in inputs]
elif isinstance(inputs, dict):
shapes = [list(inputs[name].shape) for name in inputs]
return shapes
dtypes = [inputs[name].dtype for name in inputs]
else:
return None
return shapes, dtypes
class StaticGraphAdapter(object):
......@@ -617,7 +623,7 @@ class DynamicGraphAdapter(object):
'test_batch': 0
}
self._input_shapes = None
self._input_info = None
if self._nranks > 1:
stradegy = fluid.dygraph.parallel.ParallelStrategy()
stradegy.nranks = ParallelEnv().nranks
......@@ -642,7 +648,7 @@ class DynamicGraphAdapter(object):
self.model.network.train()
self.mode = 'train'
inputs = to_list(inputs)
self._input_shapes = _update_input_shapes(inputs)
self._input_info = _update_input_info(inputs)
labels = labels or []
labels = [to_variable(l) for l in to_list(labels)]
......@@ -679,7 +685,7 @@ class DynamicGraphAdapter(object):
self.model.network.eval()
self.mode = 'eval'
inputs = to_list(inputs)
self._input_shapes = _update_input_shapes(inputs)
self._input_info = _update_input_info(inputs)
labels = labels or []
labels = [to_variable(l) for l in to_list(labels)]
......@@ -728,7 +734,7 @@ class DynamicGraphAdapter(object):
self.model.network.eval()
self.mode = 'test'
inputs = [to_variable(x) for x in to_list(inputs)]
self._input_shapes = _update_input_shapes(inputs)
self._input_info = _update_input_info(inputs)
outputs = self.model.network.forward(*inputs)
if self._nranks > 1 and isinstance(self.model._place, fluid.CUDAPlace):
outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)]
......@@ -875,7 +881,7 @@ class Model(object):
self._loss = None
self._loss_weights = None
self._optimizer = None
self._input_shapes = None
self._input_info = None
self._is_shape_inferred = False
self._test_dataloader = None
......@@ -884,7 +890,7 @@ class Model(object):
raise TypeError(
"'inputs' must be list or dict, and couldn't be None.")
elif inputs:
self._input_shapes = _update_input_shapes(inputs)
self._input_info = _update_input_info(inputs)
self._inputs = self._verify_spec(inputs, is_input=True)
self._labels = self._verify_spec(labels)
......@@ -941,7 +947,7 @@ class Model(object):
print(loss)
"""
loss = self._adapter.train_batch(inputs, labels)
if fluid.in_dygraph_mode() and self._input_shapes is None:
if fluid.in_dygraph_mode() and self._input_info is None:
self._update_inputs()
return loss
......@@ -992,7 +998,7 @@ class Model(object):
print(loss)
"""
loss = self._adapter.eval_batch(inputs, labels)
if fluid.in_dygraph_mode() and self._input_shapes is None:
if fluid.in_dygraph_mode() and self._input_info is None:
self._update_inputs()
return loss
......@@ -1036,7 +1042,7 @@ class Model(object):
print(out)
"""
loss = self._adapter.predict_batch(inputs)
if fluid.in_dygraph_mode() and self._input_shapes is None:
if fluid.in_dygraph_mode() and self._input_info is None:
self._update_inputs()
return loss
......@@ -1750,14 +1756,15 @@ class Model(object):
if fluid.in_dygraph_mode():
with fluid.framework._dygraph_guard(None):
layer = self.network
if self._input_shapes is None: # No provided or inferred
if self._input_info is None: # No provided or inferred
raise RuntimeError(
"Saving inference model needs 'inputs' or running before saving. Please specify 'inputs' in Model initialization or input training data and perform a training for shape derivation."
)
if self._is_shape_inferred:
warnings.warn(
"'inputs' was not specified when Model initialization, so the input shape to be saved will be the shape derived from the user's actual inputs. The input shape to be saved is %s. For saving correct input shapes, please provide 'inputs' for Model initialization."
% self._input_shapes)
% self._input_info[0])
layer.forward = paddle.jit.to_static(
layer.forward, input_spec=self._inputs)
......@@ -1945,7 +1952,7 @@ class Model(object):
_input_size = self._inputs
return summary(self.network, _input_size, dtype)
def _verify_spec(self, specs, shapes=None, is_input=False):
def _verify_spec(self, specs, shapes=None, dtypes=None, is_input=False):
out_specs = []
if specs is None:
......@@ -1954,10 +1961,12 @@ class Model(object):
if is_input:
arg_names = extract_args(self.network.forward)[1:]
if shapes is not None and fluid.in_dygraph_mode():
# While Saving inference model in dygraph, and providing inputs only in running.
if shapes is not None and dtypes is not None and fluid.in_dygraph_mode(
):
out_specs = [
Input(
name=n, shape=shapes[i])
name=n, dtype=dtypes[i], shape=shapes[i])
for i, n in enumerate(arg_names)
]
else:
......@@ -2000,6 +2009,8 @@ class Model(object):
def _update_inputs(self):
"Update self._inputs according to given inputs."
self._input_shapes = self._adapter._input_shapes
self._input_info = self._adapter._input_info
if self._input_info is not None and len(self._input_info) == 2:
self._inputs = self._verify_spec(None, self._input_info[0],
self._input_info[1], True)
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册