From a9577347c482aa152d5ba285b94d7312dfc242ff Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Wed, 15 Sep 2021 11:33:40 +0800 Subject: [PATCH] fix dim check of class center sample (#35733) --- python/paddle/nn/functional/common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 323a8af52bc..fcfbea438d7 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1701,14 +1701,14 @@ def class_center_sample(label, num_classes, num_samples, group=None): label_size = 1 for dim in list(label.shape): label_size *= dim - if label_size < 1: + if label_size != -1 and label_size < 1: raise ValueError('Expected label_size > 0 \ - (got label_size{})'.format(label_size)) + (got label_size: {})'.format(label_size)) label_dims = len(list(label.shape)) if label_dims != 1: raise ValueError('Expected label_dims == 1 \ - (got label_dims{})'.format(label_dims)) + (got label_dims: {})'.format(label_dims)) seed = None if (seed is None or seed == 0) and default_main_program().random_seed != 0: -- GitLab