diff --git a/paddle/gserver/layers/MultinomialSampler.cpp b/paddle/gserver/layers/MultinomialSampler.cpp index 710772c0cf476f3b2dee790dc3f8254ee2452b0c..518dc0c60cbdc2a95b7eb9c8ff33dd6a9fb87c98 100644 --- a/paddle/gserver/layers/MultinomialSampler.cpp +++ b/paddle/gserver/layers/MultinomialSampler.cpp @@ -19,7 +19,7 @@ namespace paddle { MultinomialSampler::MultinomialSampler(const real* prob, int size) : rand_(0.0, size) { - intervals_.reserve(size + 1); + intervals_.resize(size + 1); double sum = 0; for (int i = 0; i < size; ++i) { sum += prob[i]; @@ -50,12 +50,13 @@ MultinomialSampler::MultinomialSampler(const real* prob, int size) int bigPos = nextBigPos(0); auto fillIntervals = [&]() { - while (bigPos < size && smallPos < size) { + while (bigPos < size) { while (intervals_[bigPos].thresh > 1 && smallPos < size) { intervals_[smallPos].otherId = bigPos; intervals_[bigPos].thresh -= 1 - intervals_[smallPos].thresh; smallPos = nextSmallPos(smallPos + 1); } + if (smallPos >= size) break; bigPos = nextBigPos(bigPos + 1); // If intervals_[bigPos].thresh < 1, it becomes a small interval } diff --git a/paddle/gserver/tests/test_MultinomialSampler.cpp b/paddle/gserver/tests/test_MultinomialSampler.cpp index 39a90958331f6cc3a19c12342f9c280e467a066e..73b4d0b8b7110d4ab79809875e2481cd2b565a68 100644 --- a/paddle/gserver/tests/test_MultinomialSampler.cpp +++ b/paddle/gserver/tests/test_MultinomialSampler.cpp @@ -41,39 +41,42 @@ public: TEST(MultinomialSampler, gen) { int numGrids = 1024 * 1024; int size = 1024 * 4; - default_random_engine reng; - uniform_int_distribution rand(1, numGrids / size * 1.8); - vector prob; - int sum = 0; - for (int i = 0; i < size; ++i) { - prob.push_back(rand(reng)); - sum += prob.back(); - } - CHECK_LE(sum, numGrids); - prob.back() += numGrids - sum; - vector counts(size); - MultinomialSamplerTester sampler(&prob[0], size); - counts.assign(size, 0); - { - double s = (double)size / (double)numGrids; - REGISTER_TIMER("MultinomialSampler"); - for (double i = 0; i < numGrids; ++i) { - int ret = sampler.testGen([i, s]() { return s * i; }); - if (ret < 0 || ret >= size) { - EXPECT_GE(ret, 0); - EXPECT_LT(ret, size); - break; + for (size_t iter=0; iter < 256; ++iter) { + uniform_int_distribution rand(1, numGrids / size * 1.8); + vector prob; + int sum = 0; + for (int i = 0; i < size; ++i) { + prob.push_back(rand(reng)); + sum += prob.back(); + } + + CHECK_LE(sum, numGrids); + prob.back() += numGrids - sum; + + vector counts(size); + MultinomialSamplerTester sampler(&prob[0], size); + counts.assign(size, 0); + { + double s = (double)size / (double)numGrids; + REGISTER_TIMER("MultinomialSampler"); + for (double i = 0; i < numGrids; ++i) { + int ret = sampler.testGen([i, s]() { return s * i; }); + if (ret < 0 || ret >= size) { + EXPECT_GE(ret, 0); + EXPECT_LT(ret, size); + break; + } + ++counts[ret]; } - ++counts[ret]; } - } - for (int i = 0; i < size; ++i) { - if (prob[i] != counts[i]) { - EXPECT_EQ(prob[i], counts[i]); - LOG(INFO) << "i=" << i; - break; + for (int i = 0; i < size; ++i) { + if (prob[i] != counts[i]) { + EXPECT_EQ(prob[i], counts[i]); + LOG(INFO) << iter; + break; + } } } } @@ -135,6 +138,7 @@ void benchmarkRandom() { LOG(INFO) << "sum1=" << sum1; } + int main(int argc, char** argv) { initMain(argc, argv); testing::InitGoogleTest(&argc, argv);