未验证 提交 3976bbe2 编写于 作者: T Tao Luo 提交者: GitHub

add input type and dtype check template, and update some APIs check (#21161)

* add input type and dtype check template, and update some APIs check

* refine check template, and update some APIs check in nn.py

* update some APIs check in loss.py

test=develop
上级 b3a3e6f6
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import six import six
from six.moves import zip, range, xrange from six.moves import zip, range, xrange
import multiprocessing import multiprocessing
import warnings
from .framework import Variable, default_main_program, _current_expected_place from .framework import Variable, default_main_program, _current_expected_place
from .framework import _cpu_num, _cuda_ids from .framework import _cpu_num, _cuda_ids
...@@ -64,6 +65,39 @@ def convert_dtype(dtype): ...@@ -64,6 +65,39 @@ def convert_dtype(dtype):
"int32, int64, uint8]") "int32, int64, uint8]")
def check_type_and_dtype(input,
input_name,
expected_type,
expected_dtype,
op_name,
extra_message=''):
check_type(input, input_name, expected_type, op_name, extra_message)
check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
def check_type(input, input_name, expected_type, op_name, extra_message=''):
if not isinstance(input, expected_type):
raise TypeError(
"The type of '%s' in %s must be %s, but received %s. %s" %
(input_name, op_name, expected_type, type(input), extra_message))
def check_dtype(input_dtype,
input_name,
expected_dtype,
op_name,
extra_message=''):
if convert_dtype(input_dtype) in ['float16']:
warnings.warn(
"The data type of '%s' in %s only support float16 in GPU now. %s" %
(input_name, op_name, extra_message))
if convert_dtype(input_dtype) not in expected_dtype:
raise TypeError(
"The data type of '%s' in %s must be %s, but received %s. %s" %
(input_name, op_name, expected_dtype, convert_dtype(input_dtype),
extra_message))
class DataToLoDTensorConverter(object): class DataToLoDTensorConverter(object):
def __init__(self, place, lod_level, shape, dtype): def __init__(self, place, lod_level, shape, dtype):
self.place = place self.place = place
......
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import warnings import warnings
from .framework import Variable, in_dygraph_mode from .framework import Variable, in_dygraph_mode
from .layer_helper import LayerHelper from .layer_helper import LayerHelper
from .data_feeder import convert_dtype from .data_feeder import check_type_and_dtype, check_dtype
__all__ = ['one_hot', 'embedding'] __all__ = ['one_hot', 'embedding']
...@@ -233,21 +233,9 @@ def embedding(input, ...@@ -233,21 +233,9 @@ def embedding(input,
""" """
helper = LayerHelper('embedding', **locals()) helper = LayerHelper('embedding', **locals())
if not isinstance(input, Variable): check_type_and_dtype(input, 'input', Variable, ['int64'], 'fluid.embedding')
raise TypeError( check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
"The type of 'input' in fluid.embedding must be Variable, but received %s" 'fluid.embedding')
% (type(input)))
if convert_dtype(input.dtype) not in ['int64']:
raise TypeError(
"The data type of 'input' in fluid.embedding must be int64, but received %s."
% (convert_dtype(input.dtype)))
if convert_dtype(dtype) in ['float16']:
warnings.warn(
"The 'dtype' of fluid.embedding only support float16 in GPU now.")
if convert_dtype(dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The 'dtype' of fluid.embedding must be float16, float32 or float64, but received %s."
% (convert_dtype(dtype)))
remote_prefetch = is_sparse and (not is_distributed) remote_prefetch = is_sparse and (not is_distributed)
if remote_prefetch: if remote_prefetch:
assert is_sparse is True and is_distributed is False assert is_sparse is True and is_distributed is False
......
...@@ -22,7 +22,7 @@ from six.moves import cStringIO ...@@ -22,7 +22,7 @@ from six.moves import cStringIO
from ..proto import framework_pb2 from ..proto import framework_pb2
from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_ from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..data_feeder import convert_dtype from ..data_feeder import check_type_and_dtype
__all__ = [ __all__ = [
'deprecated', 'generate_layer_fn', 'generate_activation_fn', 'autodoc', 'deprecated', 'generate_layer_fn', 'generate_activation_fn', 'autodoc',
...@@ -251,18 +251,8 @@ def generate_activation_fn(op_type): ...@@ -251,18 +251,8 @@ def generate_activation_fn(op_type):
def func(x, name=None): def func(x, name=None):
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
if not isinstance(x, Variable): check_type_and_dtype(x, 'x', Variable,
raise TypeError( ['float16', 'float32', 'float64'], op_type)
"The type of 'x' in %s must be Variable, but received %s" %
(op_type, type(x)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in %s only support float16 in GPU now." %
(op_type))
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in %s must be float16 (only support on GPU), float32 or float64, but received %s."
% (op_type, convert_dtype(x.dtype)))
output = helper.create_variable_for_type_inference(dtype=x.dtype) output = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output}) helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output})
return output return output
......
...@@ -20,7 +20,7 @@ from . import nn ...@@ -20,7 +20,7 @@ from . import nn
from .layer_function_generator import templatedoc from .layer_function_generator import templatedoc
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..framework import Variable from ..framework import Variable
from ..data_feeder import convert_dtype from ..data_feeder import check_type_and_dtype
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from ..initializer import NumpyArrayInitializer from ..initializer import NumpyArrayInitializer
...@@ -233,19 +233,8 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): ...@@ -233,19 +233,8 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
predict = fluid.layers.fc(input=x, size=class_num, act='softmax') predict = fluid.layers.fc(input=x, size=class_num, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
""" """
if not isinstance(input, Variable): check_type_and_dtype(input, 'input', Variable,
raise TypeError( ['float16', 'float32', 'float64'], 'cross_entropy')
"The type of 'input' in cross_entropy must be Variable, but received %s"
% (type(input)))
if convert_dtype(input.dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in cross_entropy only support float16 on GPU now."
)
if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in cross_entropy must be float16 or float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
if not soft_label: if not soft_label:
return cross_entropy2(input, label, ignore_index) return cross_entropy2(input, label, ignore_index)
helper = LayerHelper('cross_entropy', **locals()) helper = LayerHelper('cross_entropy', **locals())
...@@ -674,23 +663,9 @@ def nce(input, ...@@ -674,23 +663,9 @@ def nce(input,
custom_dist=dist) custom_dist=dist)
""" """
helper = LayerHelper('nce', **locals()) helper = LayerHelper('nce', **locals())
check_type_and_dtype(input, 'input', Variable, ['float32', 'float64'],
if not isinstance(input, Variable): 'nce')
raise TypeError( check_type_and_dtype(label, 'label', Variable, ['int64'], 'nce')
"The type of 'input' in nce layer must be Variable, but received %s"
% (type(input)))
if not isinstance(label, Variable):
raise TypeError(
"The type of 'label' in nce layer must be Variable, but received %s"
% (type(label)))
if convert_dtype(input.dtype) not in ['float32', 'float64']:
raise TypeError(
"The data type of 'input' in nce layer must be float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
if convert_dtype(label.dtype) not in ['int64']:
raise TypeError(
"The data type of 'label' in nce layer must be int64, but received %s."
% (convert_dtype(label.dtype)))
dim = input.shape[1] dim = input.shape[1]
num_true_class = label.shape[1] num_true_class = label.shape[1]
......
...@@ -23,7 +23,7 @@ from ..initializer import Normal, Constant ...@@ -23,7 +23,7 @@ from ..initializer import Normal, Constant
from ..framework import Variable from ..framework import Variable
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from . import nn from . import nn
from ..data_feeder import convert_dtype from ..data_feeder import check_type_and_dtype
__all__ = ['accuracy', 'auc'] __all__ = ['accuracy', 'auc']
...@@ -72,18 +72,8 @@ def accuracy(input, label, k=1, correct=None, total=None): ...@@ -72,18 +72,8 @@ def accuracy(input, label, k=1, correct=None, total=None):
#[array([0.6666667], dtype=float32)] #[array([0.6666667], dtype=float32)]
""" """
helper = LayerHelper("accuracy", **locals()) helper = LayerHelper("accuracy", **locals())
if not isinstance(input, Variable): check_type_and_dtype(input, 'input', Variable,
raise TypeError( ['float16', 'float32', 'float64'], 'accuracy')
"The type of 'input' in accuracy must be Variable, but received %s"
% (type(input)))
if convert_dtype(input.dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in accuracy only support float16 in GPU now."
)
if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in accuracy must be float16 or float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
topk_out, topk_indices = nn.topk(input, k=k) topk_out, topk_indices = nn.topk(input, k=k)
acc_out = helper.create_variable_for_type_inference(dtype="float32") acc_out = helper.create_variable_for_type_inference(dtype="float32")
if correct is None: if correct is None:
......
...@@ -34,7 +34,7 @@ from .. import unique_name ...@@ -34,7 +34,7 @@ from .. import unique_name
from functools import reduce from functools import reduce
from .. import core from .. import core
from ..dygraph import layers from ..dygraph import layers
from ..data_feeder import convert_dtype from ..data_feeder import convert_dtype, check_type_and_dtype, check_type, check_dtype
__all__ = [ __all__ = [
'fc', 'fc',
...@@ -301,26 +301,12 @@ def fc(input, ...@@ -301,26 +301,12 @@ def fc(input,
fc = fluid.layers.fc(input=[data_1, data_2], size=1000, act="tanh") fc = fluid.layers.fc(input=[data_1, data_2], size=1000, act="tanh")
""" """
helper = LayerHelper("fc", **locals()) helper = LayerHelper("fc", **locals())
check_type(input, 'input', (list, tuple, Variable), 'fc')
if isinstance(input, (list, tuple)): if isinstance(input, (list, tuple)):
for i, input_x in enumerate(input): for i, input_x in enumerate(input):
if not isinstance(input_x, Variable): check_type(input_x, 'input[' + str(i) + ']', Variable, 'fc')
raise TypeError(
"The type of input[%d] in fc must be Variable, but received %s"
% (i, type(input_x)))
else:
if not isinstance(input, Variable):
raise TypeError(
"The type of 'input' in fc must be Variable, but received %s" %
(type(input)))
dtype = helper.input_dtype() dtype = helper.input_dtype()
if convert_dtype(dtype) in ['float16']: check_dtype(dtype, 'input', ['float16', 'float32', 'float64'], 'fc')
warnings.warn(
"The data type of 'input' in fc only support float16 in GPU now.")
if convert_dtype(dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in fc must be float16, float32 or float64, but received %s."
% (convert_dtype(dtype)))
mul_results = [] mul_results = []
for input_var, param_attr in helper.iter_inputs_and_params(): for input_var, param_attr in helper.iter_inputs_and_params():
input_shape = input_var.shape input_shape = input_var.shape
...@@ -470,21 +456,10 @@ def embedding(input, ...@@ -470,21 +456,10 @@ def embedding(input,
""" """
helper = LayerHelper('embedding', **locals()) helper = LayerHelper('embedding', **locals())
if not isinstance(input, Variable): check_type_and_dtype(input, 'input', Variable, ['int64'],
raise TypeError( 'fluid.layers.embedding')
"The type of 'input' in layers.embedding must be Variable, but received %s" check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
% (type(input))) 'fluid.layers.embedding')
if convert_dtype(input.dtype) not in ['int64']:
raise TypeError(
"The data type of 'input' in layers.embedding must be int64, but received %s."
% (convert_dtype(input.dtype)))
if convert_dtype(dtype) in ['float16']:
warnings.warn(
"The 'dtype' of layers.embedding only support float16 in GPU now.")
if convert_dtype(dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The 'dtype' of layers.embedding must be float16, float32 or float64, but received %s."
% (convert_dtype(dtype)))
remote_prefetch = is_sparse and (not is_distributed) remote_prefetch = is_sparse and (not is_distributed)
if remote_prefetch: if remote_prefetch:
assert is_sparse is True and is_distributed is False assert is_sparse is True and is_distributed is False
...@@ -830,20 +805,8 @@ def dropout(x, ...@@ -830,20 +805,8 @@ def dropout(x,
""" """
helper = LayerHelper('dropout', **locals()) helper = LayerHelper('dropout', **locals())
check_type_and_dtype(x, 'x', Variable, ['float16', 'float32', 'float64'],
if not isinstance(x, Variable): 'dropout')
raise TypeError(
"The type of 'input' in dropout must be Variable, but received %s" %
(type(x)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in dropout only support float16 on GPU now."
)
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in dropout must be float16 or float32 or float64, but received %s."
% (convert_dtype(x.dtype)))
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_variable_for_type_inference( mask = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
...@@ -1125,18 +1088,8 @@ def softmax(input, use_cudnn=False, name=None, axis=-1): ...@@ -1125,18 +1088,8 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
print(output) print(output)
""" """
helper = LayerHelper('softmax', **locals()) helper = LayerHelper('softmax', **locals())
if not isinstance(input, Variable): check_type_and_dtype(input, 'input', Variable,
raise TypeError( ['float16', 'float32', 'float64'], 'softmax')
"The type of 'input' in softmax must be Variable, but received %s" %
(type(input)))
if convert_dtype(input.dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in softmax only support float16 in GPU now."
)
if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in softmax must be float16, float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
dtype = helper.input_dtype() dtype = helper.input_dtype()
softmax_out = helper.create_variable_for_type_inference(dtype) softmax_out = helper.create_variable_for_type_inference(dtype)
...@@ -1278,19 +1231,8 @@ def conv2d(input, ...@@ -1278,19 +1231,8 @@ def conv2d(input,
conv2d = fluid.layers.conv2d(input=data, num_filters=2, filter_size=3, act="relu") conv2d = fluid.layers.conv2d(input=data, num_filters=2, filter_size=3, act="relu")
""" """
if not isinstance(input, Variable): check_type_and_dtype(input, 'input', Variable,
raise TypeError( ['float16', 'float32', 'float64'], 'conv2d')
"The type of 'input' in conv2d must be Variable, but received %s" %
(type(input)))
if convert_dtype(input.dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in conv2d only support float16 on GPU now."
)
if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in conv2d must be float16 or float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
num_channels = input.shape[1] num_channels = input.shape[1]
if not isinstance(use_cudnn, bool): if not isinstance(use_cudnn, bool):
raise ValueError("Attr(use_cudnn) should be True or False. Received " raise ValueError("Attr(use_cudnn) should be True or False. Received "
...@@ -2518,19 +2460,8 @@ def batch_norm(input, ...@@ -2518,19 +2460,8 @@ def batch_norm(input,
assert bias_attr is not False, "bias_attr should not be False in batch_norm." assert bias_attr is not False, "bias_attr should not be False in batch_norm."
helper = LayerHelper('batch_norm', **locals()) helper = LayerHelper('batch_norm', **locals())
if not isinstance(input, Variable): check_type_and_dtype(input, 'input', Variable,
raise TypeError( ['float16', 'float32', 'float64'], 'batch_norm')
"The type of 'input' in batch_norm must be Variable, but received %s"
% (type(input)))
if convert_dtype(input.dtype) in ['float16']:
warnings.warn(
"The data type of 'input' in batch_norm only support float16 on GPU now."
)
if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'input' in batch_norm must be float16 or float32 or float64, but received %s."
% (convert_dtype(input.dtype)))
dtype = helper.input_dtype() dtype = helper.input_dtype()
# use fp32 for bn parameter # use fp32 for bn parameter
if dtype == core.VarDesc.VarType.FP16: if dtype == core.VarDesc.VarType.FP16:
...@@ -3786,15 +3717,8 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): ...@@ -3786,15 +3717,8 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
""" """
helper = LayerHelper('reduce_sum', **locals()) helper = LayerHelper('reduce_sum', **locals())
if not isinstance(input, Variable): check_type_and_dtype(input, 'input', Variable,
raise TypeError( ['float32', 'float64', 'int32', 'int64'], 'reduce_sum')
"The type of 'input' in reduce_sum must be Variable, but received %s"
% (type(input)))
if convert_dtype(
input.dtype) not in ['float32', 'float64', 'int32', 'int64']:
raise TypeError(
"The data type of 'input' in reduce_sum must be float32 or float64 or int32 or int64, but received %s."
% (convert_dtype(input.dtype)))
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list): if dim is not None and not isinstance(dim, list):
dim = [dim] dim = [dim]
...@@ -3860,15 +3784,9 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None): ...@@ -3860,15 +3784,9 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_mean(y, dim=[0, 1]) # [4.0, 5.0] fluid.layers.reduce_mean(y, dim=[0, 1]) # [4.0, 5.0]
""" """
helper = LayerHelper('reduce_mean', **locals()) helper = LayerHelper('reduce_mean', **locals())
if not isinstance(input, Variable): check_type_and_dtype(input, 'input', Variable,
raise TypeError( ['float32', 'float64', 'int32', 'int64'],
"The type of 'input' in reduce_mean must be Variable, but received %s" 'reduce_mean')
% (type(input)))
if convert_dtype(
input.dtype) not in ['float32', 'float64', 'int32', 'int64']:
raise TypeError(
"The data type of 'input' in reduce_mean must be float32 or float64 or int32 or int64, but received %s."
% (convert_dtype(input.dtype)))
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list): if dim is not None and not isinstance(dim, list):
dim = [dim] dim = [dim]
...@@ -4463,20 +4381,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): ...@@ -4463,20 +4381,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
def __check_input(x, y): def __check_input(x, y):
var_names = {'x': x, 'y': y} var_names = {'x': x, 'y': y}
for name, val in var_names.items(): for name, val in var_names.items():
if not isinstance(val, Variable): check_type_and_dtype(val, name, Variable,
raise TypeError( ['float16', 'float32', 'float64'], 'matmul')
"The type of %s in matmul must be Variable, but received %s.\n"
% (name, (type(val))))
if convert_dtype(val.dtype) in ['float16']:
warnings.warn(
"The data type of %s in matmul only support float16 in GPU now."
% name)
if convert_dtype(
val.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of %s in matmul must be float16 or float32 or float64, but received %s.\n"
% (name, (convert_dtype(val.dtype))))
x_shape = list(x.shape) x_shape = list(x.shape)
y_shape = list(y.shape) y_shape = list(y.shape)
if len(x_shape) == 1: if len(x_shape) == 1:
...@@ -4826,20 +4732,10 @@ def transpose(x, perm, name=None): ...@@ -4826,20 +4732,10 @@ def transpose(x, perm, name=None):
#(3L, 2L, 4L) #(3L, 2L, 4L)
""" """
if not isinstance(x, Variable): check_type_and_dtype(x, 'x', Variable,
raise TypeError( ['float16', 'float32', 'float64', 'int32', 'int64'],
"The type of Input(x) in transpose must be Variable, but received %s" 'transpose')
% (type(x))) check_type(perm, 'perm', list, 'transpose')
if convert_dtype(x.dtype) not in [
"float16", "float32", "float64", "int32", "int64"
]:
raise TypeError(
"The data type of Input(x) in transpose must be one of [float16, float32, float64, int32, int64], but received %s."
% (convert_dtype(x.dtype)))
if not isinstance(perm, list):
raise TypeError(
"The type of Input(perm) in transpose must be list, but received %s"
% (type(perm)))
if len(perm) != len(x.shape): if len(perm) != len(x.shape):
raise ValueError( raise ValueError(
"Input(perm) is the permutation of dimensions of Input(x), " "Input(perm) is the permutation of dimensions of Input(x), "
...@@ -5431,32 +5327,11 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): ...@@ -5431,32 +5327,11 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
reshaped_2 = fluid.layers.reshape(data_2, shape=[dim, 10]) reshaped_2 = fluid.layers.reshape(data_2, shape=[dim, 10])
# the shape of reshaped_2 is [5,10]. # the shape of reshaped_2 is [5,10].
""" """
if not isinstance(x, Variable): check_type_and_dtype(x, 'x', Variable,
raise TypeError( ['float16', 'float32', 'float64', 'int32', 'int64'],
"The type of 'x' in reshape must be Variable, but received %s." % 'reshape')
(type(x))) check_type(shape, 'shape', (list, tuple, Variable), 'reshape')
check_type(actual_shape, 'actual_shape', (Variable, type(None)), 'reshape')
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in reshape only support float16 in GPU now.")
if convert_dtype(x.dtype) not in [
'float16', 'float32', 'float64', 'int32', 'int64'
]:
raise TypeError(
"The data type of 'x' in reshape must be float16, float32, float64, int32 or int64, "
"but received %s." % (convert_dtype(x.dtype)))
if not isinstance(shape, (list, tuple, Variable)):
raise TypeError(
"The type of 'shape' in reshape must be Variable, list or tuple, but "
"received %s." % (type(shape)))
if not isinstance(actual_shape, Variable) and (actual_shape is not None):
raise TypeError(
"The type of 'actual_shape' in reshape must be Variable "
"or None, but received %s." % (type(actual_shape)))
helper = LayerHelper("reshape2", **locals()) helper = LayerHelper("reshape2", **locals())
inputs = {"X": x} inputs = {"X": x}
attrs = {} attrs = {}
...@@ -5592,23 +5467,10 @@ def squeeze(input, axes, name=None): ...@@ -5592,23 +5467,10 @@ def squeeze(input, axes, name=None):
""" """
helper = LayerHelper("squeeze", **locals()) helper = LayerHelper("squeeze", **locals())
check_type_and_dtype(input, 'input', Variable,
if not isinstance(input, Variable): ['float32', 'float64', 'int8', 'int32', 'int64'],
raise TypeError( 'squeeze')
"The type of 'input' in squeeze must be Variable, but received %s" % check_type(axes, 'axes', list, 'squeeze')
(type(input)))
if convert_dtype(input.dtype
) not in ['float32', 'float64', 'int8', 'int32', 'int64']:
raise TypeError(
"The data type of 'input' in squeeze must be float32, float64, int8, int32,"
"int64, but received %s." % (convert_dtype(input.dtype)))
if not isinstance(axes, list):
raise TypeError(
"The type of 'axes' in squeeze must be list, but received %s" %
(type(axes)))
out = helper.create_variable_for_type_inference(dtype=input.dtype) out = helper.create_variable_for_type_inference(dtype=input.dtype)
x_shape = helper.create_variable_for_type_inference(dtype=input.dtype) x_shape = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op( helper.append_op(
...@@ -8090,27 +7952,16 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -8090,27 +7952,16 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
""" """
helper = LayerHelper('crop_tensor', **locals()) helper = LayerHelper('crop_tensor', **locals())
check_type_and_dtype(x, 'x', Variable,
if convert_dtype(x.dtype) not in ['float32', 'float64', 'int32', 'int64']: ['float32', 'float64', 'int32', 'int64'],
raise TypeError( 'crop_tensor')
"Input(x)'s dtype of Op(crop_tensor) must be float32, float64, int32 or int64, " check_type(shape, 'shape', (list, tuple, Variable), 'crop_tensor')
"but received %s." % (convert_dtype(x.dtype))) check_type(offsets, 'offsets', (list, tuple, Variable, type(None)),
'crop_tensor')
if not (isinstance(shape, list) or isinstance(shape, tuple) or \
isinstance(shape, Variable)):
raise TypeError(
"Attr(shape) of Op(crop_tensor) should be a list, tuple or Variable."
)
if offsets is None: if offsets is None:
offsets = [0] * len(x.shape) offsets = [0] * len(x.shape)
if not (isinstance(offsets, list) or isinstance(offsets, tuple) or \
isinstance(offsets, Variable)):
raise TypeError(
"Attr(offsets) of Op(crop_tensor) should be a list, tuple or Variable."
)
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
ipts = {'X': x} ipts = {'X': x}
attrs = {} attrs = {}
...@@ -8399,17 +8250,8 @@ def elu(x, alpha=1.0, name=None): ...@@ -8399,17 +8250,8 @@ def elu(x, alpha=1.0, name=None):
# [ 1. 15.6 ]] # [ 1. 15.6 ]]
""" """
helper = LayerHelper('elu', **locals()) helper = LayerHelper('elu', **locals())
if not isinstance(x, Variable): check_type_and_dtype(x, 'x', Variable, ['float16', 'float32', 'float64'],
raise TypeError( 'elu')
"The type of 'x' in elu must be Variable, but received %s" %
(type(x)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in elu only support float16 in GPU now.")
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in elu must be float16 (only support on GPU), float32 or float64, but received %s."
% (convert_dtype(x.dtype)))
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='elu', type='elu',
...@@ -9206,18 +9048,10 @@ def expand(x, expand_times, name=None): ...@@ -9206,18 +9048,10 @@ def expand(x, expand_times, name=None):
expanded_2 = fluid.layers.expand(data_2, expand_times=expand_times) expanded_2 = fluid.layers.expand(data_2, expand_times=expand_times)
# the shape of expanded_2 is [48, 56]. # the shape of expanded_2 is [48, 56].
""" """
if not isinstance(x, Variable): check_type_and_dtype(x, 'x', Variable,
raise TypeError( ['bool', 'float32', 'float64', 'int32', 'int64'],
"The type of 'input' in reduce_sum must be Variable, but received %s" 'expand')
% (type(x))) check_type(expand_times, 'expand_times', (list, tuple, Variable), 'expand')
if not isinstance(expand_times, (list, tuple, Variable)):
raise ValueError(
"Input expand_times must be an Variable, python list or tuple.")
if convert_dtype(
x.dtype) not in ['bool', 'float32', 'float64', 'int32', 'int64']:
raise TypeError(
"The data type of input in expand must be one of bool float32, float64, int32 or int64, but received %s."
% (convert_dtype(x.dtype)))
if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True: if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True:
raise ValueError( raise ValueError(
"expand op bool date type must set the stop_gradient to be False") "expand op bool date type must set the stop_gradient to be False")
...@@ -10151,34 +9985,12 @@ def _elementwise_op(helper): ...@@ -10151,34 +9985,12 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(op_type) assert x is not None, 'x cannot be None in {}'.format(op_type)
assert y is not None, 'y cannot be None in {}'.format(op_type) assert y is not None, 'y cannot be None in {}'.format(op_type)
if not isinstance(x, Variable): check_type_and_dtype(x, 'x', Variable,
raise TypeError( ['float16', 'float32', 'float64', 'int32', 'int64'],
"The type of 'x' in %s must be Variable, but received %s" % op_type)
(op_type, type(x))) check_type_and_dtype(y, 'y', Variable,
if not isinstance(y, Variable): ['float16', 'float32', 'float64', 'int32', 'int64'],
raise TypeError( op_type)
"The type of 'y' in %s must be Variable, but received %s" %
(op_type, type(y)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in %s only support float16 on GPU now." %
(op_type))
if convert_dtype(y.dtype) in ['float16']:
warnings.warn(
"The data type of 'y' in %s only support float16 on GPU now." %
(op_type))
if convert_dtype(x.dtype) not in [
'float16', 'float32', 'float64', 'int32', 'int64'
]:
raise TypeError(
"The data type of 'x' in %s must be float16 or float32 or float64 or int32 or int64, "
"but received %s." % (op_type, convert_dtype(x.dtype)))
if convert_dtype(y.dtype) not in [
'float16', 'float32', 'float64', 'int32', 'int64'
]:
raise TypeError(
"The data type of 'y' in %s must be float16 or float32 or float64 or int32 or int64, "
"but received %s." % (op_type, convert_dtype(y.dtype)))
axis = helper.kwargs.get('axis', -1) axis = helper.kwargs.get('axis', -1)
use_mkldnn = helper.kwargs.get('use_mkldnn', False) use_mkldnn = helper.kwargs.get('use_mkldnn', False)
...@@ -11170,21 +10982,8 @@ def mean(x, name=None): ...@@ -11170,21 +10982,8 @@ def mean(x, name=None):
""" """
helper = LayerHelper("mean", **locals()) helper = LayerHelper("mean", **locals())
check_type_and_dtype(x, 'x', Variable, ['float16', 'float32', 'float64'],
if not isinstance(x, Variable): 'mean')
raise TypeError(
"The type of 'x' in mean must be Variable, but received %s.\n" %
(type(x)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in mean only support float16 in GPU now.")
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in mean must be float16 or float32 or float64, but received %s.\n"
% (convert_dtype(x.dtype)))
if name is None: if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
...@@ -11265,30 +11064,10 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): ...@@ -11265,30 +11064,10 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
""" """
helper = LayerHelper("mul", **locals()) helper = LayerHelper("mul", **locals())
check_type_and_dtype(x, 'x', Variable, ['float16', 'float32', 'float64'],
if not isinstance(x, Variable): 'mul')
raise TypeError( check_type_and_dtype(y, 'y', Variable, ['float16', 'float32', 'float64'],
"The type of 'x' in mul must be Variable, but received %s" % 'mul')
(type(x)))
if not isinstance(y, Variable):
raise TypeError(
"The type of 'y' in mul must be Variable, but received %s" %
(type(y)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in mul only support float16 in GPU now.")
if convert_dtype(y.dtype) in ['float16']:
warnings.warn(
"The data type of 'y' in mul only support float16 in GPU now.")
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in mul must be float16, float32 or float64, but received %s."
% (convert_dtype(x.dtype)))
if convert_dtype(y.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'y' in mul must be float16, float32 or float64, but received %s."
% (convert_dtype(y.dtype)))
if name is None: if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
...@@ -12746,23 +12525,10 @@ def sign(x): ...@@ -12746,23 +12525,10 @@ def sign(x):
""" """
helper = LayerHelper("sign", **locals()) helper = LayerHelper("sign", **locals())
check_type(x, 'x', (Variable, np.ndarray), 'sign')
if not isinstance(x, Variable):
if isinstance(x, np.ndarray): if isinstance(x, np.ndarray):
x = assign(x) x = assign(x)
else: check_dtype(x.dtype, 'x', ['float16', 'float32', 'float64'], 'sign')
raise TypeError(
"The type of 'x' in sign_op must be Variable or numpy.ndarray, but received %s."
% (type(x)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in sign_op only support float16 in GPU now.")
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in sign_op must be float16, float32 or float64, but received %s."
% (convert_dtype(x.dtype)))
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='sign', inputs={'X': [x]}, outputs={'Out': [out]}) helper.append_op(type='sign', inputs={'X': [x]}, outputs={'Out': [out]})
...@@ -13633,18 +13399,10 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): ...@@ -13633,18 +13399,10 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
""" """
if not (isinstance(shape, (list, tuple, Variable))): check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random')
raise TypeError(
"Input shape must be a python list,Variable or tuple. But received %s"
% (type(shape)))
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(dtype, 'dtype', ['float32', 'float64'], 'uniform_random')
if convert_dtype(dtype) not in ['float32', 'float64']:
raise TypeError(
"The attribute dtype in uniform_random op must be float32 or float64, but received %s."
% (convert_dtype(dtype)))
def contain_var(one_list): def contain_var(one_list):
for ele in one_list: for ele in one_list:
......
...@@ -21,10 +21,9 @@ from ..framework import Variable ...@@ -21,10 +21,9 @@ from ..framework import Variable
from ..initializer import Constant, force_init_on_cpu from ..initializer import Constant, force_init_on_cpu
from ..core import VarDesc from ..core import VarDesc
from .layer_function_generator import templatedoc from .layer_function_generator import templatedoc
from ..data_feeder import convert_dtype from ..data_feeder import check_type_and_dtype, check_type, check_dtype, convert_dtype
import numpy import numpy
import warnings import warnings
from ..data_feeder import convert_dtype
__all__ = [ __all__ = [
'create_tensor', 'create_parameter', 'create_global_var', 'cast', 'create_tensor', 'create_parameter', 'create_global_var', 'cast',
...@@ -192,17 +191,10 @@ def cast(x, dtype): ...@@ -192,17 +191,10 @@ def cast(x, dtype):
# [ 0 4]] int32 # [ 0 4]] int32
""" """
helper = LayerHelper('cast', **locals()) helper = LayerHelper('cast', **locals())
if not isinstance(x, Variable): check_type_and_dtype(
raise TypeError( x, 'x', Variable,
"The type of 'x' in cast must be Variable, but received %s" % ['bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
(type(x))) 'cast')
if convert_dtype(x.dtype) not in [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'
]:
raise TypeError(
"The data type of 'x' in cast must be one of [bool, float16, float32, float64, int32, int64, uint8], but received %s."
% (convert_dtype(x.dtype)))
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op( helper.append_op(
type='cast', type='cast',
...@@ -265,25 +257,11 @@ def concat(input, axis=0, name=None): ...@@ -265,25 +257,11 @@ def concat(input, axis=0, name=None):
"The type of input in concat should be list, but received %s." % "The type of input in concat should be list, but received %s." %
(type(input))) (type(input)))
input = [input] input = [input]
for x in input: for id, x in enumerate(input):
if not isinstance(x, Variable): check_type_and_dtype(
raise TypeError( x, 'input[' + str(id) + ']', Variable,
"The type of x in 'input' in concat must be Variable, but received %s." ['float16', 'float32', 'float64', 'int32', 'int64'], 'concat')
% (type(x))) check_type(axis, 'axis', (int, Variable), 'concat')
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of x in 'input' in concat only support float16 on GPU now."
)
if convert_dtype(x.dtype) not in [
'float16', 'float32', 'float64', 'int32', 'int64'
]:
raise TypeError(
"The data type of x in 'input' in concat must be float16(only support on GPU), float32, float64, int32, int64, but received %s."
% (convert_dtype(x.dtype)))
if not isinstance(axis, (int, Variable)):
raise TypeError(
"The type of 'axis' in concat must be int or Variable, but "
"received %s." % (type(axis)))
inputs = {'X': input} inputs = {'X': input}
attrs = {} attrs = {}
if isinstance(axis, Variable): if isinstance(axis, Variable):
...@@ -478,14 +456,11 @@ def assign(input, output=None): ...@@ -478,14 +456,11 @@ def assign(input, output=None):
result3 = fluid.layers.assign(np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32')) # result3 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]] result3 = fluid.layers.assign(np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32')) # result3 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]]
""" """
helper = LayerHelper('assign', **locals()) helper = LayerHelper('assign', **locals())
check_type(input, 'input', (Variable, numpy.ndarray), 'assign')
if isinstance(input, Variable): if isinstance(input, Variable):
if convert_dtype(input.dtype) not in [ check_dtype(input.dtype, 'input',
'float32', 'float64', 'int32', 'int64', 'bool' ['float32', 'float64', 'int32', 'int64', 'bool'], 'assign',
]: '(When the type of input in assign is Variable.)')
raise TypeError(
"When the type of 'input' in assign is Variable, the data "
"type of 'input' must be float32, float64, int32, int64 or "
"bool, but received %s." % convert_dtype(input.dtype))
if output is None: if output is None:
output = helper.create_variable_for_type_inference( output = helper.create_variable_for_type_inference(
dtype=input.dtype) dtype=input.dtype)
...@@ -518,9 +493,6 @@ def assign(input, output=None): ...@@ -518,9 +493,6 @@ def assign(input, output=None):
'shape': list(input.shape), 'shape': list(input.shape),
value_name: values value_name: values
}) })
else:
raise TypeError("The type of 'input' in assign must be Variable or "
"numpy.ndarray, but received %s" % type(input))
return output return output
...@@ -570,19 +542,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -570,19 +542,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]] data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]]
""" """
helper = LayerHelper("fill_constant", **locals()) helper = LayerHelper("fill_constant", **locals())
if convert_dtype(dtype) not in [ check_dtype(dtype, 'create data type',
'bool', 'float16', 'float32', 'float64', 'int32', 'int64' ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
]: 'fill_constant')
raise TypeError( check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')
"The create data type in fill_constant must be one of 'bool', float16, float32,"
"float64, int32 or int64, but received %s." % convert_dtype(
(dtype)))
if not isinstance(shape, (list, tuple, Variable)):
raise TypeError(
"The type of 'shape' in fill_constant must be Variable, list or tuple, but "
"received %s." % (type(shape)))
inputs = {} inputs = {}
attrs = { attrs = {
'value': float(value), 'value': float(value),
...@@ -609,12 +572,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -609,12 +572,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
for idx, dim in enumerate(list_shape): for idx, dim in enumerate(list_shape):
if isinstance(dim, Variable): if isinstance(dim, Variable):
dim.stop_gradient = True dim.stop_gradient = True
if convert_dtype(dim.dtype) not in ['int32', 'int64']: check_dtype(
raise TypeError( dim.dtype, 'shape[' + str(idx) + ']', ['int32', 'int64'],
"When type of 'shape' in fill_constant is list or tuple, " 'fill_constant',
"the data type of the element with type Variable must be int32 or int64, " '(When type of shape in fill_constant is list or tuple.)')
"but received the data type of shape[%d] is %s." %
(idx, convert_dtype(dim.dtype)))
if convert_dtype(dim.dtype) == 'int64': if convert_dtype(dim.dtype) == 'int64':
dim = cast(x=dim, dtype='int32') dim = cast(x=dim, dtype='int32')
new_shape_tensor.append(dim) new_shape_tensor.append(dim)
...@@ -626,10 +587,8 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -626,10 +587,8 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
if isinstance(shape, Variable): if isinstance(shape, Variable):
shape.stop_gradient = True shape.stop_gradient = True
if convert_dtype(shape.dtype) not in ['int32', 'int64']: check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant',
raise TypeError( '(When type of shape in fill_constant is Variable.)')
"When type of 'shape' in fill_constant is Variable, the data type of 'shape' must be int32 or int64, "
"but received %s." % (convert_dtype(shape.dtype)))
if (convert_dtype(shape.dtype) == 'int64'): if (convert_dtype(shape.dtype) == 'int64'):
shape = cast(shape, 'int32') shape = cast(shape, 'int32')
inputs["ShapeTensor"] = shape inputs["ShapeTensor"] = shape
...@@ -644,11 +603,11 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -644,11 +603,11 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
if out is None: if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
else: else:
if not (convert_dtype(dtype) == convert_dtype(out.dtype)): check_dtype(
raise TypeError( dtype, 'create data type',
"The create data type in op must be same with out type" convert_dtype(out.dtype), 'fill_constant',
"but received %s and out dtype %s." % (convert_dtype( '(The create data type in fill_constant must be the same with out data type.)'
(dtype), convert_dtype(out.dtype)))) )
attrs['dtype'] = out.dtype attrs['dtype'] = out.dtype
helper.append_op( helper.append_op(
type='fill_constant', type='fill_constant',
...@@ -970,13 +929,9 @@ def zeros(shape, dtype, force_cpu=False): ...@@ -970,13 +929,9 @@ def zeros(shape, dtype, force_cpu=False):
import paddle.fluid as fluid import paddle.fluid as fluid
data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]] data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]]
""" """
if convert_dtype(dtype) not in [ check_dtype(dtype, 'create data type',
'bool', 'float16', 'float32', 'float64', 'int32', 'int64' ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
]: 'zeros')
raise TypeError(
"The create data type in zeros must be one of bool, float16, float32,"
" float64, int32 or int64, but received %s." % convert_dtype(
(dtype)))
return fill_constant(value=0.0, **locals()) return fill_constant(value=0.0, **locals())
......
...@@ -64,13 +64,13 @@ class TestAccuracyOpError(OpTest): ...@@ -64,13 +64,13 @@ class TestAccuracyOpError(OpTest):
# The input type of accuracy_op must be Variable. # The input type of accuracy_op must be Variable.
x1 = fluid.create_lod_tensor( x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace()) np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.accuracy, x1) label = fluid.layers.data(
name='label', shape=[-1, 1], dtype="int32")
self.assertRaises(TypeError, fluid.layers.accuracy, x1, label)
# The input dtype of accuracy_op must be float32 or float64. # The input dtype of accuracy_op must be float32 or float64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32") x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.accuracy, x2) self.assertRaises(TypeError, fluid.layers.accuracy, x2, label)
x3 = fluid.layers.data(name='input', shape=[-1, 2], dtype="float16") x3 = fluid.layers.data(name='input', shape=[-1, 2], dtype="float16")
label = fluid.layers.data(
name='label', shape=[-1, 1], dtype="int32")
fluid.layers.accuracy(input=x3, label=label) fluid.layers.accuracy(input=x3, label=label)
......
...@@ -28,10 +28,9 @@ class TestZerosOpError(OpTest): ...@@ -28,10 +28,9 @@ class TestZerosOpError(OpTest):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
# The input dtype of zeros_op must be bool, float16, float32, float64, int32, int64. # The input dtype of zeros_op must be bool, float16, float32, float64, int32, int64.
x1 = fluid.layers.data(name='x1', shape=[4], dtype="int8") shape = [4]
self.assertRaises(TypeError, fluid.layers.zeros, x1) dtype = "int8"
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8") self.assertRaises(TypeError, fluid.layers.zeros, shape, dtype)
self.assertRaises(TypeError, fluid.layers.zeros, x2)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册