未验证 提交 a9577347 编写于 作者: G Guoxia Wang 提交者: GitHub

fix dim check of class center sample (#35733)

上级 78465703
......@@ -1701,14 +1701,14 @@ def class_center_sample(label, num_classes, num_samples, group=None):
label_size = 1
for dim in list(label.shape):
label_size *= dim
if label_size < 1:
if label_size != -1 and label_size < 1:
raise ValueError('Expected label_size > 0 \
(got label_size{})'.format(label_size))
(got label_size: {})'.format(label_size))
label_dims = len(list(label.shape))
if label_dims != 1:
raise ValueError('Expected label_dims == 1 \
(got label_dims{})'.format(label_dims))
(got label_dims: {})'.format(label_dims))
seed = None
if (seed is None or seed == 0) and default_main_program().random_seed != 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册