提交 8f16cff1 编写于 作者: Y yanghaitao

add para check for sampler

上级 a6a9f884
...@@ -218,6 +218,11 @@ class DistributedSampler(BuiltinSampler): ...@@ -218,6 +218,11 @@ class DistributedSampler(BuiltinSampler):
if not isinstance(shuffle, bool): if not isinstance(shuffle, bool):
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
self.num_shards = num_shards self.num_shards = num_shards
self.shard_id = shard_id self.shard_id = shard_id
self.shuffle = shuffle self.shuffle = shuffle
...@@ -282,6 +287,11 @@ class PKSampler(BuiltinSampler): ...@@ -282,6 +287,11 @@ class PKSampler(BuiltinSampler):
if not isinstance(shuffle, bool): if not isinstance(shuffle, bool):
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
self.num_val = num_val self.num_val = num_val
self.shuffle = shuffle self.shuffle = shuffle
self.class_column = class_column # work for minddataset self.class_column = class_column # work for minddataset
...@@ -385,6 +395,16 @@ class SequentialSampler(BuiltinSampler): ...@@ -385,6 +395,16 @@ class SequentialSampler(BuiltinSampler):
""" """
def __init__(self, start_index=None, num_samples=None): def __init__(self, start_index=None, num_samples=None):
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
if start_index is not None:
if start_index < 0:
raise ValueError("start_index should be a positive integer "
"value or 0, but got start_index={}".format(start_index))
self.start_index = start_index self.start_index = start_index
super().__init__(num_samples) super().__init__(num_samples)
...@@ -430,6 +450,11 @@ class SubsetRandomSampler(BuiltinSampler): ...@@ -430,6 +450,11 @@ class SubsetRandomSampler(BuiltinSampler):
""" """
def __init__(self, indices, num_samples=None): def __init__(self, indices, num_samples=None):
if num_samples is not None:
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
if not isinstance(indices, list): if not isinstance(indices, list):
indices = [indices] indices = [indices]
......
...@@ -43,7 +43,6 @@ def test_sequential_sampler(print_res=False): ...@@ -43,7 +43,6 @@ def test_sequential_sampler(print_res=False):
assert test_config(num_samples=3, num_repeats=None) == [0, 1, 2] assert test_config(num_samples=3, num_repeats=None) == [0, 1, 2]
assert test_config(num_samples=None, num_repeats=2) == [0, 1, 2, 3, 4] * 2 assert test_config(num_samples=None, num_repeats=2) == [0, 1, 2, 3, 4] * 2
assert test_config(num_samples=0, num_repeats=2) == [0, 1, 2, 3, 4] * 2
assert test_config(num_samples=4, num_repeats=2) == [0, 1, 2, 3] * 2 assert test_config(num_samples=4, num_repeats=2) == [0, 1, 2, 3] * 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册