未验证 提交 5579edfb 编写于 作者: L LiuChiachi 提交者: GitHub

save dtype of inputs (#28301)

上级 c47bfe98
...@@ -200,16 +200,22 @@ def prepare_distributed_context(place=None): ...@@ -200,16 +200,22 @@ def prepare_distributed_context(place=None):
return strategy return strategy
def _update_input_shapes(inputs): def _update_input_info(inputs):
"Get input shape list by given inputs in Model initialization." "Get input shape list by given inputs in Model initialization."
shapes = None shapes = None
dtypes = None
if isinstance(inputs, Input): if isinstance(inputs, Input):
shapes = [list(inputs.shape)] shapes = [list(inputs.shape)]
dtypes = [inputs.dtype]
elif isinstance(inputs, list): elif isinstance(inputs, list):
shapes = [list(input.shape) for input in inputs] shapes = [list(input.shape) for input in inputs]
dtypes = [input.dtype for input in inputs]
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
shapes = [list(inputs[name].shape) for name in inputs] shapes = [list(inputs[name].shape) for name in inputs]
return shapes dtypes = [inputs[name].dtype for name in inputs]
else:
return None
return shapes, dtypes
class StaticGraphAdapter(object): class StaticGraphAdapter(object):
...@@ -617,7 +623,7 @@ class DynamicGraphAdapter(object): ...@@ -617,7 +623,7 @@ class DynamicGraphAdapter(object):
'test_batch': 0 'test_batch': 0
} }
self._input_shapes = None self._input_info = None
if self._nranks > 1: if self._nranks > 1:
stradegy = fluid.dygraph.parallel.ParallelStrategy() stradegy = fluid.dygraph.parallel.ParallelStrategy()
stradegy.nranks = ParallelEnv().nranks stradegy.nranks = ParallelEnv().nranks
...@@ -642,7 +648,7 @@ class DynamicGraphAdapter(object): ...@@ -642,7 +648,7 @@ class DynamicGraphAdapter(object):
self.model.network.train() self.model.network.train()
self.mode = 'train' self.mode = 'train'
inputs = to_list(inputs) inputs = to_list(inputs)
self._input_shapes = _update_input_shapes(inputs) self._input_info = _update_input_info(inputs)
labels = labels or [] labels = labels or []
labels = [to_variable(l) for l in to_list(labels)] labels = [to_variable(l) for l in to_list(labels)]
...@@ -679,7 +685,7 @@ class DynamicGraphAdapter(object): ...@@ -679,7 +685,7 @@ class DynamicGraphAdapter(object):
self.model.network.eval() self.model.network.eval()
self.mode = 'eval' self.mode = 'eval'
inputs = to_list(inputs) inputs = to_list(inputs)
self._input_shapes = _update_input_shapes(inputs) self._input_info = _update_input_info(inputs)
labels = labels or [] labels = labels or []
labels = [to_variable(l) for l in to_list(labels)] labels = [to_variable(l) for l in to_list(labels)]
...@@ -728,7 +734,7 @@ class DynamicGraphAdapter(object): ...@@ -728,7 +734,7 @@ class DynamicGraphAdapter(object):
self.model.network.eval() self.model.network.eval()
self.mode = 'test' self.mode = 'test'
inputs = [to_variable(x) for x in to_list(inputs)] inputs = [to_variable(x) for x in to_list(inputs)]
self._input_shapes = _update_input_shapes(inputs) self._input_info = _update_input_info(inputs)
outputs = self.model.network.forward(*inputs) outputs = self.model.network.forward(*inputs)
if self._nranks > 1 and isinstance(self.model._place, fluid.CUDAPlace): if self._nranks > 1 and isinstance(self.model._place, fluid.CUDAPlace):
outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)] outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)]
...@@ -875,7 +881,7 @@ class Model(object): ...@@ -875,7 +881,7 @@ class Model(object):
self._loss = None self._loss = None
self._loss_weights = None self._loss_weights = None
self._optimizer = None self._optimizer = None
self._input_shapes = None self._input_info = None
self._is_shape_inferred = False self._is_shape_inferred = False
self._test_dataloader = None self._test_dataloader = None
...@@ -884,7 +890,7 @@ class Model(object): ...@@ -884,7 +890,7 @@ class Model(object):
raise TypeError( raise TypeError(
"'inputs' must be list or dict, and couldn't be None.") "'inputs' must be list or dict, and couldn't be None.")
elif inputs: elif inputs:
self._input_shapes = _update_input_shapes(inputs) self._input_info = _update_input_info(inputs)
self._inputs = self._verify_spec(inputs, is_input=True) self._inputs = self._verify_spec(inputs, is_input=True)
self._labels = self._verify_spec(labels) self._labels = self._verify_spec(labels)
...@@ -941,7 +947,7 @@ class Model(object): ...@@ -941,7 +947,7 @@ class Model(object):
print(loss) print(loss)
""" """
loss = self._adapter.train_batch(inputs, labels) loss = self._adapter.train_batch(inputs, labels)
if fluid.in_dygraph_mode() and self._input_shapes is None: if fluid.in_dygraph_mode() and self._input_info is None:
self._update_inputs() self._update_inputs()
return loss return loss
...@@ -992,7 +998,7 @@ class Model(object): ...@@ -992,7 +998,7 @@ class Model(object):
print(loss) print(loss)
""" """
loss = self._adapter.eval_batch(inputs, labels) loss = self._adapter.eval_batch(inputs, labels)
if fluid.in_dygraph_mode() and self._input_shapes is None: if fluid.in_dygraph_mode() and self._input_info is None:
self._update_inputs() self._update_inputs()
return loss return loss
...@@ -1036,7 +1042,7 @@ class Model(object): ...@@ -1036,7 +1042,7 @@ class Model(object):
print(out) print(out)
""" """
loss = self._adapter.predict_batch(inputs) loss = self._adapter.predict_batch(inputs)
if fluid.in_dygraph_mode() and self._input_shapes is None: if fluid.in_dygraph_mode() and self._input_info is None:
self._update_inputs() self._update_inputs()
return loss return loss
...@@ -1750,14 +1756,15 @@ class Model(object): ...@@ -1750,14 +1756,15 @@ class Model(object):
if fluid.in_dygraph_mode(): if fluid.in_dygraph_mode():
with fluid.framework._dygraph_guard(None): with fluid.framework._dygraph_guard(None):
layer = self.network layer = self.network
if self._input_shapes is None: # No provided or inferred if self._input_info is None: # No provided or inferred
raise RuntimeError( raise RuntimeError(
"Saving inference model needs 'inputs' or running before saving. Please specify 'inputs' in Model initialization or input training data and perform a training for shape derivation." "Saving inference model needs 'inputs' or running before saving. Please specify 'inputs' in Model initialization or input training data and perform a training for shape derivation."
) )
if self._is_shape_inferred: if self._is_shape_inferred:
warnings.warn( warnings.warn(
"'inputs' was not specified when Model initialization, so the input shape to be saved will be the shape derived from the user's actual inputs. The input shape to be saved is %s. For saving correct input shapes, please provide 'inputs' for Model initialization." "'inputs' was not specified when Model initialization, so the input shape to be saved will be the shape derived from the user's actual inputs. The input shape to be saved is %s. For saving correct input shapes, please provide 'inputs' for Model initialization."
% self._input_shapes) % self._input_info[0])
layer.forward = paddle.jit.to_static( layer.forward = paddle.jit.to_static(
layer.forward, input_spec=self._inputs) layer.forward, input_spec=self._inputs)
...@@ -1945,7 +1952,7 @@ class Model(object): ...@@ -1945,7 +1952,7 @@ class Model(object):
_input_size = self._inputs _input_size = self._inputs
return summary(self.network, _input_size, dtype) return summary(self.network, _input_size, dtype)
def _verify_spec(self, specs, shapes=None, is_input=False): def _verify_spec(self, specs, shapes=None, dtypes=None, is_input=False):
out_specs = [] out_specs = []
if specs is None: if specs is None:
...@@ -1954,10 +1961,12 @@ class Model(object): ...@@ -1954,10 +1961,12 @@ class Model(object):
if is_input: if is_input:
arg_names = extract_args(self.network.forward)[1:] arg_names = extract_args(self.network.forward)[1:]
if shapes is not None and fluid.in_dygraph_mode(): # While Saving inference model in dygraph, and providing inputs only in running.
if shapes is not None and dtypes is not None and fluid.in_dygraph_mode(
):
out_specs = [ out_specs = [
Input( Input(
name=n, shape=shapes[i]) name=n, dtype=dtypes[i], shape=shapes[i])
for i, n in enumerate(arg_names) for i, n in enumerate(arg_names)
] ]
else: else:
...@@ -2000,6 +2009,8 @@ class Model(object): ...@@ -2000,6 +2009,8 @@ class Model(object):
def _update_inputs(self): def _update_inputs(self):
"Update self._inputs according to given inputs." "Update self._inputs according to given inputs."
self._input_shapes = self._adapter._input_shapes self._input_info = self._adapter._input_info
self._is_shape_inferred = True if self._input_info is not None and len(self._input_info) == 2:
self._inputs = self._verify_spec(None, self._input_shapes, True) self._inputs = self._verify_spec(None, self._input_info[0],
self._input_info[1], True)
self._is_shape_inferred = True
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册