未验证 提交 3976bbe2 编写于 作者: T Tao Luo 提交者: GitHub

add input type and dtype check template, and update some APIs check (#21161)

* add input type and dtype check template, and update some APIs check

* refine check template, and update some APIs check in nn.py

* update some APIs check in loss.py

test=develop
上级 b3a3e6f6
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import six import six
from six.moves import zip, range, xrange from six.moves import zip, range, xrange
import multiprocessing import multiprocessing
import warnings
from .framework import Variable, default_main_program, _current_expected_place from .framework import Variable, default_main_program, _current_expected_place
from .framework import _cpu_num, _cuda_ids from .framework import _cpu_num, _cuda_ids
...@@ -64,6 +65,39 @@ def convert_dtype(dtype): ...@@ -64,6 +65,39 @@ def convert_dtype(dtype):
"int32, int64, uint8]") "int32, int64, uint8]")
def check_type_and_dtype(input,
input_name,
expected_type,
expected_dtype,
op_name,
extra_message=''):
check_type(input, input_name, expected_type, op_name, extra_message)
check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
def check_type(input, input_name, expected_type, op_name, extra_message=''):
if not isinstance(input, expected_type):
raise TypeError(
"The type of '%s' in %s must be %s, but received %s. %s" %
(input_name, op_name, expected_type, type(input), extra_message))
def check_dtype(input_dtype,
input_name,
expected_dtype,
op_name,
extra_message=''):
if convert_dtype(input_dtype) in ['float16']:
warnings.warn(
"The data type of '%s' in %s only support float16 in GPU now. %s" %
(input_name, op_name, extra_message))
if convert_dtype(input_dtype) not in expected_dtype:
raise TypeError(
"The data type of '%s' in %s must be %s, but received %s. %s" %
(input_name, op_name, expected_dtype, convert_dtype(input_dtype),
extra_message))
class DataToLoDTensorConverter(object): class DataToLoDTensorConverter(object):
def __init__(self, place, lod_level, shape, dtype): def __init__(self, place, lod_level, shape, dtype):
self.place = place self.place = place
......
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import warnings import warnings
from .framework import Variable, in_dygraph_mode from .framework import Variable, in_dygraph_mode
from .layer_helper import LayerHelper from .layer_helper import LayerHelper
from .data_feeder import convert_dtype from .data_feeder import check_type_and_dtype, check_dtype
__all__ = ['one_hot', 'embedding'] __all__ = ['one_hot', 'embedding']
...@@ -233,21 +233,9 @@ def embedding(input, ...@@ -233,21 +233,9 @@ def embedding(input,
""" """
helper = LayerHelper('embedding', **locals()) helper = LayerHelper('embedding', **locals())
if not isinstance(input, Variable): check_type_and_dtype(input, 'input', Variable, ['int64'], 'fluid.embedding')
raise TypeError( check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
"The type of 'input' in fluid.embedding must be Variable, but received %s" 'fluid.embedding')
% (type(input)))
if convert_dtype(input.dtype) not in ['int64']:
raise TypeError(
"The data type of 'input' in fluid.embedding must be int64, but received %s."
% (convert_dtype(input.dtype)))
if convert_dtype(dtype) in ['float16']:
warnings.warn(
"The 'dtype' of fluid.embedding only support float16 in GPU now.")
if convert_dtype(dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The 'dtype' of fluid.embedding must be float16, float32 or float64, but received %s."
% (convert_dtype(dtype)))
remote_prefetch = is_sparse and (not is_distributed) remote_prefetch = is_sparse and (not is_distributed)
if remote_prefetch: if remote_prefetch:
assert is_sparse is True and is_distributed is False assert is_sparse is True and is_distributed is False
......
...@@ -22,7 +22,7 @@ from six.moves import cStringIO ...@@ -22,7 +22,7 @@ from six.moves import cStringIO
from ..proto import framework_pb2 from ..proto import framework_pb2
from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_ from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..data_feeder import convert_dtype from ..data_feeder import check_type_and_dtype
__all__ = [ __all__ = [
'deprecated', 'generate_layer_fn', 'generate_activation_fn', 'autodoc', 'deprecated', 'generate_layer_fn', 'generate_activation_fn', 'autodoc',
...@@ -251,18 +251,8 @@ def generate_activation_fn(op_type): ...@@ -251,18 +251,8 @@ def generate_activation_fn(op_type):
def func(x, name=None): def func(x, name=None):
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
if not isinstance(x, Variable): check_type_and_dtype(x, 'x', Variable,
raise TypeError( ['float16', 'float32', 'float64'], op_type)
"The type of 'x' in %s must be Variable, but received %s" %
(op_type, type(x)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in %s only support float16 in GPU now." %
(op_type))
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in %s must be float16 (only support on GPU), float32 or float64, but received %s."
% (op_type, convert_dtype(x.dtype)))
output = helper.create_variable_for_type_inference(dtype=x.dtype) output = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output}) helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output})
return output return output
......
...@@ -20,7 +20,7 @@ from . import nn ...@@ -20,7 +20,7 @@ from . import nn
from .layer_function_generator import templatedoc from .layer_function_generator import templatedoc
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..framework import Variable from ..framework import Variable
from ..data_feeder import convert_dtype from ..data_feeder import check_type_and_dtype
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from ..initializer import NumpyArrayInitializer from ..initializer import NumpyArrayInitializer
...@@ -233,19 +233,8 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): ...@@ -233,19 +233,8 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
predict = fluid.layers.fc(input=x, size=class_num, act='softmax') predict = fluid.layers.fc(input=x, size=class_num, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
""" """
if not isinstance(input, Variable): check_type_and_dtype(input, 'input', Variable,
raise TypeError( ['float16', 'float32', 'float64'], 'cross_entropy')
"The type of 'input' in cross_entropy must be Variable, but received %s"
% (type(input)))
if convert_dtype(input.dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in cross_entropy only support float16 on GPU now."
)
if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in cross_entropy must be float16 or float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
if not soft_label: if not soft_label:
return cross_entropy2(input, label, ignore_index) return cross_entropy2(input, label, ignore_index)
helper = LayerHelper('cross_entropy', **locals()) helper = LayerHelper('cross_entropy', **locals())
...@@ -674,23 +663,9 @@ def nce(input, ...@@ -674,23 +663,9 @@ def nce(input,
custom_dist=dist) custom_dist=dist)
""" """
helper = LayerHelper('nce', **locals()) helper = LayerHelper('nce', **locals())
check_type_and_dtype(input, 'input', Variable, ['float32', 'float64'],
if not isinstance(input, Variable): 'nce')
raise TypeError( check_type_and_dtype(label, 'label', Variable, ['int64'], 'nce')
"The type of 'input' in nce layer must be Variable, but received %s"
% (type(input)))
if not isinstance(label, Variable):
raise TypeError(
"The type of 'label' in nce layer must be Variable, but received %s"
% (type(label)))
if convert_dtype(input.dtype) not in ['float32', 'float64']:
raise TypeError(
"The data type of 'input' in nce layer must be float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
if convert_dtype(label.dtype) not in ['int64']:
raise TypeError(
"The data type of 'label' in nce layer must be int64, but received %s."
% (convert_dtype(label.dtype)))
dim = input.shape[1] dim = input.shape[1]
num_true_class = label.shape[1] num_true_class = label.shape[1]
......
...@@ -23,7 +23,7 @@ from ..initializer import Normal, Constant ...@@ -23,7 +23,7 @@ from ..initializer import Normal, Constant
from ..framework import Variable from ..framework import Variable
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from . import nn from . import nn
from ..data_feeder import convert_dtype from ..data_feeder import check_type_and_dtype
__all__ = ['accuracy', 'auc'] __all__ = ['accuracy', 'auc']
...@@ -72,18 +72,8 @@ def accuracy(input, label, k=1, correct=None, total=None): ...@@ -72,18 +72,8 @@ def accuracy(input, label, k=1, correct=None, total=None):
#[array([0.6666667], dtype=float32)] #[array([0.6666667], dtype=float32)]
""" """
helper = LayerHelper("accuracy", **locals()) helper = LayerHelper("accuracy", **locals())
if not isinstance(input, Variable): check_type_and_dtype(input, 'input', Variable,
raise TypeError( ['float16', 'float32', 'float64'], 'accuracy')
"The type of 'input' in accuracy must be Variable, but received %s"
% (type(input)))
if convert_dtype(input.dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in accuracy only support float16 in GPU now."
)
if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in accuracy must be float16 or float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
topk_out, topk_indices = nn.topk(input, k=k) topk_out, topk_indices = nn.topk(input, k=k)
acc_out = helper.create_variable_for_type_inference(dtype="float32") acc_out = helper.create_variable_for_type_inference(dtype="float32")
if correct is None: if correct is None:
......
此差异已折叠。
...@@ -21,10 +21,9 @@ from ..framework import Variable ...@@ -21,10 +21,9 @@ from ..framework import Variable
from ..initializer import Constant, force_init_on_cpu from ..initializer import Constant, force_init_on_cpu
from ..core import VarDesc from ..core import VarDesc
from .layer_function_generator import templatedoc from .layer_function_generator import templatedoc
from ..data_feeder import convert_dtype from ..data_feeder import check_type_and_dtype, check_type, check_dtype, convert_dtype
import numpy import numpy
import warnings import warnings
from ..data_feeder import convert_dtype
__all__ = [ __all__ = [
'create_tensor', 'create_parameter', 'create_global_var', 'cast', 'create_tensor', 'create_parameter', 'create_global_var', 'cast',
...@@ -192,17 +191,10 @@ def cast(x, dtype): ...@@ -192,17 +191,10 @@ def cast(x, dtype):
# [ 0 4]] int32 # [ 0 4]] int32
""" """
helper = LayerHelper('cast', **locals()) helper = LayerHelper('cast', **locals())
if not isinstance(x, Variable): check_type_and_dtype(
raise TypeError( x, 'x', Variable,
"The type of 'x' in cast must be Variable, but received %s" % ['bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
(type(x))) 'cast')
if convert_dtype(x.dtype) not in [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'
]:
raise TypeError(
"The data type of 'x' in cast must be one of [bool, float16, float32, float64, int32, int64, uint8], but received %s."
% (convert_dtype(x.dtype)))
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op( helper.append_op(
type='cast', type='cast',
...@@ -265,25 +257,11 @@ def concat(input, axis=0, name=None): ...@@ -265,25 +257,11 @@ def concat(input, axis=0, name=None):
"The type of input in concat should be list, but received %s." % "The type of input in concat should be list, but received %s." %
(type(input))) (type(input)))
input = [input] input = [input]
for x in input: for id, x in enumerate(input):
if not isinstance(x, Variable): check_type_and_dtype(
raise TypeError( x, 'input[' + str(id) + ']', Variable,
"The type of x in 'input' in concat must be Variable, but received %s." ['float16', 'float32', 'float64', 'int32', 'int64'], 'concat')
% (type(x))) check_type(axis, 'axis', (int, Variable), 'concat')
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of x in 'input' in concat only support float16 on GPU now."
)
if convert_dtype(x.dtype) not in [
'float16', 'float32', 'float64', 'int32', 'int64'
]:
raise TypeError(
"The data type of x in 'input' in concat must be float16(only support on GPU), float32, float64, int32, int64, but received %s."
% (convert_dtype(x.dtype)))
if not isinstance(axis, (int, Variable)):
raise TypeError(
"The type of 'axis' in concat must be int or Variable, but "
"received %s." % (type(axis)))
inputs = {'X': input} inputs = {'X': input}
attrs = {} attrs = {}
if isinstance(axis, Variable): if isinstance(axis, Variable):
...@@ -478,14 +456,11 @@ def assign(input, output=None): ...@@ -478,14 +456,11 @@ def assign(input, output=None):
result3 = fluid.layers.assign(np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32')) # result3 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]] result3 = fluid.layers.assign(np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32')) # result3 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]]
""" """
helper = LayerHelper('assign', **locals()) helper = LayerHelper('assign', **locals())
check_type(input, 'input', (Variable, numpy.ndarray), 'assign')
if isinstance(input, Variable): if isinstance(input, Variable):
if convert_dtype(input.dtype) not in [ check_dtype(input.dtype, 'input',
'float32', 'float64', 'int32', 'int64', 'bool' ['float32', 'float64', 'int32', 'int64', 'bool'], 'assign',
]: '(When the type of input in assign is Variable.)')
raise TypeError(
"When the type of 'input' in assign is Variable, the data "
"type of 'input' must be float32, float64, int32, int64 or "
"bool, but received %s." % convert_dtype(input.dtype))
if output is None: if output is None:
output = helper.create_variable_for_type_inference( output = helper.create_variable_for_type_inference(
dtype=input.dtype) dtype=input.dtype)
...@@ -518,9 +493,6 @@ def assign(input, output=None): ...@@ -518,9 +493,6 @@ def assign(input, output=None):
'shape': list(input.shape), 'shape': list(input.shape),
value_name: values value_name: values
}) })
else:
raise TypeError("The type of 'input' in assign must be Variable or "
"numpy.ndarray, but received %s" % type(input))
return output return output
...@@ -570,19 +542,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -570,19 +542,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]] data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]]
""" """
helper = LayerHelper("fill_constant", **locals()) helper = LayerHelper("fill_constant", **locals())
if convert_dtype(dtype) not in [ check_dtype(dtype, 'create data type',
'bool', 'float16', 'float32', 'float64', 'int32', 'int64' ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
]: 'fill_constant')
raise TypeError( check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')
"The create data type in fill_constant must be one of 'bool', float16, float32,"
"float64, int32 or int64, but received %s." % convert_dtype(
(dtype)))
if not isinstance(shape, (list, tuple, Variable)):
raise TypeError(
"The type of 'shape' in fill_constant must be Variable, list or tuple, but "
"received %s." % (type(shape)))
inputs = {} inputs = {}
attrs = { attrs = {
'value': float(value), 'value': float(value),
...@@ -609,12 +572,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -609,12 +572,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
for idx, dim in enumerate(list_shape): for idx, dim in enumerate(list_shape):
if isinstance(dim, Variable): if isinstance(dim, Variable):
dim.stop_gradient = True dim.stop_gradient = True
if convert_dtype(dim.dtype) not in ['int32', 'int64']: check_dtype(
raise TypeError( dim.dtype, 'shape[' + str(idx) + ']', ['int32', 'int64'],
"When type of 'shape' in fill_constant is list or tuple, " 'fill_constant',
"the data type of the element with type Variable must be int32 or int64, " '(When type of shape in fill_constant is list or tuple.)')
"but received the data type of shape[%d] is %s." %
(idx, convert_dtype(dim.dtype)))
if convert_dtype(dim.dtype) == 'int64': if convert_dtype(dim.dtype) == 'int64':
dim = cast(x=dim, dtype='int32') dim = cast(x=dim, dtype='int32')
new_shape_tensor.append(dim) new_shape_tensor.append(dim)
...@@ -626,10 +587,8 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -626,10 +587,8 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
if isinstance(shape, Variable): if isinstance(shape, Variable):
shape.stop_gradient = True shape.stop_gradient = True
if convert_dtype(shape.dtype) not in ['int32', 'int64']: check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant',
raise TypeError( '(When type of shape in fill_constant is Variable.)')
"When type of 'shape' in fill_constant is Variable, the data type of 'shape' must be int32 or int64, "
"but received %s." % (convert_dtype(shape.dtype)))
if (convert_dtype(shape.dtype) == 'int64'): if (convert_dtype(shape.dtype) == 'int64'):
shape = cast(shape, 'int32') shape = cast(shape, 'int32')
inputs["ShapeTensor"] = shape inputs["ShapeTensor"] = shape
...@@ -644,11 +603,11 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -644,11 +603,11 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
if out is None: if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
else: else:
if not (convert_dtype(dtype) == convert_dtype(out.dtype)): check_dtype(
raise TypeError( dtype, 'create data type',
"The create data type in op must be same with out type" convert_dtype(out.dtype), 'fill_constant',
"but received %s and out dtype %s." % (convert_dtype( '(The create data type in fill_constant must be the same with out data type.)'
(dtype), convert_dtype(out.dtype)))) )
attrs['dtype'] = out.dtype attrs['dtype'] = out.dtype
helper.append_op( helper.append_op(
type='fill_constant', type='fill_constant',
...@@ -970,13 +929,9 @@ def zeros(shape, dtype, force_cpu=False): ...@@ -970,13 +929,9 @@ def zeros(shape, dtype, force_cpu=False):
import paddle.fluid as fluid import paddle.fluid as fluid
data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]] data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]]
""" """
if convert_dtype(dtype) not in [ check_dtype(dtype, 'create data type',
'bool', 'float16', 'float32', 'float64', 'int32', 'int64' ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
]: 'zeros')
raise TypeError(
"The create data type in zeros must be one of bool, float16, float32,"
" float64, int32 or int64, but received %s." % convert_dtype(
(dtype)))
return fill_constant(value=0.0, **locals()) return fill_constant(value=0.0, **locals())
......
...@@ -64,13 +64,13 @@ class TestAccuracyOpError(OpTest): ...@@ -64,13 +64,13 @@ class TestAccuracyOpError(OpTest):
# The input type of accuracy_op must be Variable. # The input type of accuracy_op must be Variable.
x1 = fluid.create_lod_tensor( x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace()) np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.accuracy, x1) label = fluid.layers.data(
name='label', shape=[-1, 1], dtype="int32")
self.assertRaises(TypeError, fluid.layers.accuracy, x1, label)
# The input dtype of accuracy_op must be float32 or float64. # The input dtype of accuracy_op must be float32 or float64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32") x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.accuracy, x2) self.assertRaises(TypeError, fluid.layers.accuracy, x2, label)
x3 = fluid.layers.data(name='input', shape=[-1, 2], dtype="float16") x3 = fluid.layers.data(name='input', shape=[-1, 2], dtype="float16")
label = fluid.layers.data(
name='label', shape=[-1, 1], dtype="int32")
fluid.layers.accuracy(input=x3, label=label) fluid.layers.accuracy(input=x3, label=label)
......
...@@ -28,10 +28,9 @@ class TestZerosOpError(OpTest): ...@@ -28,10 +28,9 @@ class TestZerosOpError(OpTest):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
# The input dtype of zeros_op must be bool, float16, float32, float64, int32, int64. # The input dtype of zeros_op must be bool, float16, float32, float64, int32, int64.
x1 = fluid.layers.data(name='x1', shape=[4], dtype="int8") shape = [4]
self.assertRaises(TypeError, fluid.layers.zeros, x1) dtype = "int8"
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8") self.assertRaises(TypeError, fluid.layers.zeros, shape, dtype)
self.assertRaises(TypeError, fluid.layers.zeros, x2)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册