未验证 提交 e133d8ef 编写于 作者: G Guoxia Wang 提交者: GitHub

fix bug (#35482)

上级 3c457a38
......@@ -323,6 +323,11 @@ class MarginCrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
T><<<NumBlocks(N), threads, 0, dev_ctx.stream()>>>(
logits_ptr, labels->data<LabelT>(), margin1, margin2, margin3, rank,
nranks, N, D, class_interval.data<int>());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"margin_cross_entropy label type noly support int32 and int64, "
"but got %s",
label_type));
}
// scale by s
......
......@@ -138,7 +138,7 @@ class TestClassCenterSampleV2(unittest.TestCase):
label = paddle.static.data(
name='label', shape=[self.batch_size], dtype=self.dtype)
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, self.num_classes, self.num_samples, seed=self.seed)
label, self.num_classes, self.num_samples)
remapped_label_np, sampled_class_center_np = class_center_sample_numpy(
label_np, [self.num_classes], self.num_samples)
......@@ -163,7 +163,7 @@ class TestClassCenterSampleV2(unittest.TestCase):
label = paddle.to_tensor(label_np, dtype=self.dtype)
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, self.num_classes, self.num_samples, seed=self.seed)
label, self.num_classes, self.num_samples)
remapped_label_np, sampled_class_center_np = class_center_sample_numpy(
label_np, [self.num_classes], self.num_samples)
......@@ -210,13 +210,41 @@ class TestClassCenterSampleAPIError(unittest.TestCase):
label = paddle.to_tensor(label_np)
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label,
self.num_classes,
self.num_samples,
seed=self.seed)
label, self.num_classes, self.num_samples)
self.assertRaises(ValueError, test_num_samples)
class TestClassCenterSampleAPIError1(unittest.TestCase):
def setUp(self):
self.initParams()
np.random.seed(self.seed)
self.places = [paddle.fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(paddle.fluid.CUDAPlace(0))
def initParams(self):
self.batch_size = 5
self.num_samples = 5
self.num_classes = 10
self.seed = 2021
self.init_dtype()
def init_dtype(self):
self.dtype = np.int64
def test_dynamic_errors(self):
def test_empty_label():
for place in self.places:
with paddle.fluid.dygraph.guard(place):
label = paddle.to_tensor(np.array([], dtype=self.dtype))
remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, self.num_classes, self.num_samples)
print(remapped_label, sampled_class_index)
self.assertRaises(ValueError, test_empty_label)
if __name__ == '__main__':
unittest.main()
......@@ -378,7 +378,30 @@ class TestMarginCrossEntropyOpAPIError(unittest.TestCase):
return_softmax=True,
reduction=None)
def test_label_type():
for place in self.places:
with paddle.fluid.dygraph.guard(place):
labels_np = np.random.uniform(
0, self.num_class,
(self.batch_dim, 1)).astype(self.dtype)
logits_np = np.random.uniform(
-0.99, 0.99,
[self.batch_dim, self.num_class]).astype(self.dtype)
labels = paddle.to_tensor(labels_np)
logits = paddle.to_tensor(logits_np)
loss, softmax = paddle.nn.functional.margin_cross_entropy(
logits,
labels,
margin1=self.margin1,
margin2=self.margin2,
margin3=self.margin3,
scale=self.scale,
return_softmax=True,
reduction=None)
self.assertRaises(ValueError, test_dim)
self.assertRaises(NotImplementedError, test_label_type)
if __name__ == '__main__':
......
......@@ -1584,7 +1584,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
return smooth_label
def class_center_sample(label, num_classes, num_samples, group=None, seed=None):
def class_center_sample(label, num_classes, num_samples, group=None):
"""
Class center sample method is proposed from the paper PartialFC that only sample a subset of the class centers.
The process of sampling subset class centers is straightforward:
......@@ -1611,7 +1611,6 @@ def class_center_sample(label, num_classes, num_samples, group=None, seed=None):
num_samples (int): A positive integer to specify the number of class center to sample.
group (Group, optional): The abstract representation of group.
See paddle.distributed.collective.Group. Default is ``None``.
seed (int, optional): Random seed. Default is ``None``.
Returns:
Tuple of two ``Tensor`` : (remapped_label, sampled_class_center), remapped label using sampled class center,
......@@ -1702,6 +1701,19 @@ def class_center_sample(label, num_classes, num_samples, group=None, seed=None):
'Expected num_samples less than or equal to {}, got num_samples {}'.
format(num_classes, num_samples))
label_size = 1
for dim in list(label.shape):
label_size *= dim
if label_size < 1:
raise ValueError('Expected label_size > 0 \
(got label_size{})'.format(label_size))
label_dims = len(list(label.shape))
if label_dims != 1:
raise ValueError('Expected label_dims == 1 \
(got label_dims{})'.format(label_dims))
seed = None
if (seed is None or seed == 0) and default_main_program().random_seed != 0:
seed = default_main_program().random_seed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册