diff --git a/python/paddle/incubate/hapi/tests/test_metrics.py b/python/paddle/incubate/hapi/tests/test_metrics.py index 3d25a275d5f1c539ce959c5231a7af771b229836..19c94b73f61a29004ea41b573aa1dd0fb9c0c8e6 100644 --- a/python/paddle/incubate/hapi/tests/test_metrics.py +++ b/python/paddle/incubate/hapi/tests/test_metrics.py @@ -40,7 +40,8 @@ def accuracy(pred, label, topk=(1, )): def convert_to_one_hot(y, C): - oh = np.random.random((y.shape[0], C)).astype('float32') * .5 + oh = np.random.choice(np.arange(C), C, replace=False).astype('float32') / C + oh = np.tile(oh[np.newaxis, :], (y.shape[0], 1)) for i in range(y.shape[0]): oh[i, int(y[i])] = 1. return oh