未验证 提交 72e068f1 编写于 作者: P pangyoki 提交者: GitHub

fix test_multinomial (#28558)

* fix test_multinomial

* fix test_multinomial add 0 prob
上级 b889a0ce
......@@ -22,6 +22,26 @@ from op_test import OpTest
import numpy as np
def sample_output_one_dimension(out, dim):
# count numbers of different categories
sample_prob = np.zeros(dim).astype("float32")
sample_index_prob = np.unique(out, return_counts=True)
sample_prob[sample_index_prob[0]] = sample_index_prob[1]
sample_prob /= sample_prob.sum()
return sample_prob
def sample_output_two_dimension(out, shape):
num_dist = shape[0]
out_list = np.split(out, num_dist, axis=0)
sample_prob = np.zeros(shape).astype("float32")
for i in range(num_dist):
sample_index_prob = np.unique(out_list[i], return_counts=True)
sample_prob[i][sample_index_prob[0]] = sample_index_prob[1]
sample_prob /= sample_prob.sum(axis=-1, keepdims=True)
return sample_prob
class TestMultinomialOp(OpTest):
def setUp(self):
paddle.enable_static()
......@@ -39,10 +59,7 @@ class TestMultinomialOp(OpTest):
self.check_output_customized(self.verify_output)
def sample_output(self, out):
# count numbers of different categories
sample_prob = np.unique(out, return_counts=True)[1].astype("float32")
sample_prob /= sample_prob.sum()
return sample_prob
return sample_output_one_dimension(out, 4)
def verify_output(self, outs):
# normalize the input to get the probability
......@@ -62,14 +79,7 @@ class TestMultinomialOp2(TestMultinomialOp):
self.attrs = {"num_samples": 100000, "replacement": True}
def sample_output(self, out):
out_list = np.split(out, 3, axis=0)
count_array = [0] * 3
for i in range(3):
count_array[i] = np.unique(
out_list[i], return_counts=True)[1].astype("float32")
sample_prob = np.stack(count_array, axis=0)
sample_prob /= sample_prob.sum(axis=-1, keepdims=True)
return sample_prob
return sample_output_two_dimension(out, [3, 4])
class TestMultinomialOp3(TestMultinomialOp):
......@@ -91,15 +101,12 @@ class TestMultinomialApi(unittest.TestCase):
def test_dygraph(self):
# input probability is a vector, and replacement is True
paddle.disable_static()
x = paddle.rand([4])
x_numpy = np.random.rand(4)
x = paddle.to_tensor(x_numpy)
out = paddle.multinomial(x, num_samples=100000, replacement=True)
x_numpy = x.numpy()
paddle.enable_static()
sample_prob = np.unique(
out.numpy(), return_counts=True)[1].astype("float32")
sample_prob /= sample_prob.sum()
sample_prob = sample_output_one_dimension(out.numpy(), 4)
prob = x_numpy / x_numpy.sum(axis=-1, keepdims=True)
self.assertTrue(
np.allclose(
......@@ -109,18 +116,11 @@ class TestMultinomialApi(unittest.TestCase):
def test_dygraph2(self):
# input probability is a matrix, and replacement is True
paddle.disable_static()
x = paddle.rand([3, 4])
x_numpy = np.random.rand(3, 4)
x = paddle.to_tensor(x_numpy)
out = paddle.multinomial(x, num_samples=100000, replacement=True)
x_numpy = x.numpy()
out_list = np.split(out.numpy(), 3, axis=0)
count_array = [0] * 3
for i in range(3):
count_array[i] = np.unique(
out_list[i], return_counts=True)[1].astype("float32")
sample_prob = np.stack(count_array, axis=0)
sample_prob /= sample_prob.sum(axis=-1, keepdims=True)
sample_prob = sample_output_two_dimension(out.numpy(), [3, 4])
prob = x_numpy / x_numpy.sum(axis=-1, keepdims=True)
self.assertTrue(
np.allclose(
......@@ -131,9 +131,9 @@ class TestMultinomialApi(unittest.TestCase):
def test_dygraph3(self):
# replacement is False. number of samples must be less than number of categories.
paddle.disable_static()
x = paddle.rand([1000])
x_numpy = np.random.rand(1000)
x = paddle.to_tensor(x_numpy)
out = paddle.multinomial(x, num_samples=100, replacement=False)
x_numpy = x.numpy()
unique_out = np.unique(out.numpy())
self.assertEqual(
......@@ -158,9 +158,7 @@ class TestMultinomialApi(unittest.TestCase):
x_np = np.random.rand(4).astype('float32')
out = exe.run(train_program, feed={'x': x_np}, fetch_list=[out])
sample_prob = np.unique(out, return_counts=True)[1].astype("float32")
sample_prob /= sample_prob.sum()
sample_prob = sample_output_one_dimension(out, 4)
prob = x_np / x_np.sum(axis=-1, keepdims=True)
self.assertTrue(
np.allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册