未验证 提交 d15cbe70 编写于 作者: L LiuChiachi 提交者: GitHub

Remove Input requirement in dygraph for Model (#27557)

* remove input requirment in dygraph Model

* correct unittest

* upadte save inference model in dygraph without input

* fix unittets for test_model.py

* solve conflicts

* solve conflicts

* delete http.log

* fix test_model.py bug, correct initialization of MyModel

* fix unittests bugs

* set paddle manual seed for unittest

* fix Model bugs, because inputs can be list or dict when it is provided.

* add random seed for test_export_deploy_model

* delete redundant codes, because  calls

* Code optimization, error information optimization
上级 3a8bef1d
......@@ -200,6 +200,15 @@ def prepare_distributed_context(place=None):
return strategy
def _update_input_shapes(inputs):
shapes = None
if isinstance(inputs, list):
shapes = [list(input.shape) for input in inputs]
elif isinstance(inputs, dict):
shapes = [list(inputs[name].shape) for name in inputs]
return shapes
class StaticGraphAdapter(object):
"""
Model traning/inference with a static graph.
......@@ -598,6 +607,7 @@ class DynamicGraphAdapter(object):
'test_batch': 0
}
self._input_shapes = None
if self._nranks > 1:
stradegy = fluid.dygraph.parallel.ParallelStrategy()
stradegy.nranks = ParallelEnv().nranks
......@@ -622,6 +632,7 @@ class DynamicGraphAdapter(object):
self.model.network.train()
self.mode = 'train'
inputs = to_list(inputs)
self._input_shapes = _update_input_shapes(inputs)
labels = labels or []
labels = [to_variable(l) for l in to_list(labels)]
......@@ -656,6 +667,7 @@ class DynamicGraphAdapter(object):
self.model.network.eval()
self.mode = 'eval'
inputs = to_list(inputs)
self._input_shapes = _update_input_shapes(inputs)
labels = labels or []
labels = [to_variable(l) for l in to_list(labels)]
......@@ -704,6 +716,7 @@ class DynamicGraphAdapter(object):
self.model.network.eval()
self.mode = 'test'
inputs = [to_variable(x) for x in to_list(inputs)]
self._input_shapes = _update_input_shapes(inputs)
outputs = self.model.network.forward(*inputs)
if self._nranks > 1 and isinstance(self.model._place, fluid.CUDAPlace):
outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)]
......@@ -778,7 +791,7 @@ class DynamicGraphAdapter(object):
if not hasattr(self.model._optimizer, 'set_state_dict'):
warnings.warn(
"paddle.fluid.optimizer is deprecated in API 2.0, please use paddle.optimizer instead"
"paddle.fluid.optimizer is deprecated in API 2.0, please use paddle.optimizer instead."
)
self.model._optimizer.set_dict(converted_state)
else:
......@@ -792,14 +805,15 @@ class Model(object):
switched by `paddle.disable_static()`. The usage is as follows.
But note, the switching between dynamic and static should be before
instantiating a Model. The input description, i.e, paddle.static.InputSpec,
must be required.
must be required for static graph.
Args:
network (paddle.nn.Layer): The network is an instance of
paddle.nn.Layer.
inputs (InputSpec|list|dict|None): `inputs`, entry points of network,
could be a InputSpec instance, or lits of InputSpec instances,
or dict ({name: InputSpec}), and it couldn't be None.
or dict ({name: InputSpec}), and it couldn't be None in static
graph.
labels (InputSpec|list|None): `labels`, entry points of network,
could be a InputSpec instnace or lits of InputSpec instances,
or None. For static graph, if labels is required in loss,
......@@ -844,14 +858,18 @@ class Model(object):
self._loss = None
self._loss_weights = None
self._optimizer = None
self._optimizer = None
self._input_shapes = None
self._is_shape_inferred = False
self._test_dataloader = None
if not isinstance(inputs, (list, dict, Input)):
raise TypeError(
"'inputs' must be list or dict in static graph mode")
if not in_dygraph_mode():
if not isinstance(inputs, (list, dict, Input)):
raise TypeError(
"'inputs' must be list or dict, and couldn't be None.")
elif inputs:
self._input_shapes = _update_input_shapes(inputs)
self._inputs = self._verify_spec(inputs, True)
self._inputs = self._verify_spec(inputs, is_input=True)
self._labels = self._verify_spec(labels)
# init backend
......@@ -902,7 +920,12 @@ class Model(object):
loss = model.train_batch([data], [label])
print(loss)
"""
return self._adapter.train_batch(inputs, labels)
loss = self._adapter.train_batch(inputs, labels)
if fluid.in_dygraph_mode() and self._input_shapes is None:
self._input_shapes = self._adapter._input_shapes
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
return loss
def eval_batch(self, inputs, labels=None):
"""
......@@ -947,7 +970,12 @@ class Model(object):
loss = model.eval_batch([data], [label])
print(loss)
"""
return self._adapter.eval_batch(inputs, labels)
loss = self._adapter.eval_batch(inputs, labels)
if fluid.in_dygraph_mode() and self._input_shapes is None:
self._input_shapes = self._adapter._input_shapes
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
return loss
def test_batch(self, inputs):
"""
......@@ -987,7 +1015,12 @@ class Model(object):
out = model.test_batch([data])
print(out)
"""
return self._adapter.test_batch(inputs)
loss = self._adapter.test_batch(inputs)
if fluid.in_dygraph_mode() and self._input_shapes is None:
self._input_shapes = self._adapter._input_shapes
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
return loss
def save(self, path, training=True):
"""
......@@ -1677,6 +1710,14 @@ class Model(object):
if fluid.in_dygraph_mode():
with fluid.framework._dygraph_guard(None):
layer = self.network
if self._input_shapes is None: # No provided or inferred
raise RuntimeError(
"Saving inference model needs 'inputs' or running before saving. Please specify 'inputs' in Model initialization or input training zqqdata and perform a training for shape derivation."
)
if self._is_shape_inferred:
warnings.warn(
"'inputs' was not specified when Model initialization, so the input shape to be saved will be the shape derived from the user's actual inputs. The input shape to be saved is %s. For saving correct input shapes, please provide 'inputs' for Model initialization."
% self._input_shapes)
layer.forward = paddle.jit.to_static(
layer.forward, input_spec=self._inputs)
......@@ -1775,6 +1816,7 @@ class Model(object):
data = flatten(data)
# LoDTensor.shape is callable, where LoDTensor comes from
# DataLoader in static graph
batch_size = data[0].shape()[0] if callable(data[
0].shape) else data[0].shape[0]
......@@ -1864,10 +1906,26 @@ class Model(object):
_input_size = self._inputs
return summary(self.network, _input_size, dtype)
def _verify_spec(self, specs, is_input=False):
def _verify_spec(self, specs, shapes=None, is_input=False):
out_specs = []
if isinstance(specs, dict):
if specs is None:
# Note(Aurelius84): If not specific specs of `Input`, using argument names of `forward` function
# to generate `Input`. But how can we know the actual shape of each input tensor?
if is_input:
arg_names = extract_args(self.network.forward)[1:]
if shapes is not None and fluid.in_dygraph_mode():
out_specs = [
Input(
name=n, shape=shapes[i])
for i, n in enumerate(arg_names)
]
else:
out_specs = [Input(name=n, shape=[None]) for n in arg_names]
else:
out_specs = to_list(specs)
elif isinstance(specs, dict):
assert is_input == False
out_specs = [specs[n] \
for n in extract_args(self.network.forward) if n != 'self']
......
......@@ -66,34 +66,6 @@ class LeNetDygraph(paddle.nn.Layer):
return x
class LeNetDeclarative(fluid.dygraph.Layer):
def __init__(self, num_classes=10):
super(LeNetDeclarative, self).__init__()
self.num_classes = num_classes
self.features = Sequential(
Conv2d(
1, 6, 3, stride=1, padding=1),
ReLU(),
Pool2D(2, 'max', 2),
Conv2d(
6, 16, 5, stride=1, padding=0),
ReLU(),
Pool2D(2, 'max', 2))
if num_classes > 0:
self.fc = Sequential(
Linear(400, 120), Linear(120, 84), Linear(84, 10))
@declarative
def forward(self, inputs):
x = self.features(inputs)
if self.num_classes > 0:
x = fluid.layers.flatten(x, 1)
x = self.fc(x)
return x
class MnistDataset(MNIST):
def __init__(self, mode, return_label=True, sample_num=None):
super(MnistDataset, self).__init__(mode=mode)
......@@ -440,9 +412,7 @@ class TestModelFunction(unittest.TestCase):
# dynamic saving
device = paddle.set_device('cpu')
fluid.enable_dygraph(device)
inputs = [InputSpec([None, 20], 'float32', 'x')]
labels = [InputSpec([None, 1], 'int64', 'label')]
model = Model(MyModel(), inputs, labels)
model = Model(MyModel())
optim = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=model.parameters())
model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
......@@ -545,6 +515,8 @@ class TestModelFunction(unittest.TestCase):
paddle.summary(nlp_net, (1, 1, 2))
def test_export_deploy_model(self):
self.set_seed()
np.random.seed(2020)
for dynamic in [True, False]:
paddle.disable_static() if dynamic else None
prog_translator = ProgramTranslator()
......@@ -579,6 +551,35 @@ class TestModelFunction(unittest.TestCase):
shutil.rmtree(save_dir)
paddle.enable_static()
def test_dygraph_export_deploy_model_without_inputs(self):
mnist_data = MnistDataset(mode='train')
paddle.disable_static()
for initial in ["fit", "train_batch", "eval_batch", "test_batch"]:
save_dir = tempfile.mkdtemp()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
net = LeNet()
model = Model(net)
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=model.parameters())
model.prepare(
optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
if initial == "fit":
model.fit(mnist_data, batch_size=64, verbose=0)
else:
img = np.array(
np.random.random((1, 1, 28, 28)), dtype=np.float32)
label = np.array(np.random.rand(1, 1), dtype=np.int64)
if initial == "train_batch":
model.train_batch([img], [label])
elif initial == "eval_batch":
model.eval_batch([img], [label])
else:
model.test_batch([img])
model.save(save_dir, training=False)
shutil.rmtree(save_dir)
class TestRaiseError(unittest.TestCase):
def test_input_without_name(self):
......@@ -589,13 +590,22 @@ class TestRaiseError(unittest.TestCase):
with self.assertRaises(ValueError):
model = Model(net, inputs, labels)
def test_input_without_input_spec(self):
for dynamic in [True, False]:
paddle.disable_static() if dynamic else None
net = MyModel()
with self.assertRaises(TypeError):
model = Model(net)
paddle.enable_static()
def test_static_without_inputs(self):
paddle.enable_static()
net = MyModel()
with self.assertRaises(TypeError):
model = Model(net)
def test_save_infer_model_without_inputs_and_run_in_dygraph(self):
paddle.disable_static()
net = MyModel()
save_dir = tempfile.mkdtemp()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
with self.assertRaises(RuntimeError):
model = Model(net)
model.save(save_dir, training=False)
paddle.enable_static()
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册