提交 51feea03 编写于 作者: M ms_yan 提交者: 高东海

Repair parameter check problem in TFRecordDataset

上级 bd4a206a
......@@ -398,6 +398,7 @@ def check_tfrecorddataset(method):
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_list = ['columns_list']
nreq_param_bool = ['shard_equal_rows']
# check dataset_files; required argument
dataset_files = param_dict.get('dataset_files')
......@@ -410,6 +411,10 @@ def check_tfrecorddataset(method):
check_param_type(nreq_param_list, param_dict, list)
check_param_type(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
return method(*args, **kwargs)
return new_method
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册