/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "paddle/trainer/Trainer.h" #include "paddle/trainer/TrainerInternal.h" #include #include using namespace paddle; // NOLINT using namespace std; // NOLINT static const string& configFile1 = "trainer/tests/sample_trainer_config.conf"; static const string& configFile2 = "trainer/tests/sample_trainer_config_parallel.conf"; P_DECLARE_bool(use_gpu); P_DECLARE_string(config); P_DECLARE_int32(gpu_id); P_DECLARE_int32(seed); P_DECLARE_int32(num_passes); P_DECLARE_int32(saving_period); class TrainerForTest : public paddle::Trainer { public: inline const std::shared_ptr& getParameterUpdaterForTest() { return this->trainerInternal_.getParameterUpdater(); } }; int gNumDevices = 0; void trainerOnePassTest(const string& configFile, bool useGpu, bool parallel, int trainerCount = 1, double averageWindow = 0.0f, bool doAverageInCpu = false) { FLAGS_use_gpu = useGpu; FLAGS_parallel_nn = parallel; FLAGS_config = configFile; FLAGS_trainer_count = trainerCount; LOG(INFO) << " useGpu=" << useGpu << " trainerCount=" << trainerCount << " configFile=" << configFile; srand(FLAGS_seed); if (useGpu) { if (gNumDevices < trainerCount) { return; } } Trainer trainer; auto config = TrainerConfigHelper::createFromFlagConfig(); if (averageWindow > 0) { config->getOptConfig().set_average_window(averageWindow); config->getOptConfig().set_do_average_in_cpu(doAverageInCpu); } trainer.init(config); trainer.train(); } // 1. test trainer (cpu, gpu). TEST(trainerOnePass, cpu) { trainerOnePassTest(configFile1, false, false); } #ifndef PADDLE_ONLY_CPU TEST(trainerOnePass, gpu) { trainerOnePassTest(configFile1, true, false); } TEST(trainerOnePass, gpu2) { trainerOnePassTest(configFile1, true, false, 2); } TEST(trainerOnePass, gpu4) { trainerOnePassTest(configFile1, true, false, 4); } TEST(trainerOnePass, parallel) { if (hl_get_device_count() >= 2) { trainerOnePassTest(configFile2, true, true); } } #endif // 2. test average_window. #ifndef PADDLE_ONLY_CPU TEST(average_window, gpu) { trainerOnePassTest(configFile1, true, false, 4, 0.01); } TEST(average_window, gpu2) { FLAGS_num_passes = 100; trainerOnePassTest(configFile1, true, false, 2, 0.01); FLAGS_num_passes = 1; } TEST(average_window, gpu4) { FLAGS_num_passes = 100; trainerOnePassTest(configFile1, true, false, 4, 0.01); FLAGS_num_passes = 1; } TEST(average_window_cpu, gpu2) { FLAGS_num_passes = 100; trainerOnePassTest(configFile1, true, false, 2, 0.01, true); FLAGS_num_passes = 1; } TEST(average_window_cpu, gpu4) { FLAGS_num_passes = 100; trainerOnePassTest(configFile1, true, false, 4, 0.01, true); FLAGS_num_passes = 1; } #endif // 3. test trainer + pserver. P_DECLARE_int32(num_gradient_servers); P_DECLARE_int32(port); P_DECLARE_bool(local); P_DECLARE_bool(use_old_updater); double checkRemoteParameterUpdater(TrainerForTest& trainer) { auto gradientMachine = trainer.getGradientMachine(); auto parameterUpdater = trainer.getParameterUpdaterForTest(); auto dataProvider = trainer.getDataProvider(); auto& parameters = gradientMachine->getParameters(); const TrainerConfig& config = trainer.getConfig(); const string& alg = config.opt_config().algorithm(); vector parameterCheck; for (auto& parameter : parameters) { parameterCheck.emplace_back( new Parameter(parameter->getConfig(), /* useGpu= */ false)); parameterCheck.back() ->getBuf(PARAMETER_VALUE) ->copyFrom(*parameter->getBuf(PARAMETER_VALUE)); parameterCheck.back() ->getBuf(PARAMETER_GRADIENT) ->copyFrom(*parameter->getBuf(PARAMETER_GRADIENT)); } std::unique_ptr parameterUpdaterCheck; if (alg == TrainAlgorithm::SGD) { parameterUpdaterCheck.reset(new SgdLocalUpdater(config.opt_config())); } else { LOG(INFO) << "unsupported algorithm in remote parameter check: " << alg; return -1.0; } parameterUpdaterCheck->init(parameterCheck); // gradientMachine->start(config, *dataProvider); DataBatch dataBatch; int32_t batchSize = config.opt_config().batch_size(); dataProvider->getNextBatch(batchSize, &dataBatch); CHECK(dataBatch.getSize()) << "No data from data provider"; int64_t actualBatchSize = dataBatch.getSize(); const vector& inArgs = dataBatch.getStreams(); vector outArgs; UpdateCallback updateCallback = [parameterUpdater, parameterCheck](Parameter* para) { parameterCheck[para->getID()] ->getBuf(PARAMETER_GRADIENT) ->copyFrom(*para->getBuf(PARAMETER_GRADIENT)); parameterUpdater->update(para); }; parameterUpdater->startPass(); parameterUpdaterCheck->startPass(); for (int i = 0; i < config.opt_config().num_batches_per_get_parameter() * 2; ++i) { PassType passType = parameterUpdater->startBatch(actualBatchSize); gradientMachine->forwardBackward( inArgs, &outArgs, passType, updateCallback); parameterUpdater->finishBatch(0); parameterUpdaterCheck->startBatch(actualBatchSize); for (auto& para : parameterCheck) { parameterUpdaterCheck->update(para.get()); } parameterUpdaterCheck->finishBatch(0); } double sum = 0.0f; for (size_t i = 0; i != parameters.size(); ++i) { real *v1, *v2; CpuVector trainerPara(parameters[i]->getSize()); trainerPara.copyFrom(*parameters[i]->getBuf(PARAMETER_VALUE)); if (!FLAGS_use_gpu) { v1 = parameters[i]->getBuf(PARAMETER_VALUE)->getData(); } else { v1 = trainerPara.getData(); } v2 = parameterCheck[i]->getBuf(PARAMETER_VALUE)->getData(); size_t size = parameters[i]->getSize(); double diff = 0; for (size_t j = 0; j < size; ++j) { diff += fabs(v1[j] - v2[j]); } sum += diff; LOG(INFO) << setiosflags(ios::left) << setfill(' ') << setw(20) << parameters[i]->getName() << "diff=" << setw(15) << diff; } parameterUpdater->finishPass(); parameterUpdaterCheck->finishPass(); gradientMachine->finish(); return sum; } void checkRemoteParameterUpdaterTest(const string& configFile, bool useGpu, bool parallel, int trainerCount = 1, bool useOldUpdater = false, int num_batches_per_get_parameter = 1) { FLAGS_use_gpu = useGpu; FLAGS_parallel_nn = parallel; FLAGS_config = configFile; FLAGS_trainer_count = trainerCount; FLAGS_use_old_updater = useOldUpdater; LOG(INFO) << " useGpu=" << useGpu << " trainerCount=" << trainerCount << " configFile=" << configFile; srand(FLAGS_seed); if (useGpu) { if (gNumDevices < trainerCount) { return; } } FLAGS_local = 0; std::shared_ptr pserver; pserver.reset(new ParameterServer2(std::string(), FLAGS_port)); pserver->init(); pserver->start(); TrainerForTest trainer; auto config = TrainerConfigHelper::createFromFlagConfig(); config->getOptConfig().set_num_batches_per_get_parameter( num_batches_per_get_parameter); trainer.init(config); EXPECT_EQ(checkRemoteParameterUpdater(trainer), 0); FLAGS_local = 1; } TEST(checkRemoteUpdater, cpuTrainer) { checkRemoteParameterUpdaterTest(configFile1, false, false); } TEST(checkRemoteUpdater, cpuTrainerOldUpdater) { checkRemoteParameterUpdaterTest(configFile1, false, false, 1, true); } #ifndef PADDLE_ONLY_CPU TEST(checkRemoteUpdater, gpuTrainer) { checkRemoteParameterUpdaterTest(configFile1, true, false); } TEST(checkRemoteUpdater, gpu2Trainer) { checkRemoteParameterUpdaterTest(configFile1, true, false, 2); } TEST(checkRemoteUpdater, gpu4Trainer) { checkRemoteParameterUpdaterTest(configFile1, true, false, 4); } TEST(checkRemoteUpdater, gpuTrainerOldUpdater) { checkRemoteParameterUpdaterTest(configFile1, true, false, 1, true); } TEST(checkRemoteUpdater, gpu2TrainerOldUpdater) { checkRemoteParameterUpdaterTest(configFile1, true, false, 2, true); } TEST(checkRemoteUpdater, gpu4TrainerOldUpdater) { checkRemoteParameterUpdaterTest(configFile1, true, false, 4, true); } #endif TEST(checkRemoteUpdater, cpuDeltaTrainer) { checkRemoteParameterUpdaterTest(configFile1, false, false, 1, false, 10); } TEST(checkRemoteUpdater, cpuDeltaTrainerOldUpdater) { checkRemoteParameterUpdaterTest(configFile1, false, false, 1, true, 10); } int main(int argc, char** argv) { initMain(argc, argv); initPython(argc, argv); gNumDevices = hl_get_device_count(); testing::InitGoogleTest(&argc, argv); FLAGS_num_passes = 1; // train one pass FLAGS_saving_period = 100000; // do not save parameteres return RUN_ALL_TESTS(); }