未验证 提交 0cc63cc3 编写于 作者: A Aurelius84 提交者: GitHub

[Paddle2.0] Rename hapi.Input and move `data` api (#26396)

* Rename `Input` into `InputSpec`

* fix argument place of Input api
上级 a57d63a0
......@@ -93,8 +93,9 @@ def create_paddle_case(op_type, callback):
def test_broadcast_api_1(self):
with program_guard(Program(), Program()):
x = paddle.nn.data(name='x', shape=[1, 2, 1, 3], dtype='int32')
y = paddle.nn.data(name='y', shape=[1, 2, 3], dtype='int32')
x = paddle.static.data(
name='x', shape=[1, 2, 1, 3], dtype='int32')
y = paddle.static.data(name='y', shape=[1, 2, 3], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
......
......@@ -54,7 +54,7 @@ class TestCumsumOp(unittest.TestCase):
def run_static(self, use_gpu=False):
with fluid.program_guard(fluid.Program()):
data_np = np.random.random((100, 100)).astype(np.float32)
x = paddle.nn.data('X', [100, 100])
x = paddle.static.data('X', [100, 100])
y = paddle.cumsum(x)
y2 = paddle.cumsum(x, axis=0)
y3 = paddle.cumsum(x, axis=-1)
......@@ -100,7 +100,7 @@ class TestCumsumOp(unittest.TestCase):
def test_name(self):
with fluid.program_guard(fluid.Program()):
x = paddle.nn.data('x', [3, 4])
x = paddle.static.data('x', [3, 4])
y = paddle.cumsum(x, name='out')
self.assertTrue('out' in y.name)
......
......@@ -145,19 +145,22 @@ class TestFlatten2OpError(unittest.TestCase):
x = x.astype('float32')
def test_ValueError1():
x_var = paddle.nn.data(name="x", shape=image_shape, dtype='float32')
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32')
out = paddle.flatten(x_var, start_axis=2, stop_axis=1)
self.assertRaises(ValueError, test_ValueError1)
def test_ValueError2():
x_var = paddle.nn.data(name="x", shape=image_shape, dtype='float32')
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32')
paddle.flatten(x_var, start_axis=10, stop_axis=1)
self.assertRaises(ValueError, test_ValueError2)
def test_ValueError3():
x_var = paddle.nn.data(name="x", shape=image_shape, dtype='float32')
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32')
paddle.flatten(x_var, start_axis=2, stop_axis=10)
self.assertRaises(ValueError, test_ValueError3)
......
......@@ -32,7 +32,7 @@ class ApiMaxTest(unittest.TestCase):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.nn.data("data", shape=[10, 10], dtype="float32")
data = paddle.static.data("data", shape=[10, 10], dtype="float32")
result_max = paddle.max(x=data, axis=1)
exe = paddle.static.Executor(self.place)
input_data = np.random.rand(10, 10).astype(np.float32)
......@@ -41,7 +41,7 @@ class ApiMaxTest(unittest.TestCase):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.nn.data("data", shape=[10, 10], dtype="int64")
data = paddle.static.data("data", shape=[10, 10], dtype="int64")
result_max = paddle.max(x=data, axis=0)
exe = paddle.static.Executor(self.place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int64)
......@@ -50,7 +50,7 @@ class ApiMaxTest(unittest.TestCase):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.nn.data("data", shape=[10, 10], dtype="int64")
data = paddle.static.data("data", shape=[10, 10], dtype="int64")
result_max = paddle.max(x=data, axis=(0, 1))
exe = paddle.static.Executor(self.place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int64)
......@@ -71,8 +71,8 @@ class ApiMaxTest(unittest.TestCase):
def test_axis_type():
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.nn.data("data", shape=[10, 10], dtype="int64")
axis = paddle.nn.data("axis", shape=[10, 10], dtype="int64")
data = paddle.static.data("data", shape=[10, 10], dtype="int64")
axis = paddle.static.data("axis", shape=[10, 10], dtype="int64")
result_min = paddle.min(data, axis)
self.assertRaises(TypeError, test_axis_type)
......
......@@ -36,8 +36,8 @@ class ApiMaximumTest(unittest.TestCase):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_x = paddle.nn.data("x", shape=[10, 15], dtype="float32")
data_y = paddle.nn.data("y", shape=[10, 15], dtype="float32")
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_y = paddle.static.data("y", shape=[10, 15], dtype="float32")
result_max = paddle.maximum(data_x, data_y)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
......@@ -48,8 +48,8 @@ class ApiMaximumTest(unittest.TestCase):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_x = paddle.nn.data("x", shape=[10, 15], dtype="float32")
data_z = paddle.nn.data("z", shape=[15], dtype="float32")
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_z = paddle.static.data("z", shape=[15], dtype="float32")
result_max = paddle.maximum(data_x, data_z, axis=1)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
......
......@@ -32,7 +32,7 @@ class ApiMinTest(unittest.TestCase):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.nn.data("data", shape=[10, 10], dtype="float32")
data = paddle.static.data("data", shape=[10, 10], dtype="float32")
result_min = paddle.min(x=data, axis=1)
exe = paddle.static.Executor(self.place)
input_data = np.random.rand(10, 10).astype(np.float32)
......@@ -41,7 +41,7 @@ class ApiMinTest(unittest.TestCase):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.nn.data("data", shape=[10, 10], dtype="int64")
data = paddle.static.data("data", shape=[10, 10], dtype="int64")
result_min = paddle.min(x=data, axis=0)
exe = paddle.static.Executor(self.place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int64)
......@@ -50,7 +50,7 @@ class ApiMinTest(unittest.TestCase):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.nn.data("data", shape=[10, 10], dtype="int64")
data = paddle.static.data("data", shape=[10, 10], dtype="int64")
result_min = paddle.min(x=data, axis=(0, 1))
exe = paddle.static.Executor(self.place)
input_data = np.random.randint(10, size=(10, 10)).astype(np.int64)
......@@ -71,8 +71,8 @@ class ApiMinTest(unittest.TestCase):
def test_axis_type():
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data = paddle.nn.data("data", shape=[10, 10], dtype="int64")
axis = paddle.nn.data("axis", shape=[10, 10], dtype="int64")
data = paddle.static.data("data", shape=[10, 10], dtype="int64")
axis = paddle.static.data("axis", shape=[10, 10], dtype="int64")
result_min = paddle.min(data, axis)
self.assertRaises(TypeError, test_axis_type)
......
......@@ -36,8 +36,8 @@ class ApiMinimumTest(unittest.TestCase):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_x = paddle.nn.data("x", shape=[10, 15], dtype="float32")
data_y = paddle.nn.data("y", shape=[10, 15], dtype="float32")
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_y = paddle.static.data("y", shape=[10, 15], dtype="float32")
result_min = paddle.minimum(data_x, data_y)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
......@@ -48,8 +48,8 @@ class ApiMinimumTest(unittest.TestCase):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_x = paddle.nn.data("x", shape=[10, 15], dtype="float32")
data_z = paddle.nn.data("z", shape=[15], dtype="float32")
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_z = paddle.static.data("z", shape=[15], dtype="float32")
result_min = paddle.minimum(data_x, data_z, axis=1)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
......
......@@ -26,8 +26,10 @@ class TestMultiplyAPI(unittest.TestCase):
def __run_static_graph_case(self, x_data, y_data, axis=-1):
with program_guard(Program(), Program()):
x = paddle.nn.data(name='x', shape=x_data.shape, dtype=x_data.dtype)
y = paddle.nn.data(name='y', shape=y_data.shape, dtype=y_data.dtype)
x = paddle.static.data(
name='x', shape=x_data.shape, dtype=x_data.dtype)
y = paddle.static.data(
name='y', shape=y_data.shape, dtype=y_data.dtype)
res = tensor.multiply(x, y, axis=axis)
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
......@@ -109,14 +111,14 @@ class TestMultiplyError(unittest.TestCase):
# test static computation graph: dtype can not be int8
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.nn.data(name='x', shape=[100], dtype=np.int8)
y = paddle.nn.data(name='y', shape=[100], dtype=np.int8)
x = paddle.static.data(name='x', shape=[100], dtype=np.int8)
y = paddle.static.data(name='y', shape=[100], dtype=np.int8)
self.assertRaises(TypeError, tensor.multiply, x, y)
# test static computation graph: inputs must be broadcastable
with program_guard(Program(), Program()):
x = paddle.nn.data(name='x', shape=[20, 50], dtype=np.float64)
y = paddle.nn.data(name='y', shape=[20], dtype=np.float64)
x = paddle.static.data(name='x', shape=[20, 50], dtype=np.float64)
y = paddle.static.data(name='y', shape=[20], dtype=np.float64)
self.assertRaises(fluid.core.EnforceNotMet, tensor.multiply, x, y)
np.random.seed(7)
......
......@@ -54,9 +54,11 @@ def create_test_case(margin, reduction):
margin=margin,
reduction=reduction)
with program_guard(Program(), Program()):
x = paddle.nn.data(name="x", shape=[10, 10], dtype="float64")
y = paddle.nn.data(name="y", shape=[10, 10], dtype="float64")
label = paddle.nn.data(
x = paddle.static.data(
name="x", shape=[10, 10], dtype="float64")
y = paddle.static.data(
name="y", shape=[10, 10], dtype="float64")
label = paddle.static.data(
name="label", shape=[10, 10], dtype="float64")
result = paddle.nn.functional.margin_ranking_loss(
x, y, label, margin, reduction)
......@@ -78,9 +80,11 @@ def create_test_case(margin, reduction):
margin=margin,
reduction=reduction)
with program_guard(Program(), Program()):
x = paddle.nn.data(name="x", shape=[10, 10], dtype="float64")
y = paddle.nn.data(name="y", shape=[10, 10], dtype="float64")
label = paddle.nn.data(
x = paddle.static.data(
name="x", shape=[10, 10], dtype="float64")
y = paddle.static.data(
name="y", shape=[10, 10], dtype="float64")
label = paddle.static.data(
name="label", shape=[10, 10], dtype="float64")
margin_rank_loss = paddle.nn.loss.MarginRankingLoss(
margin=margin, reduction=reduction)
......
......@@ -45,7 +45,7 @@ class TestNNSigmoidAPI(unittest.TestCase):
main_program = paddle.static.Program()
mysigmoid = nn.Sigmoid(name="api_sigmoid")
with paddle.static.program_guard(main_program):
x = paddle.nn.data(name='x', shape=self.x_shape)
x = paddle.static.data(name='x', shape=self.x_shape)
x.stop_gradient = False
y = mysigmoid(x)
fluid.backward.append_backward(paddle.mean(y))
......@@ -86,7 +86,7 @@ class TestNNFunctionalSigmoidAPI(unittest.TestCase):
paddle.enable_static()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = paddle.nn.data(name='x', shape=self.x_shape)
x = paddle.static.data(name='x', shape=self.x_shape)
y = functional.sigmoid(x, name="api_sigmoid")
exe = paddle.static.Executor(fluid.CPUPlace())
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y])
......
......@@ -125,7 +125,7 @@ class TestRandintAPI(unittest.TestCase):
out4 = paddle.randint(
low=-100, high=100, shape=[dim_1, 5, dim_2], dtype='int32')
# shape is a tensor and dtype is 'float64'
var_shape = paddle.nn.data(
var_shape = paddle.static.data(
name='var_shape', shape=[2], dtype="int64")
out5 = paddle.randint(
low=1, high=1000, shape=var_shape, dtype='int64')
......
......@@ -34,7 +34,7 @@ class TestRandnOp(unittest.TestCase):
dim_2 = paddle.fill_constant([1], "int32", 50)
x3 = paddle.randn([dim_1, dim_2, 784])
var_shape = paddle.nn.data('X', [2], 'int32')
var_shape = paddle.static.data('X', [2], 'int32')
x4 = paddle.randn(var_shape)
place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
......
......@@ -295,8 +295,8 @@ class ProgBarLogger(Callback):
import paddle.fluid as fluid
import paddle.incubate.hapi as hapi
inputs = [hapi.Input('image', [-1, 1, 28, 28], 'float32')]
labels = [hapi.Input('label', [None, 1], 'int64')]
inputs = [hapi.Input([-1, 1, 28, 28], 'float32', 'image')]
labels = [hapi.Input([None, 1], 'int64', 'label')]
train_dataset = hapi.datasets.MNIST(mode='train')
......@@ -431,8 +431,8 @@ class ModelCheckpoint(Callback):
import paddle.fluid as fluid
import paddle.incubate.hapi as hapi
inputs = [hapi.Input('image', [-1, 1, 28, 28], 'float32')]
labels = [hapi.Input('label', [None, 1], 'int64')]
inputs = [hapi.Input([-1, 1, 28, 28], 'float32', 'image')]
labels = [hapi.Input([None, 1], 'int64', 'label')]
train_dataset = hapi.datasets.MNIST(mode='train')
......
......@@ -25,6 +25,8 @@ import warnings
from collections import Iterable
from paddle import fluid
# Note: Use alias `Input` temporarily before releasing hapi feature.
from paddle.static import InputSpec as Input
from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
......@@ -47,40 +49,6 @@ __all__ = [
]
class Input(fluid.dygraph.Layer):
"""
Define inputs the model.
Args:
name (str): The name/alias of the variable, see :ref:`api_guide_Name`
for more details.
shape (tuple(integers)|list[integers]): List|Tuple of integers
declaring the shape. You can set "None" or -1 at a dimension
to indicate the dimension can be of any size. For example,
it is useful to set changeable batch size as "None" or -1.
dtype (np.dtype|VarType|str, optional): The type of the data. Supported
dtype: bool, float16, float32, float64, int8, int16, int32, int64,
uint8. Default: float32.
Examples:
.. code-block:: python
import paddle.incubate.hapi as hapi
input = hapi.Input('x', [None, 784], 'float32')
label = hapi.Input('label', [None, 1], 'int64')
"""
def __init__(self, name, shape=None, dtype='float32'):
super(Input, self).__init__()
self.shape = shape
self.dtype = dtype
self.name = name
def forward(self):
return fluid.data(self.name, shape=self.shape, dtype=self.dtype)
class StaticGraphAdapter(object):
"""
Model traning/inference with a static graph.
......@@ -388,8 +356,8 @@ class StaticGraphAdapter(object):
with fluid.program_guard(prog, self._startup_prog):
inputs = self.model._inputs
labels = self.model._labels if self.model._labels else []
inputs = [k.forward() for k in to_list(inputs)]
labels = [k.forward() for k in to_list(labels)]
inputs = [k._create_feed_layer() for k in to_list(inputs)]
labels = [k._create_feed_layer() for k in to_list(labels)]
self._label_vars[mode] = labels
outputs = to_list(self.model.network.forward(*inputs))
......@@ -704,8 +672,8 @@ class Model(object):
fluid.enable_dygraph(device)
# inputs and labels are not required for dynamic graph.
input = hapi.Input('x', [None, 784], 'float32')
label = hapi.Input('label', [None, 1], 'int64')
input = hapi.Input([None, 784], 'float32', 'x')
label = hapi.Input([None, 1], 'int64', 'label')
model = hapi.Model(MyNet(), input, label)
optim = fluid.optimizer.SGD(learning_rate=1e-3,
......@@ -734,16 +702,8 @@ class Model(object):
if not isinstance(inputs, (list, dict, Input)):
raise TypeError(
"'inputs' must be list or dict in static graph mode")
if inputs is None:
self._inputs = [Input(name=n) \
for n in extract_args(self.network.forward) if n != 'self']
elif isinstance(input, dict):
self._inputs = [inputs[n] \
for n in extract_args(self.network.forward) if n != 'self']
else:
self._inputs = to_list(inputs)
self._labels = to_list(labels)
self._inputs = self._verify_spec(inputs, True)
self._labels = self._verify_spec(labels)
# init backend
if fluid.in_dygraph_mode():
......@@ -787,8 +747,8 @@ class Model(object):
device = hapi.set_device('gpu')
fluid.enable_dygraph(device)
input = hapi.Input('x', [None, 784], 'float32')
label = hapi.Input('label', [None, 1], 'int64')
input = hapi.Input([None, 784], 'float32', 'x')
label = hapi.Input([None, 1], 'int64', 'label')
model = hapi.Model(MyNet(), input, label)
optim = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
......@@ -836,8 +796,8 @@ class Model(object):
device = hapi.set_device('gpu')
fluid.enable_dygraph(device)
input = hapi.Input('x', [None, 784], 'float32')
label = hapi.Input('label', [None, 1], 'int64')
input = hapi.Input([None, 784], 'float32', 'x')
label = hapi.Input([None, 1], 'int64', 'label')
model = hapi.Model(MyNet(), input, label)
optim = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
......@@ -1194,8 +1154,8 @@ class Model(object):
train_dataset = hapi.datasets.MNIST(mode='train')
val_dataset = hapi.datasets.MNIST(mode='test')
input = hapi.Input('image', [None, 1, 28, 28], 'float32')
label = hapi.Input('label', [None, 1], 'int64')
input = hapi.Input([None, 1, 28, 28], 'float32', 'image')
label = hapi.Input([None, 1], 'int64', 'label')
model = hapi.Model(hapi.vision.LeNet(classifier_activation=None),
input, label)
......@@ -1231,8 +1191,8 @@ class Model(object):
val_loader = fluid.io.DataLoader(val_dataset,
places=device, batch_size=64)
input = hapi.Input('image', [None, 1, 28, 28], 'float32')
label = hapi.Input('label', [None, 1], 'int64')
input = hapi.Input([None, 1, 28, 28], 'float32', 'image')
label = hapi.Input([None, 1], 'int64', 'label')
model = hapi.Model(hapi.vision.LeNet(classifier_activation=None),
input, label)
......@@ -1359,8 +1319,8 @@ class Model(object):
# declarative mode
val_dataset = hapi.datasets.MNIST(mode='test')
input = hapi.Input('image', [-1, 1, 28, 28], 'float32')
label = hapi.Input('label', [None, 1], 'int64')
input = hapi.Input([-1, 1, 28, 28], 'float32', 'image')
label = hapi.Input([None, 1], 'int64', 'label')
model = hapi.Model(hapi.vision.LeNet(), input, label)
model.prepare(metrics=hapi.metrics.Accuracy())
......@@ -1433,12 +1393,13 @@ class Model(object):
num_workers (int): The number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this argument will be ignored. Default: 0.
stack_output (bool): Whether stack output field like a batch, as for an output
stack_outputs (bool): Whether stack output field like a batch, as for an output
filed of a sample is in shape [X, Y], test_data contains N samples, predict
output field will be in shape [N, X, Y] if stack_output is True, and will
be a length N list in shape [[X, Y], [X, Y], ....[X, Y]] if stack_outputs
is False. stack_outputs as False is used for LoDTensor output situation,
it is recommended set as True if outputs contains no LoDTensor. Default: False.
callbacks(Callback): A Callback instance, default None.
Returns:
list: output of models.
......@@ -1466,7 +1427,7 @@ class Model(object):
test_dataset = MnistDataset(mode='test', return_label=False)
# declarative mode
input = hapi.Input('image', [-1, 1, 28, 28], 'float32')
input = hapi.Input([-1, 1, 28, 28], 'float32', 'image')
model = hapi.Model(hapi.vision.LeNet(), input)
model.prepare()
......@@ -1548,7 +1509,7 @@ class Model(object):
import paddle.fluid as fluid
import paddle.incubate.hapi as hapi
input = hapi.Input('image', [-1, 1, 28, 28], 'float32')
input = hapi.Input([-1, 1, 28, 28], 'float32', 'image')
model = hapi.Model(hapi.vision.LeNet(), input)
model.prepare()
......@@ -1639,6 +1600,36 @@ class Model(object):
return logs, outputs
return logs
def _verify_spec(self, specs, is_input=False):
out_specs = []
if specs is None:
# If not specific specs of `Input`, using argument names of `forward` function
# to generate `Input`.
if is_input:
out_specs = [
Input(name=n) for n in extract_args(self.network.forward)
if n != 'self'
]
else:
out_specs = to_list(specs)
elif isinstance(specs, dict):
assert is_input == False
out_specs = [specs[n] \
for n in extract_args(self.network.forward) if n != 'self']
else:
out_specs = to_list(specs)
# Note: checks each element has specificed `name`.
if out_specs is not None:
for i, spec in enumerate(out_specs):
assert isinstance(spec, Input)
if spec.name is None:
raise ValueError(
"Requires Input[{}].name != None, but receive `None` with {}.".
format(i, spec))
return out_specs
def _reset_metrics(self):
for metric in self._metrics:
metric.reset()
......
......@@ -64,8 +64,8 @@ class TestDistTraning(unittest.TestCase):
im_shape = (-1, 1, 28, 28)
batch_size = 128
inputs = [Input('image', im_shape, 'float32')]
labels = [Input('label', [None, 1], 'int64')]
inputs = [Input(im_shape, 'float32', 'image')]
labels = [Input([None, 1], 'int64', 'label')]
model = Model(LeNet(classifier_activation=None), inputs, labels)
optim = fluid.optimizer.Momentum(
......
......@@ -63,8 +63,8 @@ class TestDistTraning(unittest.TestCase):
im_shape = (-1, 1, 28, 28)
batch_size = 128
inputs = [Input('image', im_shape, 'float32')]
labels = [Input('label', [None, 1], 'int64')]
inputs = [Input(im_shape, 'float32', 'image')]
labels = [Input([None, 1], 'int64', 'label')]
model = Model(LeNet(classifier_activation=None), inputs, labels)
optim = fluid.optimizer.Momentum(
......
......@@ -36,7 +36,7 @@ class TestCallbacks(unittest.TestCase):
freq = 2
eval_steps = 20
inputs = [Input('image', [None, 1, 28, 28], 'float32')]
inputs = [Input([None, 1, 28, 28], 'float32', 'image')]
lenet = Model(LeNet(), inputs)
lenet.prepare()
......
......@@ -150,8 +150,8 @@ class TestModel(unittest.TestCase):
cls.acc1 = dynamic_evaluate(dy_lenet, cls.val_loader)
cls.inputs = [Input('image', [-1, 1, 28, 28], 'float32')]
cls.labels = [Input('label', [None, 1], 'int64')]
cls.inputs = [Input([-1, 1, 28, 28], 'float32', 'image')]
cls.labels = [Input([None, 1], 'int64', 'label')]
cls.save_dir = tempfile.mkdtemp()
cls.weight_path = os.path.join(cls.save_dir, 'lenet')
......@@ -330,8 +330,8 @@ class TestModelFunction(unittest.TestCase):
optim2 = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=net.parameters())
inputs = [Input('x', [None, dim], 'float32')]
labels = [Input('label', [None, 1], 'int64')]
inputs = [Input([None, dim], 'float32', 'x')]
labels = [Input([None, 1], 'int64', 'label')]
model = Model(net, inputs, labels)
model.prepare(
optim2, loss_function=CrossEntropyLoss(reduction="sum"))
......@@ -359,7 +359,7 @@ class TestModelFunction(unittest.TestCase):
fluid.enable_dygraph(device) if dynamic else None
self.set_seed()
net = MyModel()
inputs = [Input('x', [None, dim], 'float32')]
inputs = [Input([None, dim], 'float32', 'x')]
model = Model(net, inputs)
model.prepare()
out, = model.test_batch([data])
......@@ -373,8 +373,8 @@ class TestModelFunction(unittest.TestCase):
device = hapi.set_device('cpu')
fluid.enable_dygraph(device) if dynamic else None
net = MyModel(classifier_activation=None)
inputs = [Input('x', [None, 20], 'float32')]
labels = [Input('label', [None, 1], 'int64')]
inputs = [Input([None, 20], 'float32', 'x')]
labels = [Input([None, 1], 'int64', 'label')]
optim = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=net.parameters())
model = Model(net, inputs, labels)
......@@ -399,8 +399,8 @@ class TestModelFunction(unittest.TestCase):
model.save(path + '/test')
fluid.disable_dygraph()
inputs = [Input('x', [None, 20], 'float32')]
labels = [Input('label', [None, 1], 'int64')]
inputs = [Input([None, 20], 'float32', 'x')]
labels = [Input([None, 1], 'int64', 'label')]
model = Model(MyModel(classifier_activation=None), inputs, labels)
optim = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=model.parameters())
......@@ -413,8 +413,8 @@ class TestModelFunction(unittest.TestCase):
path = tempfile.mkdtemp()
net = MyModel(classifier_activation=None)
inputs = [Input('x', [None, 20], 'float32')]
labels = [Input('label', [None, 1], 'int64')]
inputs = [Input([None, 20], 'float32', 'x')]
labels = [Input([None, 1], 'int64', 'label')]
optim = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=net.parameters())
model = Model(net, inputs, labels)
......@@ -426,8 +426,8 @@ class TestModelFunction(unittest.TestCase):
fluid.enable_dygraph(device) #if dynamic else None
net = MyModel(classifier_activation=None)
inputs = [Input('x', [None, 20], 'float32')]
labels = [Input('label', [None, 1], 'int64')]
inputs = [Input([None, 20], 'float32', 'x')]
labels = [Input([None, 1], 'int64', 'label')]
optim = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=net.parameters())
model = Model(net, inputs, labels)
......@@ -442,7 +442,7 @@ class TestModelFunction(unittest.TestCase):
device = hapi.set_device('cpu')
fluid.enable_dygraph(device) if dynamic else None
net = MyModel()
inputs = [Input('x', [None, 20], 'float32')]
inputs = [Input([None, 20], 'float32', 'x')]
model = Model(net, inputs)
model.prepare()
params = model.parameters()
......@@ -452,7 +452,7 @@ class TestModelFunction(unittest.TestCase):
def test_export_deploy_model(self):
net = LeNet()
inputs = [Input('image', [-1, 1, 28, 28], 'float32')]
inputs = [Input([-1, 1, 28, 28], 'float32', 'image')]
model = Model(net, inputs)
model.prepare()
save_dir = tempfile.mkdtemp()
......@@ -480,5 +480,15 @@ class TestModelFunction(unittest.TestCase):
shutil.rmtree(save_dir)
class TestRaiseError(unittest.TestCase):
def test_input_without_name(self):
net = MyModel(classifier_activation=None)
inputs = [Input([None, 10], 'float32')]
labels = [Input([None, 1], 'int64', 'label')]
with self.assertRaises(ValueError):
model = Model(net, inputs, labels)
if __name__ == '__main__':
unittest.main()
......@@ -28,7 +28,7 @@ class TestPretrainedModel(unittest.TestCase):
fluid.enable_dygraph()
net = models.__dict__[arch](pretrained=True, classifier_activation=None)
inputs = [Input('image', [None, 3, 224, 224], 'float32')]
inputs = [Input([None, 3, 224, 224], 'float32', 'image')]
model = Model(network=net, inputs=inputs)
model.prepare()
res = model.test_batch(x)
......
......@@ -142,7 +142,7 @@ class TestBasicLSTM(ModuleApiTest):
def make_inputs(self):
inputs = [
Input("input", [None, None, self.inputs[-1].shape[-1]], "float32"),
Input([None, None, self.inputs[-1].shape[-1]], "float32", "input"),
]
return inputs
......@@ -168,7 +168,7 @@ class TestBasicGRU(ModuleApiTest):
def make_inputs(self):
inputs = [
Input("input", [None, None, self.inputs[-1].shape[-1]], "float32"),
Input([None, None, self.inputs[-1].shape[-1]], "float32", "input"),
]
return inputs
......@@ -219,8 +219,8 @@ class TestBeamSearch(ModuleApiTest):
def make_inputs(self):
inputs = [
Input("init_hidden", [None, self.inputs[0].shape[-1]], "float32"),
Input("init_cell", [None, self.inputs[1].shape[-1]], "float32"),
Input([None, self.inputs[0].shape[-1]], "float32", "init_hidden"),
Input([None, self.inputs[1].shape[-1]], "float32", "init_cell"),
]
return inputs
......@@ -272,10 +272,10 @@ class TestTransformerEncoder(ModuleApiTest):
def make_inputs(self):
inputs = [
Input("enc_input", [None, None, self.inputs[0].shape[-1]],
"float32"),
Input("attn_bias", [None, self.inputs[1].shape[1], None, None],
"float32"),
Input([None, None, self.inputs[0].shape[-1]], "float32",
"enc_input"),
Input([None, self.inputs[1].shape[1], None, None], "float32",
"attn_bias"),
]
return inputs
......@@ -336,14 +336,14 @@ class TestTransformerDecoder(TestTransformerEncoder):
def make_inputs(self):
inputs = [
Input("dec_input", [None, None, self.inputs[0].shape[-1]],
"float32"),
Input("enc_output", [None, None, self.inputs[0].shape[-1]],
"float32"),
Input("self_attn_bias",
[None, self.inputs[-1].shape[1], None, None], "float32"),
Input("cross_attn_bias",
[None, self.inputs[-1].shape[1], None, None], "float32"),
Input([None, None, self.inputs[0].shape[-1]], "float32",
"dec_input"),
Input([None, None, self.inputs[0].shape[-1]], "float32",
"enc_output"),
Input([None, self.inputs[-1].shape[1], None, None], "float32",
"self_attn_bias"),
Input([None, self.inputs[-1].shape[1], None, None], "float32",
"cross_attn_bias"),
]
return inputs
......@@ -431,10 +431,10 @@ class TestTransformerBeamSearchDecoder(ModuleApiTest):
def make_inputs(self):
inputs = [
Input("enc_output", [None, None, self.inputs[0].shape[-1]],
"float32"),
Input("trg_src_attn_bias",
[None, self.inputs[1].shape[1], None, None], "float32"),
Input([None, None, self.inputs[0].shape[-1]], "float32",
"enc_output"),
Input([None, self.inputs[1].shape[1], None, None], "float32",
"trg_src_attn_bias"),
]
return inputs
......@@ -473,9 +473,9 @@ class TestSequenceTagging(ModuleApiTest):
def make_inputs(self):
inputs = [
Input("word", [None, None], "int64"),
Input("lengths", [None], "int64"),
Input("target", [None, None], "int64"),
Input([None, None], "int64", "word"),
Input([None], "int64", "lengths"),
Input([None, None], "int64", "target"),
]
return inputs
......@@ -517,7 +517,7 @@ class TestStackedRNN(ModuleApiTest):
def make_inputs(self):
inputs = [
Input("input", [None, None, self.inputs[-1].shape[-1]], "float32"),
Input([None, None, self.inputs[-1].shape[-1]], "float32", "input"),
]
return inputs
......@@ -543,7 +543,7 @@ class TestLSTM(ModuleApiTest):
def make_inputs(self):
inputs = [
Input("input", [None, None, self.inputs[-1].shape[-1]], "float32"),
Input([None, None, self.inputs[-1].shape[-1]], "float32", "input"),
]
return inputs
......@@ -579,7 +579,7 @@ class TestBiLSTM(ModuleApiTest):
def make_inputs(self):
inputs = [
Input("input", [None, None, self.inputs[-1].shape[-1]], "float32"),
Input([None, None, self.inputs[-1].shape[-1]], "float32", "input"),
]
return inputs
......@@ -609,7 +609,7 @@ class TestGRU(ModuleApiTest):
def make_inputs(self):
inputs = [
Input("input", [None, None, self.inputs[-1].shape[-1]], "float32"),
Input([None, None, self.inputs[-1].shape[-1]], "float32", "input"),
]
return inputs
......@@ -645,7 +645,7 @@ class TestBiGRU(ModuleApiTest):
def make_inputs(self):
inputs = [
Input("input", [None, None, self.inputs[-1].shape[-1]], "float32"),
Input([None, None, self.inputs[-1].shape[-1]], "float32", "input"),
]
return inputs
......@@ -680,7 +680,7 @@ class TestCNNEncoder(ModuleApiTest):
def make_inputs(self):
inputs = [
Input("input", [None, self.inputs[-1].shape[1], None], "float32"),
Input([None, self.inputs[-1].shape[1], None], "float32", "input"),
]
return inputs
......
......@@ -28,7 +28,7 @@ class TestVisonModels(unittest.TestCase):
else:
net = models.__dict__[arch](pretrained=pretrained)
input = hapi.Input('image', [None, 3, 224, 224], 'float32')
input = hapi.Input([None, 3, 224, 224], 'float32', 'image')
model = hapi.Model(net, input)
model.prepare()
......@@ -71,7 +71,7 @@ class TestVisonModels(unittest.TestCase):
self.models_infer('resnet152')
def test_lenet(self):
input = hapi.Input('x', [None, 1, 28, 28], 'float32')
input = hapi.Input([None, 1, 28, 28], 'float32', 'x')
lenet = hapi.Model(models.__dict__['LeNet'](), input)
lenet.prepare()
......
......@@ -49,7 +49,6 @@ from .decode import beam_search_decode #DEFINE_ALIAS
# from .decode import ctc_greedy_decoder #DEFINE_ALIAS
# from .decode import dynamic_decode #DEFINE_ALIAS
from .decode import gather_tree #DEFINE_ALIAS
from .input import data #DEFINE_ALIAS
# from .input import Input #DEFINE_ALIAS
from .layer.activation import ELU
from .layer.activation import GELU
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: define input placeholders of neural network
from ..fluid import data #DEFINE_ALIAS
__all__ = [
'data',
# 'Input'
]
......@@ -17,9 +17,12 @@ __all__ = [
'append_backward', 'gradients', 'Executor', 'global_scope', 'scope_guard',
'BuildStrategy', 'CompiledProgram', 'Print', 'py_func', 'ExecutionStrategy',
'name_scope', 'ParallelExecutor', 'program_guard', 'WeightNormParamAttr',
'default_main_program', 'default_startup_program', 'Program', 'save', 'load'
'default_main_program', 'default_startup_program', 'Program', 'save',
'load', 'data', 'InputSpec'
]
from .input import data #DEFINE_ALIAS
from .input import InputSpec #DEFINE_ALIAS
from ..fluid.executor import Executor #DEFINE_ALIAS
from ..fluid.executor import global_scope #DEFINE_ALIAS
from ..fluid.executor import scope_guard #DEFINE_ALIAS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..fluid.data import data
__all__ = ['data', 'InputSpec']
class InputSpec(object):
"""
Define input specification of the model.
Args:
name (str): The name/alias of the variable, see :ref:`api_guide_Name`
for more details.
shape (tuple(integers)|list[integers]): List|Tuple of integers
declaring the shape. You can set "None" or -1 at a dimension
to indicate the dimension can be of any size. For example,
it is useful to set changeable batch size as "None" or -1.
dtype (np.dtype|VarType|str, optional): The type of the data. Supported
dtype: bool, float16, float32, float64, int8, int16, int32, int64,
uint8. Default: float32.
Examples:
.. code-block:: python
from paddle.static import InputSpec
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
"""
def __init__(self, shape=None, dtype='float32', name=None):
self.shape = shape
self.dtype = dtype
self.name = name
def _create_feed_layer(self):
return data(self.name, shape=self.shape, dtype=self.dtype)
def __repr__(self):
return '{}(shape={}, dtype={}, name={})'.format(
type(self).__name__, self.shape, self.dtype, self.name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册