提交 a8df4111 编写于 作者: L liaogang

Replace random_shuffle using shuffle.

* reduce trainer count for unit test on MAC OSX
上级 1d4bc478
...@@ -65,7 +65,8 @@ void DataProviderGroup<T>::reset() { ...@@ -65,7 +65,8 @@ void DataProviderGroup<T>::reset() {
provider_ = nullptr; provider_ = nullptr;
// shuffle file list // shuffle file list
std::random_shuffle(fileList_.begin(), fileList_.end()); std::shuffle(fileList_.begin(), fileList_.end(),
ThreadLocalRandomEngine::get());
startLoader(); startLoader();
DataProvider::reset(); DataProvider::reset();
......
...@@ -374,7 +374,8 @@ void ProtoDataProvider::reset() { ...@@ -374,7 +374,8 @@ void ProtoDataProvider::reset() {
} }
void ProtoDataProvider::shuffle() { void ProtoDataProvider::shuffle() {
std::random_shuffle(shuffledSequenceIds_.begin(), shuffledSequenceIds_.end()); std::shuffle(shuffledSequenceIds_.begin(), shuffledSequenceIds_.end(),
ThreadLocalRandomEngine::get());
} }
/* /*
......
...@@ -2514,7 +2514,8 @@ void SharedCpuMatrix::mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, ...@@ -2514,7 +2514,8 @@ void SharedCpuMatrix::mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB,
for (int k = 0; k < blockNum_; ++k) { for (int k = 0; k < blockNum_; ++k) {
blockSeq.push_back(k); blockSeq.push_back(k);
} }
std::random_shuffle(blockSeq.begin(), blockSeq.end()); std::shuffle(blockSeq.begin(), blockSeq.end(),
ThreadLocalRandomEngine::get());
} }
std::vector<int>& localBufRows = *localBufRows_; std::vector<int>& localBufRows = *localBufRows_;
int* cols = a->getCols(); int* cols = a->getCols();
......
...@@ -146,12 +146,12 @@ TEST(compareSparse, remote_cpu) { ...@@ -146,12 +146,12 @@ TEST(compareSparse, remote_cpu) {
TEST(compareSparse, cpu10_local_vs_remote) { TEST(compareSparse, cpu10_local_vs_remote) {
FLAGS_local = 1; // disable remote sparse update in parameter config FLAGS_local = 1; // disable remote sparse update in parameter config
std::vector<ParameterPtr> localParameters = std::vector<ParameterPtr> localParameters =
trainerOnePassTest(configFile1, true, 10); trainerOnePassTest(configFile1, true, 2);
FLAGS_local = 0; // will enable remote sparse update FLAGS_local = 0; // will enable remote sparse update
FLAGS_ports_num_for_sparse = 5; FLAGS_ports_num_for_sparse = 5;
std::vector<ParameterPtr> remoteParameters = std::vector<ParameterPtr> remoteParameters =
trainerOnePassTest(configFile1, true, 10); trainerOnePassTest(configFile1, true, 2);
compareValue(localParameters, remoteParameters); compareValue(localParameters, remoteParameters);
} }
...@@ -174,7 +174,7 @@ TEST(compareSparse, multiGradientMachine) { ...@@ -174,7 +174,7 @@ TEST(compareSparse, multiGradientMachine) {
FLAGS_parallel_nn = useGpu; FLAGS_parallel_nn = useGpu;
LOG(INFO) << " local=" << local LOG(INFO) << " local=" << local
<< " useGpu=" << useGpu; << " useGpu=" << useGpu;
int trainerCount = useGpu ? numGpu : 10; int trainerCount = useGpu ? numGpu : 2;
std::vector<ParameterPtr> parameters = std::vector<ParameterPtr> parameters =
trainerOnePassTest(configFile1, true, trainerCount, useGpu); trainerOnePassTest(configFile1, true, trainerCount, useGpu);
compareValue(getDenseParameters(), parameters, eps); compareValue(getDenseParameters(), parameters, eps);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册