未验证 提交 a6e935f4 编写于 作者: J juncaipeng 提交者: GitHub

Update op test framework (#21599)

* update op test framework
上级 7c386123
......@@ -33,6 +33,7 @@ from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, OpProtoHolder, Variable
from testsuite import create_op, set_input, append_input_output, append_loss_ops
from paddle.fluid import unique_name
from white_list import op_accuracy_white_list
def _set_use_system_allocator(value=None):
......@@ -141,7 +142,7 @@ def get_numeric_gradient(place,
return gradient_flat.reshape(tensor_to_check.shape())
class OpTest(unittest.TestCase):
class OpTestBase(unittest.TestCase):
@classmethod
def setUpClass(cls):
'''Fix random seeds to remove randomness from tests'''
......@@ -170,24 +171,47 @@ class OpTest(unittest.TestCase):
self.dtype = data_type
def infer_dtype_from_inputs_outputs(self, inputs, outputs):
def infer_dtype(numpy_dict):
def is_np_data(input):
return isinstance(input, (np.ndarray, np.generic))
def infer_dtype(numpy_dict, dtype_set):
assert isinstance(
numpy_dict,
dict), "self.inputs, self.outputs must be numpy_dict"
for var_name, var_value in six.iteritems(numpy_dict):
if isinstance(var_value, (np.ndarray, np.generic)):
self.try_call_once(var_value.dtype)
elif isinstance(var_value, (list, tuple)):
# the case of self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
if len(var_value) > 1 and isinstance(var_value[1], (
np.ndarray, np.generic)):
instance = var_value[1]
self.try_call_once(instance[1].dtype)
else:
self.try_call_once("float32")
infer_dtype(inputs)
infer_dtype(outputs)
# the inputs are as follows:
# case 1: inputs = {'X': x}
# case 2: inputs = {'X': (x, x_lod)}
# case 3: inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
# case 4: inputs = {'X': [("x1", (x1, [x1_lod1])), ("x2", (x2, [x2_.lod2]))]}
# TODO(juncaipeng) infer dtype from inputs maybe obtain wrong type.
for _, var_value in six.iteritems(numpy_dict):
if is_np_data(var_value): # case 1
dtype_set.add(var_value.dtype)
elif isinstance(var_value, (list, tuple)): # case 2, 3, 4
for sub_val_value in var_value:
if is_np_data(sub_val_value): # case 2
dtype_set.add(sub_val_value.dtype)
elif len(sub_val_value) > 1 and is_np_data(
sub_val_value[1]): # case 3
dtype_set.add(sub_val_value[1].dtype)
elif len(sub_val_value) > 1 and isinstance(sub_val_value[1], (list, tuple)) \
and is_np_data(sub_val_value[1][0]): # case 4
dtype_set.add(sub_val_value[1][0].dtype)
# infer dtype from inputs, and dtype means the precision of the test
# collect dtype of all inputs
dtype_set = set()
infer_dtype(inputs, dtype_set)
dtype_list = [
np.dtype(np.float64), np.dtype(np.float32), np.dtype(np.float16),
np.dtype(np.int64), np.dtype(np.int32), np.dtype(np.int16),
np.dtype(np.int8)
]
# check the dtype in dtype_list in order, select the first dtype that in dtype_set
for dtype in dtype_list:
if dtype in dtype_set:
self.dtype = dtype
break
def feed_var(self, input_vars, place):
feed_map = {}
......@@ -214,6 +238,7 @@ class OpTest(unittest.TestCase):
return feed_map
def _append_ops(self, block):
self.__class__.op_type = self.op_type # for ci check, please not delete it for now
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
"infer datatype from inputs and outputs for this test case"
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
......@@ -352,6 +377,7 @@ class OpTest(unittest.TestCase):
return var_dict
def _calc_dygraph_output(self, place, parallel=False, no_check_set=None):
self.__class__.op_type = self.op_type # for ci check, please not delete it for now
with fluid.dygraph.base.guard(place=place):
block = fluid.default_main_program().global_block()
......@@ -1263,3 +1289,157 @@ class OpTest(unittest.TestCase):
return list(
map(np.array,
executor.run(prog, feed_dict, fetch_list, return_numpy=False)))
'''
The op test with int8 precision should inherit OpTestInt8.
'''
class OpTestInt8(OpTestBase):
pass
'''
The op test with float16 precision should inherit OpTestFp16,
which requires the test to call check_grad.
'''
class OpTestFp16(OpTestBase):
def check_output(self,
atol=1e-5,
no_check_set=None,
equal_nan=False,
check_dygraph=True,
inplace_atol=None,
check_compile_vs_runtime=False):
self.__class__.op_type = self.op_type
OpTestBase.check_output(self, atol, no_check_set, equal_nan,
check_dygraph, inplace_atol,
check_compile_vs_runtime)
def _check_grad_helper(self):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
assert self.dtype == np.float16, "The dtype of this test should be float16."
self.__class__.op_type = self.op_type
self.__class__.exist_check_grad = True
def check_grad(self,
inputs_to_check,
output_names,
no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
check_dygraph=True):
self._check_grad_helper()
OpTestBase.check_grad(self, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error,
user_defined_grads, check_dygraph)
def check_grad_with_place(self,
place,
inputs_to_check,
output_names,
no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
check_dygraph=True):
self._check_grad_helper()
OpTestBase.check_grad_with_place(
self, place, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error,
user_defined_grads, check_dygraph)
@classmethod
def tearDownClass(cls):
"""Restore random seeds"""
np.random.set_state(cls._np_rand_state)
random.setstate(cls._py_rand_state)
if cls.__name__ not in op_accuracy_white_list.NO_NEED_FP16_CHECK_GRAD_CASES \
and not hasattr(cls, "exist_check_grad") \
and cls.op_type not in op_accuracy_white_list.NO_FP16_CHECK_GRAD_OP_LIST:
raise AssertionError("This test of %s op needs check_grad." %
cls.op_type)
'''
The op test with float32/64 precision should inherit OpTest,
which requires the test to call check_grad with float64 precision.
'''
class OpTest(OpTestBase):
def check_output(self,
atol=1e-5,
no_check_set=None,
equal_nan=False,
check_dygraph=True,
inplace_atol=None,
check_compile_vs_runtime=False):
self.__class__.op_type = self.op_type
OpTestBase.check_output(self, atol, no_check_set, equal_nan,
check_dygraph, inplace_atol,
check_compile_vs_runtime)
def _check_grad_helper(self):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
assert self.dtype in [np.float16, np.float32, np.float64], \
"self.dtype = %s." % self.dtype
if self.dtype == np.float16 and \
self.op_type not in op_accuracy_white_list.FP16_CHECK_OP_LIST:
raise AssertionError("The dtype of this test should be float32 "
"or float64. op: %s dtype: %s." %
(self.op_type, self.dtype))
self.__class__.op_type = self.op_type
if self.dtype == np.float64:
self.__class__.exist_fp64_check_grad = True
def check_grad(self,
inputs_to_check,
output_names,
no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
check_dygraph=True):
self._check_grad_helper()
OpTestBase.check_grad(self, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error,
user_defined_grads, check_dygraph)
def check_grad_with_place(self,
place,
inputs_to_check,
output_names,
no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
check_dygraph=True):
self._check_grad_helper()
OpTestBase.check_grad_with_place(
self, place, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error,
user_defined_grads, check_dygraph)
@classmethod
def tearDownClass(cls):
"""Restore random seeds"""
np.random.set_state(cls._np_rand_state)
random.setstate(cls._py_rand_state)
if cls.__name__ not in op_accuracy_white_list.NO_NEED_FP64_CHECK_GRAD_CASES \
and not hasattr(cls, 'exist_fp64_check_grad') \
and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST:
raise AssertionError("This test of %s op needs fp64 check_grad." %
cls.op_type)
......@@ -26,8 +26,8 @@ class TestSequenceConcat(OpTest):
self.out_lod = [19, 11]
def setUp(self):
x1 = np.random.random(size=(10, 80))
x2 = np.random.random(size=(20, 80))
x1 = np.random.random(size=(10, 80)).astype('float32')
x2 = np.random.random(size=(20, 80)).astype('float32')
self.setLoD()
out = np.concatenate((x1[0:self.lod1[0]], x2[0:self.lod2[0]],
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# For op in FP16_CHECK_OP_LIST, the op test of fp16 precision should inherit OpTestFp16
FP16_CHECK_OP_LIST = [
'abs', 'acos', 'asin', 'atan', 'brelu', 'concat', 'cos', 'elementwise_div',
'elementwise_mul', 'elu', 'exp', 'gelu', 'hard_shrink', 'hard_swish', 'log',
'logsigmoid', 'mean', 'mul', 'pad', 'pool2d', 'pow', 'reciprocal', 'relu',
'relu6', 'scale', 'sigmoid', 'sin', 'slice', 'soft_relu', 'softmax',
'softmax_with_cross_entropy', 'softshrink', 'softsign', 'sqrt', 'square',
'stanh', 'sum', 'swish', 'tanh', 'tanh_shrink', 'thresholded_relu'
]
# For op in NO_FP64_CHECK_GRAD_OP_LIST, the op test requires check_grad with fp64 precision
NO_FP64_CHECK_GRAD_OP_LIST = [
'abs', 'accuracy', 'acos', 'adadelta', 'adagrad', 'adam', 'adamax',
'add_position_encoding', 'affine_grid', 'anchor_generator', 'arg_max',
'arg_min', 'argsort', 'asin', 'assign_value', 'atan', 'attention_lstm',
'auc', 'bilinear_interp', 'bilinear_tensor_product', 'bipartite_match',
'box_clip', 'box_coder', 'box_decoder_and_assign', 'brelu', 'cast', 'ceil',
'center_loss', 'chunk_eval', 'clip', 'clip_by_norm', 'coalesce_tensor',
'collect_fpn_proposals', 'concat', 'conv2d', 'conv2d_fusion',
'conv2d_transpose', 'conv3d', 'conv3d_transpose', 'conv_shift', 'cos',
'cos_sim', 'crf_decoding', 'crop', 'crop_tensor', 'cross_entropy',
'cross_entropy2', 'ctc_align', 'cudnn_lstm', 'cvm', 'data_norm',
'decayed_adagrad', 'deformable_conv', 'deformable_conv_v1',
'deformable_psroi_pooling', 'density_prior_box', 'depthwise_conv2d',
'depthwise_conv2d_transpose', 'dequantize', 'dequantize_abs_max',
'detection_map', 'diag', 'distribute_fpn_proposals', 'dpsgd', 'dropout',
'edit_distance', 'elementwise_add', 'elementwise_div',
'elementwise_floordiv', 'elementwise_max', 'elementwise_min',
'elementwise_mod', 'elementwise_mul', 'elementwise_pow', 'elementwise_sub',
'elu', 'equal', 'exp', 'expand', 'eye',
'fake_channel_wise_dequantize_max_abs',
'fake_channel_wise_quantize_abs_max', 'fake_dequantize_max_abs',
'fake_quantize_abs_max', 'fake_quantize_dequantize_moving_average_abs_max',
'fake_quantize_moving_average_abs_max', 'fake_quantize_range_abs_max', 'fc',
'fill', 'fill_any_like', 'fill_constant', 'fill_constant_batch_size_like',
'fill_zeros_like', 'fill_zeros_like2', 'flatten', 'flatten2', 'floor',
'ftrl', 'fused_elemwise_activation', 'fused_embedding_fc_lstm',
'fused_embedding_seq_pool', 'fused_fc_elementwise_layernorm', 'fusion_gru',
'fusion_lstm', 'fusion_repeated_fc_relu', 'fusion_seqconv_eltadd_relu',
'fusion_seqexpand_concat_fc', 'fusion_seqpool_concat',
'fusion_seqpool_cvm_concat', 'fusion_squared_mat_sub',
'fusion_transpose_flatten_concat', 'gather', 'gather_nd', 'gather_tree',
'gaussian_random_batch_size_like', 'gelu', 'generate_mask_labels',
'generate_proposal_labels', 'generate_proposals', 'greater_equal',
'greater_than', 'grid_sampler', 'group_norm', 'hard_shrink', 'hard_sigmoid',
'hard_swish', 'hash', 'hierarchical_sigmoid', 'hinge_loss', 'huber_loss',
'im2sequence', 'increment', 'iou_similarity', 'is_empty', 'isfinite',
'isinf', 'isnan', 'kldiv_loss', 'l1_norm', 'lamb', 'lars_momentum',
'leaky_relu', 'less_equal', 'less_than', 'linspace', 'locality_aware_nms',
'lod_reset', 'log', 'log_loss', 'logical_and', 'logical_not', 'logical_or',
'logical_xor', 'logsigmoid', 'lookup_table', 'lookup_table_v2', 'lrn',
'margin_rank_loss', 'match_matrix_tensor', 'matmul',
'max_pool2d_with_index', 'max_pool3d_with_index', 'maxout', 'mean',
'mean_iou', 'merge_ids', 'mine_hard_examples', 'minus',
'modified_huber_loss', 'momentum', 'moving_average_abs_max_scale', 'mul',
'multiclass_nms', 'multiclass_nms2', 'multihead_matmul', 'multiplex', 'nce',
'nearest_interp', 'not_equal', 'one_hot', 'one_hot_v2', 'pad', 'pad2d',
'pad_constant_like', 'pixel_shuffle', 'polygon_box_transform', 'pool2d',
'pool3d', 'positive_negative_pair', 'pow', 'precision_recall', 'prelu',
'prior_box', 'proximal_adagrad', 'proximal_gd', 'prroi_pool', 'psroi_pool',
'quantize', 'random_crop', 'range', 'rank_loss', 'reciprocal', 'reduce_all',
'reduce_any', 'reduce_max', 'reduce_min', 'ref_by_trainer_id', 'relu',
'relu6', 'requantize', 'reshape2', 'retinanet_detection_output',
'retinanet_target_assign', 'reverse', 'roi_align',
'roi_perspective_transform', 'roi_pool', 'round', 'row_conv',
'rpn_target_assign', 'rsqrt', 'sampling_id', 'scale', 'scatter',
'scatter_nd_add', 'seed', 'selu', 'sequence_concat', 'sequence_conv',
'sequence_enumerate', 'sequence_erase', 'sequence_expand',
'sequence_expand_as', 'sequence_mask', 'sequence_pad', 'sequence_pool',
'sequence_reshape', 'sequence_reverse', 'sequence_scatter',
'sequence_slice', 'sequence_softmax', 'sequence_topk_avg_pooling',
'sequence_unpad', 'sgd', 'shape', 'shard_index', 'shuffle_channel',
'sigmoid', 'sigmoid_cross_entropy_with_logits', 'sigmoid_focal_loss',
'sign', 'similarity_focus', 'sin', 'size', 'slice', 'smooth_l1_loss',
'soft_relu', 'softmax', 'softshrink', 'softsign', 'space_to_depth',
'spectral_norm', 'split', 'split_ids', 'spp', 'sqrt', 'square',
'squared_l2_distance', 'squared_l2_norm', 'squeeze', 'squeeze2', 'stack',
'stanh', 'strided_slice', 'sum', 'swish', 'tanh', 'tanh_shrink',
'target_assign', 'teacher_student_sigmoid_loss', 'temporal_shift',
'thresholded_relu', 'top_k', 'transpose2', 'tree_conv', 'trilinear_interp',
'unfold', 'uniform_random', 'uniform_random_batch_size_like', 'unique',
'unique_with_counts', 'unpool', 'unsqueeze', 'unsqueeze2', 'unstack',
'var_conv_2d', 'warpctc', 'where', 'yolo_box', 'yolov3_loss'
]
NO_NEED_FP64_CHECK_GRAD_CASES = ['TestFSPOp']
NO_FP16_CHECK_GRAD_OP_LIST = []
NO_NEED_FP16_CHECK_GRAD_CASES = []
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册