未验证 提交 8eb7f3ad 编写于 作者: G guru4elephant 提交者: GitHub

Merge pull request #4 from colourful-tree/develop

merge colourful-tree's change
...@@ -216,6 +216,9 @@ include(external/warpctc) # download, build, install warpctc ...@@ -216,6 +216,9 @@ include(external/warpctc) # download, build, install warpctc
include(cupti) include(cupti)
include(external/gzstream) include(external/gzstream)
endif (NOT WIN32) endif (NOT WIN32)
include(external/libmct)
#include(external/pslib_brpc)
include(external/pslib)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
if(WITH_GRPC) if(WITH_GRPC)
...@@ -276,6 +279,9 @@ set(EXTERNAL_LIBS ...@@ -276,6 +279,9 @@ set(EXTERNAL_LIBS
protobuf protobuf
zlib zlib
${PYTHON_LIBRARIES} ${PYTHON_LIBRARIES}
pslib
#pslib_brpc
libmct
) )
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
......
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_LIBMCT})
return()
ENDIF(NOT ${WITH_LIBMCT})
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with LIBMCT in Paddle yet."
"Force WITH_LIBMCT=OFF")
SET(WITH_LIBMCT OFF CACHE STRING "Disable LIBMCT package in Windows and MacOS" FORCE)
return()
ENDIF()
INCLUDE(ExternalProject)
SET(LIBMCT_PROJECT "extern_libmct")
IF((NOT DEFINED LIBMCT_VER) OR (NOT DEFINED LIBMCT_URL))
MESSAGE(STATUS "use pre defined download url")
SET(LIBMCT_VER "libmct" CACHE STRING "" FORCE) #todo libmct version
SET(LIBMCT_URL "http://bjyz-heqiaozhi-dev-new.epc.baidu.com:8000/${LIBMCT_VER}.tar.gz" CACHE STRING "" FORCE) #todo libmct url
ENDIF()
MESSAGE(STATUS "LIBMCT_VER: ${LIBMCT_VER}, LIBMCT_URL: ${LIBMCT_URL}")
SET(LIBMCT_SOURCE_DIR "${THIRD_PARTY_PATH}/libmct")
SET(LIBMCT_DOWNLOAD_DIR "${LIBMCT_SOURCE_DIR}/src/${LIBMCT_PROJECT}")
SET(LIBMCT_DST_DIR "libmct")
SET(LIBMCT_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(LIBMCT_INSTALL_DIR ${LIBMCT_INSTALL_ROOT}/${LIBMCT_DST_DIR})
SET(LIBMCT_ROOT ${LIBMCT_INSTALL_DIR})
SET(LIBMCT_INC_DIR ${LIBMCT_ROOT}/include)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${LIBMCT_ROOT}/lib")
INCLUDE_DIRECTORIES(${LIBMCT_INC_DIR})
FILE(WRITE ${LIBMCT_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(LIBMCT)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${LIBMCT_VER}/include ${LIBMCT_VER}/lib \n"
" DESTINATION ${LIBMCT_DST_DIR})\n")
ExternalProject_Add(
${LIBMCT_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LIBMCT_SOURCE_DIR}
DOWNLOAD_DIR ${LIBMCT_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${LIBMCT_URL} -c -q -O ${LIBMCT_VER}.tar.gz
&& tar zxvf ${LIBMCT_VER}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${LIBMCT_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${LIBMCT_INSTALL_ROOT}
)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0" OR NOT WIN32)
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/boost_dummy.c)
file(WRITE ${dummyfile} "const char *dummy = \"${dummyfile}\";")
add_library(libmct STATIC ${dummyfile})
else()
add_library(libmct INTERFACE)
endif()
#ADD_LIBRARY(libmct SHARED IMPORTED GLOBAL)
ADD_DEPENDENCIES(libmct ${LIBMCT_PROJECT})
LIST(APPEND external_project_dependencies libmct)
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_PSLIB})
return()
ENDIF(NOT ${WITH_PSLIB})
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with PSLIB in Paddle yet."
"Force WITH_PSLIB=OFF")
SET(WITH_PSLIB OFF CACHE STRING "Disable PSLIB package in Windows and MacOS" FORCE)
return()
ENDIF()
INCLUDE(ExternalProject)
SET(PSLIB_PROJECT "extern_pslib")
IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL))
MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_VER "pslib" CACHE STRING "" FORCE) #todo pslib version
SET(PSLIB_URL "http://bjyz-heqiaozhi-dev-new.epc.baidu.com:8000/${PSLIB_VER}.tar.gz" CACHE STRING "" FORCE) #todo pslib url
ENDIF()
MESSAGE(STATUS "PSLIB_VER: ${PSLIB_VER}, PSLIB_URL: ${PSLIB_URL}")
SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib")
SET(PSLIB_DOWNLOAD_DIR "${PSLIB_SOURCE_DIR}/src/${PSLIB_PROJECT}")
SET(PSLIB_DST_DIR "pslib")
SET(PSLIB_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(PSLIB_INSTALL_DIR ${PSLIB_INSTALL_ROOT}/${PSLIB_DST_DIR})
SET(PSLIB_ROOT ${PSLIB_INSTALL_DIR})
SET(PSLIB_INC_DIR ${PSLIB_ROOT}/include)
SET(PSLIB_LIB_DIR ${PSLIB_ROOT}/lib)
SET(PSLIB_LIB ${PSLIB_LIB_DIR}/libps.so)
SET(PSLIB_IOMP_LIB ${PSLIB_LIB_DIR}/libiomp5.so) #todo what is this
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${PSLIB_ROOT}/lib")
INCLUDE_DIRECTORIES(${PSLIB_INC_DIR})
FILE(WRITE ${PSLIB_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(PSLIB)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${PSLIB_VER}/include ${PSLIB_VER}/lib \n"
" DESTINATION ${PSLIB_DST_DIR})\n")
ExternalProject_Add(
${PSLIB_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_SOURCE_DIR}
DOWNLOAD_DIR ${PSLIB_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_URL} -c -q -O ${PSLIB_VER}.tar.gz
&& tar zxvf ${PSLIB_VER}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PSLIB_INSTALL_ROOT}
)
ADD_LIBRARY(pslib STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET pslib PROPERTY IMPORTED_LOCATION ${PSLIB_LIB})
ADD_DEPENDENCIES(pslib ${PSLIB_PROJECT})
LIST(APPEND external_project_dependencies pslib)
IF(WITH_C_API)
INSTALL(FILES ${PSLIB_LIB} ${PSLIB_IOMP_LIB} DESTINATION lib)
ENDIF()
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_PSLIB_BRPC})
return()
ENDIF(NOT ${WITH_PSLIB_BRPC})
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with PSLIB_BRPC in Paddle yet."
"Force WITH_PSLIB_BRPC=OFF")
SET(WITH_PSLIB_BRPC OFF CACHE STRING "Disable PSLIB_BRPC package in Windows and MacOS" FORCE)
return()
ENDIF()
INCLUDE(ExternalProject)
SET(PSLIB_BRPC_PROJECT "extern_pslib_brpc")
IF((NOT DEFINED PSLIB_BRPC_VER) OR (NOT DEFINED PSLIB_BRPC_URL))
MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_BRPC_VER "pslib_brpc" CACHE STRING "" FORCE) #todo pslib version
SET(PSLIB_BRPC_URL "http://bjyz-heqiaozhi-dev-new.epc.baidu.com:8000/${PSLIB_BRPC_VER}.tar.gz" CACHE STRING "" FORCE) #todo pslib_brpc url
ENDIF()
MESSAGE(STATUS "PSLIB_BRPC_VER: ${PSLIB_BRPC_VER}, PSLIB_BRPC_URL: ${PSLIB_BRPC_URL}")
SET(PSLIB_BRPC_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib_brpc")
SET(PSLIB_BRPC_DOWNLOAD_DIR "${PSLIB_BRPC_SOURCE_DIR}/src/${PSLIB_BRPC_PROJECT}")
SET(PSLIB_BRPC_DST_DIR "pslib_brpc")
SET(PSLIB_BRPC_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(PSLIB_BRPC_INSTALL_DIR ${PSLIB_BRPC_INSTALL_ROOT}/${PSLIB_BRPC_DST_DIR})
SET(PSLIB_BRPC_ROOT ${PSLIB_BRPC_INSTALL_DIR})
SET(PSLIB_BRPC_INC_DIR ${PSLIB_BRPC_ROOT}/include)
SET(PSLIB_BRPC_LIB_DIR ${PSLIB_BRPC_ROOT}/lib)
SET(PSLIB_BRPC_LIB ${PSLIB_BRPC_LIB_DIR}/libps.so)
SET(PSLIB_BRPC_IOMP_LIB ${PSLIB_BRPC_LIB_DIR}/libiomp5.so) #todo what is this
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${PSLIB_BRPC_ROOT}/lib")
INCLUDE_DIRECTORIES(${PSLIB_BRPC_INC_DIR})
FILE(WRITE ${PSLIB_BRPC_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(PSLIB_BRPC)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${PSLIB_BRPC_VER}/include ${PSLIB_BRPC_VER}/lib \n"
" DESTINATION ${PSLIB_BRPC_DST_DIR})\n")
ExternalProject_Add(
${PSLIB_BRPC_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_BRPC_SOURCE_DIR}
DOWNLOAD_DIR ${PSLIB_BRPC_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_BRPC_URL} -c -q -O ${PSLIB_BRPC_VER}.tar.gz
&& tar zxvf ${PSLIB_BRPC_VER}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_BRPC_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PSLIB_BRPC_INSTALL_ROOT}
)
ADD_LIBRARY(pslib_brpc SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET pslib_brpc PROPERTY IMPORTED_LOCATION ${PSLIB_BRPC_LIB})
ADD_DEPENDENCIES(pslib_brpc ${PSLIB_BRPC_PROJECT})
LIST(APPEND external_project_dependencies pslib_brpc)
IF(WITH_C_API)
INSTALL(FILES ${PSLIB_BRPC_LIB} ${PSLIB_BRPC_IOMP_LIB} DESTINATION lib)
ENDIF()
...@@ -29,6 +29,7 @@ limitations under the License. */ ...@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#include "pslib.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -47,6 +48,10 @@ void AsyncExecutor::CreateThreads( ...@@ -47,6 +48,10 @@ void AsyncExecutor::CreateThreads(
worker->SetDataFeed(reader); worker->SetDataFeed(reader);
worker->SetFetchVarNames(fetch_var_names); worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory(); worker->BindingDataFeedMemory();
worker->SetPSlibPtr(_pslib_ptr);
worker->SetPullDenseThread(_pull_dense_thread);
worker->BindingSlotVariableMemory();
worker->SetParamConfig(&_param_config);
} }
void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
...@@ -60,6 +65,77 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT ...@@ -60,6 +65,77 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
readers[0]->SetFileList(filelist); readers[0]->SetFileList(filelist);
} }
void AsyncExecutor::ConfigPslib(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(new paddle::distributed::PSlib());
_pslib_ptr->init_and_config(dist_desc, host_sign_list, node_num, index);//TODO
}
void AsyncExecutor::StartServer() {
_pslib_ptr->run_server();
}
void AsyncExecutor::InitModel() {
//TODO only rank = 0 do this
std::vector<int> all_dense_table_id; //TODO
all_dense_table_id.push_back(0);
for (auto table_id: all_dense_table_id) {
std::vector<paddle::ps::Region> regions;
std::vector<std::string> variables; //TODO
for (auto& t : variables) {
Variable* var = root_scope_->FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
CHECK(g != nullptr) << "var[" << t << "] value not initialized";
float init_range = 0.2;
int rown = tensor->dims()[0];
init_range /= sqrt(rown);
std::normal_distribution<float> ndistr(0.0, 1.0);
for (auto i = 0u; i < tensor->numel(); ++i) {
g[i] = ndistr(local_random_engine()) * init_range;
}
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto push_status = _pslib_ptr->_worker_ptr->push_dense_param(regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push dense param failed, status[" << status << "]";
exit(-1);
}
}
}
void AsyncExecutor::SaveModel(const std::string& path) {
auto ret = _pslib_ptr->_worker_ptr->flush();
ret.wait();
ret = _pslib_ptr->_worker_ptr->save(path, 0);
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) { // TODO should be feasign_cnt < 0, because server bug
LOG(FATAL) << "save model failed";
exit(-1);
}
}
void AsyncExecutor::PrepareDenseThread() {
DensePullThreadParam param;
param.ps_client = _pslib_ptr->_worker_ptr;;
param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO
param.training_thread_num = actual_thread_num;
param.root_scope = root_scope_;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
_pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param));
}
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str, const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
...@@ -82,7 +158,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -82,7 +158,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc); &data_feed_desc);
int actual_thread_num = thread_num; actual_thread_num = thread_num;
int file_cnt = filelist.size(); int file_cnt = filelist.size();
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty"); PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
...@@ -106,11 +182,11 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -106,11 +182,11 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
// todo: should be factory method for creating datafeed // todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers; std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist); PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
PrepareDenseThread();
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers; std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num); workers.resize(actual_thread_num);
for (auto& worker : workers) { for (auto& worker : workers) {
worker.reset(new ExecutorThreadWorker); worker.reset(new AsyncExecutorThreadWorker);
} }
// prepare thread resource here // prepare thread resource here
...@@ -128,7 +204,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -128,7 +204,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& th : threads) { for (auto& th : threads) {
th.join(); th.join();
} }
_pull_dense_thread->stop();
root_scope_->DropKids(); root_scope_->DropKids();
return; return;
......
...@@ -22,6 +22,8 @@ limitations under the License. */ ...@@ -22,6 +22,8 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include <typeinfo> #include <typeinfo>
#include <vector> #include <vector>
#include <random> //local_random_engine
#include <time.h> //local_random_engine
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_thread_worker.h" #include "paddle/fluid/framework/executor_thread_worker.h"
...@@ -30,6 +32,26 @@ limitations under the License. */ ...@@ -30,6 +32,26 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline double current_realtime() {
struct timespec tp;
clock_gettime(CLOCK_REALTIME, &tp);
return tp.tv_sec + tp.tv_nsec * 1e-9;
}
inline std::default_random_engine& local_random_engine() {
struct engine_wrapper_t {
std::default_random_engine engine;
engine_wrapper_t() {
static std::atomic<unsigned long> x(0);
std::seed_seq sseq = {x++, x++, x++, (unsigned long)(current_realtime() * 1000)};
engine.seed(sseq);
}
};
thread_local engine_wrapper_t r;
return r.engine;
}
class AsyncExecutor { class AsyncExecutor {
public: public:
AsyncExecutor(Scope* scope, const platform::Place& place); AsyncExecutor(Scope* scope, const platform::Place& place);
...@@ -40,6 +62,12 @@ class AsyncExecutor { ...@@ -40,6 +62,12 @@ class AsyncExecutor {
const int thread_num, const int thread_num,
const std::vector<std::string>& fetch_names, const std::vector<std::string>& fetch_names,
const bool debug = false); const bool debug = false);
//void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index);
void ConfigPslib(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index);
//void ConfigWorker() {}
void StartServer();
void InitModel();
void SaveModel(const std::string& path);
private: private:
void CreateThreads(ExecutorThreadWorker* worker, void CreateThreads(ExecutorThreadWorker* worker,
...@@ -48,11 +76,19 @@ class AsyncExecutor { ...@@ -48,11 +76,19 @@ class AsyncExecutor {
const std::vector<std::string>& fetch_var_names, const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index, Scope* root_scope, const int thread_index,
const bool debug); const bool debug);
void PrepareDenseThread();
public: public:
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread;
Scope* root_scope_; Scope* root_scope_;
platform::Place place_; platform::Place place_;
AsyncWorkerParamConfig _param_config;
private:
int actual_thread_num;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -31,6 +31,85 @@ limitations under the License. */ ...@@ -31,6 +31,85 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
int DensePullThread::start() {
_running = true;
_t = std::thread(&DensePullThread::run, this);
return 0;
}
void DensePullThread::run() {
while (_running) {
_pull_dense_status.resize(0);
for (auto& t : _dense_variable_name) {
if (check_update_param(t.first)) {
auto status = pull_dense(t.first);
_pull_dense_status.emplace_back(std::move(status));
reset_thread_version(t.first);
}
}
if (_pull_dense_status.size() != 0) {
wait_all();
}
usleep(_sleep_time_ms * 1000);
}
}
bool DensePullThread::check_update_param(uint64_t table_id) {
{
std::lock_guard<std::mutex> lock(_mutex_for_version);
auto& version = _training_versions[table_id];
_current_version[table_id] = *(std::min_element(version.begin(), version.end()));
}
if (_current_version[table_id] - _last_versions[table_id] < _threshold) {
return false;
}
return true;
}
void DensePullThread::reset_thread_version(uint64_t table_id) {
std::lock_guard<std::mutex> lock(_mutex_for_version);
_last_versions[table_id] = _current_version[table_id];
}
std::future<int32_t> DensePullThread::pull_dense(uint64_t table_id) {
auto& regions = _regions[table_id];
regions.clear();
auto& variables = _dense_variable_name[table_id];
regions.resize(variables.size());
for (auto i = 0u; i < variables.size(); ++i) {
auto& t = variables[i];
Variable* var = _root_scope->FindVar(t);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* w = tensor->data<float>();
paddle::ps::Region reg(w, tensor->numel());
regions[i] = std::move(reg);
}
return _ps_client->pull_dense(regions.data(), regions.size(), table_id);
}
void DensePullThread::wait_all() {
for (auto& t : _pull_dense_status) {
t.wait();
auto status = t.get();
if (status != 0) {
LOG(WARNING) << "pull dense failed times:" << ++_pull_dense_fail_times;
}
}
if (_pull_dense_fail_times > 20) {
LOG(FATAL) << "pull dense failed times more than 20 times";
exit(-1);
}
_pull_dense_status.resize(0);
}
void DensePullThread::increase_thread_version(int thread_id, uint64_t table_id) {
std::lock_guard<std::mutex> lock(_mutex_for_version);
_training_versions[table_id][thread_id]++;
}
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) { void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0); auto& block = program.Block(0);
op_names_.clear(); op_names_.clear();
...@@ -90,6 +169,11 @@ void ExecutorThreadWorker::SetFetchVarNames( ...@@ -90,6 +169,11 @@ void ExecutorThreadWorker::SetFetchVarNames(
fetch_var_names.end()); fetch_var_names.end());
} }
void ExecutorThreadWorker::SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) {
}
void ExecutorThreadWorker::SetDevice() { void ExecutorThreadWorker::SetDevice() {
#if defined _WIN32 || defined __APPLE__ #if defined _WIN32 || defined __APPLE__
return; return;
...@@ -219,5 +303,377 @@ void ExecutorThreadWorker::SetRootScope(Scope* g_scope) { ...@@ -219,5 +303,377 @@ void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
root_scope_ = g_scope; root_scope_ = g_scope;
} }
//AsyncExecutor
void AsyncExecutorThreadWorker::TrainFiles() {
SetDevice();
int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear();
fetch_values_.resize(fetch_var_num);
thread_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) {
// executor run here
TrainOneNetwork();
++batch_cnt;
thread_scope_->DropKids();
if (debug_ == false || thread_id_ != 0) {
continue;
}
for (int i = 0; i < fetch_var_num; ++i) {
print_fetch_var(thread_scope_, fetch_var_names_[i]);
} // end for (int i = 0...)
} // end while ()
}
void AsyncExecutorThreadWorker::SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) {
_pslib_ptr = pslib_ptr;
}
void AsyncExecutorThreadWorker::SetPullDenseThread(std::shared_ptr<DensePullThread> dpt) {
_pull_dense_thread = dpt;
}
void AsyncExecutorThreadWorker::TrainOneNetwork() {
PrepareParams();
for (auto& op : ops_) {
if (op->Type().find("sgd") != std::string::npos) {
continue;
}
op->Run(*thread_scope_, place_);
}
UpdateParams();
}
void AsyncExecutorThreadWorker::BindingSlotVariableMemory() {
/*
std::vector<int> ins_slot_offset(batch_size + 1, 0);
for (auto i = 1u; i <= batch_size; ++i) {
ins_slot_offset[i] += ins_slot_offset[i - 1] + slot_dim;
}
std::vector<int> tensor_lod(batch_size + 1, 0);
for (auto i = 1u; i <= batch_size; ++i) {
tensor_lod[i] += tensor_lod[i - 1] + 1;
}
auto& used_slots = reader->get_use_slot_alias();
slot_input_vec.resize(used_slots.size() - 1);
for (auto slot_idx = 1u; slot_idx < used_slots.size(); ++slot_idx) {
auto var = slot_input_variable_name[slot_idx];
auto v = thread_scope->FindVar(var);
CHECK(v != nullptr) << "var[" << var << "] not found";
LoDTensor* tensor = v->GetMutable<LoDTensor>();
float* tensor_ptr = tensor->mutable_data<float>({batch_size, slot_dim}, platform::CPUPlace());
memset(tensor_ptr, 0, sizeof(float) * ins_slot_offset.back());
LoD data_lod{tensor_lod};
tensor->set_lod(data_lod);
slot_input_vec[slot_idx - 1].reset(tensor);
}
*/
}
void AsyncExecutorThreadWorker::SetParamConfig(AsyncWorkerParamConfig* pc) {
_param_config = pc;
}
void AsyncExecutorThreadWorker::PrepareParams() {
int table_id = 0; //TODO
PullSparse(table_id);
for (auto& t : _pull_sparse_status) {
t.wait();
auto status = t.get();
if (status != 0) {
LOG(ERROR) << "pull sparse failed, status[" << status << "]";
exit(-1);
}
}
_pull_sparse_status.resize(0);
FillSparse(table_id);
}
void AsyncExecutorThreadWorker::UpdateParams() {
//for (auto i = 0u; i < GlobalConfig::instance().dense_table_id.size(); ++i) {//TODO
for (int i = 0; i < 1; ++i) {
PushSparse(i);
}
//for (auto i = 0u; i < GlobalConfig::instance().dense_table_id.size(); ++i) {//TODO
for (int i = 1; i < 2; ++i) {
PushDense(i);
}
int32_t tmp_push_dense_wait_times = _param_config->tmp_push_dense_wait_times; //TODO
int32_t tmp_push_sparse_wait_times = _param_config->tmp_push_sparse_wait_times; //TODO
static uint32_t push_dense_wait_times = static_cast<uint32_t>(tmp_push_dense_wait_times);
static uint32_t push_sparse_wait_times = static_cast<uint32_t>(tmp_push_sparse_wait_times);
if (_push_dense_status.size() >= push_dense_wait_times) {
for (auto& t : _push_dense_status) {
t.wait();
}
_push_dense_status.resize(0);
}
if (tmp_push_dense_wait_times == -1) {
_push_dense_status.resize(0);
}
if (_push_sparse_status.size() >= push_sparse_wait_times) {
for (auto& t : _push_sparse_status) {
t.wait();
}
_push_sparse_status.resize(0);
}
if (tmp_push_sparse_wait_times == -1) {
_push_sparse_status.resize(0);
}
//for (auto dense_table_id : GlobalConfig::instance().dense_table_id) {//TODO
int dense_table_id = 1;
_pull_dense_thread->increase_thread_version(thread_id_, dense_table_id);
//}
}
void AsyncExecutorThreadWorker::PushDense(int table_id) {
//auto table_id = GlobalConfig::instance().dense_table_id[table_id_index]; TODO
std::vector<paddle::ps::Region> regions;
//auto& variables = GlobalConfig::instance().dense_gradient_variable_name[table_id];
std::vector<std::string> variables;
for (auto& t : variables) {
Variable* var = thread_scope_->FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int count = tensor->numel();
float* g = tensor->data<float>();
paddle::ps::Region reg(g, count);
regions.emplace_back(std::move(reg));
}
auto status = _pslib_ptr->_worker_ptr->push_dense(regions.data(), regions.size(), table_id);
_push_dense_status.push_back(std::move(status));
}
void AsyncExecutorThreadWorker::PullSparse(int table_id) {
auto& features = _features[table_id];
auto& feature_value = _feature_value[table_id];
auto fea_dim = _param_config->fea_dim; //TODO
// slot id starts from 1
features.clear();
features.resize(0);
features.reserve(MAX_FEASIGN_NUM);
const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias();
// slot_idx = 0 is label TODO
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) {
Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel();
for (auto i = 0u; i < len; ++i) {
//todo: current trick - filter feasign=use_slot_mod(bug: datafeed fill use_slot_mod for empty slot)
if (ids[i] == 0u) {
continue;
}
features.push_back(static_cast<uint64_t>(ids[i]));
}
}
check_pull_push_memory(features, feature_value, fea_dim);
std::vector<float*> pull_feature_value;
for (auto i = 0u; i < features.size(); ++i) {
pull_feature_value.push_back(feature_value[i].data());
}
auto status = _pslib_ptr->_worker_ptr->pull_sparse(
pull_feature_value.data(), table_id, features.data(), features.size());
_pull_sparse_status.push_back(std::move(status));
//to save time
auto& push_g = _feature_push_value[table_id];
check_pull_push_memory(features, push_g, fea_dim);
//binding_slot_embed_with_concat(); TODO
collect_feasign_info(table_id); //TODO
}
void AsyncExecutorThreadWorker::FillSparse(int table_id) {
auto slot_dim = _param_config->slot_dim; // TODO
auto fea_dim = _param_config->fea_dim; //TODO
auto& features = _features[table_id];
auto& fea_value = _feature_value[table_id];
CHECK(features.size() > 0) << "feature size check failed";
auto fea_idx = 0u;
std::vector<float> init_value(fea_dim);
const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias();
// slot_idx = 0 is label TODO
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) {
Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel();
Variable* var_emb = thread_scope_->FindVar(_param_config->slot_input_vec[slot_idx - 1]);
LoDTensor* tensor_emb = var_emb->GetMutable<LoDTensor>();
float* ptr = tensor_emb->data<float>();
for (auto index = 0u; index < len; ++index){
//if (_current_train_job.use_cvm_feature()) {
// if (ids[index] == 0u) {
// memcpy(ptr + slot_dim * index, init_value.data(), sizeof(float) * slot_dim);
// continue;
// }
// memcpy(ptr + slot_dim * index, fea_value[fea_idx].data(), sizeof(float) * slot_dim);
// (ptr + slot_dim * index)[0] = log((ptr + slot_dim * index)[0] + 1);
// (ptr + slot_dim * index)[1] = log((ptr + slot_dim * index)[1] + 1) - (ptr + slot_dim * index)[0];
// fea_idx++;
//} else {
if (ids[index] == 0u) {
memcpy(ptr + slot_dim * index, init_value.data() + 2, sizeof(float) * slot_dim);
continue;
}
memcpy(ptr + slot_dim * index, fea_value[fea_idx].data() + 2, sizeof(float) * slot_dim);
fea_idx++;
//}
}
}
}
void AsyncExecutorThreadWorker::PushSparse(int table_id) {
auto slot_dim = _param_config->slot_dim; //TODO
auto fea_dim = _param_config->fea_dim;//_current_train_job.fea_dim();TODO
auto& features = _features[table_id];
//std::vector<std::string> gradient_var;
//auto& gradient_var = GlobalConfig::instance().input_gradient_variable_name; //TODO
auto& push_g = _feature_push_value[table_id];
check_pull_push_memory(features, push_g, fea_dim);
uint64_t fea_idx = 0u;
auto& fea_info = _fea_info[table_id]; //TODO
int offset = 0;
//if (!_current_train_job.use_cvm_feature()) { //TODO
offset = 2;
//}
const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias();
// slot_idx = 0 is label TODO
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) {
if (_slot_alias_to_table[feed_vec[slot_idx]] != table_id) {
continue;
}
Variable* g_var = thread_scope_->FindVar(_param_config->gradient_var[slot_idx - 1]);
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
//int count = g_tensor->numel();
float* g = g_tensor->data<float>();
/*
if (FLAGS_scale_sparse_gradient_with_batch_size) {
Eigen::Map<Eigen::MatrixXf> g_mat(g, 1, tensor->numel());
g_mat *= _batch_size;
}
*/
Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int len = tensor->lod()[0].back();
//assert(slot_dim * len == count);
int64_t* ids = tensor->data<int64_t>();
for (auto id_idx = 0u; id_idx < len; ++id_idx){
if (ids[id_idx] == 0) {
g += slot_dim;
continue;
}
memcpy(push_g[fea_idx].data() + offset, g, sizeof(float) * slot_dim);
push_g[fea_idx][0] = 1.0f;
push_g[fea_idx][1] = static_cast<float>(fea_info[fea_idx].label);
g += slot_dim;
fea_idx++;
}
}
assert(fea_idx == features.size());
CHECK(features.size() > 0);
std::vector<float*> push_g_vec;
for (auto i = 0u; i < features.size(); ++i) {
push_g_vec.push_back(push_g[i].data());
}
auto status = _pslib_ptr->_worker_ptr->push_sparse(
table_id, features.data(), (const float**)push_g_vec.data(), features.size());
_push_sparse_status.push_back(std::move(status));
}
void AsyncExecutorThreadWorker::collect_feasign_info(
int table_id) {
auto& fea_info = _fea_info[table_id];
auto& feature = _features[table_id];
fea_info.resize(feature.size());
const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias();
Variable* var = thread_scope_->FindVar(feed_vec[0]);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* label = tensor->data<int64_t>();
int global_index = 0;
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) {
Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int fea_idx = 0;
for (auto ins_idx = 1u; ins_idx < tensor->lod()[0].size(); ++ins_idx) {
for (; fea_idx < tensor->lod()[0][ins_idx]; ++fea_idx) {
if (ids[fea_idx] == 0u) {
continue;
}
FeasignInfo info{slot_idx, ins_idx, label[ins_idx - 1]};
fea_info[global_index++] = std::move(info);
}
}
}
CHECK(global_index == feature.size()) << "expect fea info size:" << feature.size()
<< " real:" << global_index;
}
void AsyncExecutorThreadWorker::check_pull_push_memory(
std::vector<uint64_t>& features,
std::vector<std::vector<float>>& push_g,
int dim) {
push_g.resize(features.size() + 1);
for (auto& t : push_g) {
t.resize(dim);
}
}
void AsyncExecutorThreadWorker::check_pull_push_memory(
std::vector<uint64_t>& features,
std::vector<float*>& push_g,
int dim) {
if (features.size() > push_g.size()) {
push_g.reserve(features.size() + 1);
auto size = features.size() - push_g.size() + 1;
for (auto i = 0u; i < size; ++i) {
float* ptr = new float[dim];
push_g.push_back(ptr);
}
}
}
} // einit_modelnd namespace framework } // einit_modelnd namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -25,16 +25,107 @@ limitations under the License. */ ...@@ -25,16 +25,107 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "pslib.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
const static uint32_t MAX_FEASIGN_NUM = 1000 * 100 * 100;
void CreateTensor(Variable* var, proto::VarType::Type var_type); void CreateTensor(Variable* var, proto::VarType::Type var_type);
struct AsyncWorkerParamConfig {
int slot_dim;
int fea_dim;
int32_t tmp_push_dense_wait_times;
int32_t tmp_push_sparse_wait_times;
std::vector<std::string> slot_input_vec; //6048slot 6050slot //name
std::vector<std::string> gradient_var; //6048slot_embed
};
struct DensePullThreadParam {
std::shared_ptr<paddle::ps::PSClient> ps_client;
int threshold;
int training_thread_num;
Scope* root_scope;
std::map<uint64_t, std::vector<std::string>>* dense_params;
int sleep_time_ms = 2;
};
class DensePullThread {
public:
DensePullThread(DensePullThreadParam& param) :
_running(false) {
_ps_client = param.ps_client;
_threshold = param.threshold;
_thread_num = param.training_thread_num;
_root_scope = param.root_scope;
_sleep_time_ms = param.sleep_time_ms;
for (auto& t : *param.dense_params) {
_dense_variable_name[t.first].insert(
_dense_variable_name[t.first].end(),
t.second.begin(), t.second.end());
_training_versions[t.first].resize(_thread_num, 0);
_last_versions[t.first] = 0;
_current_version[t.first] = 0;
}
}
int start();
void stop() {
if (_running) {
_running = false;
_t.join();
}
}
void increase_thread_version(int thread_id, uint64_t table_id);
void reset_thread_version(uint64_t table_id);
std::future<int32_t> pull_dense(uint64_t table_id);
void pull_dense2(uint64_t table_id);
void wait_all();
private:
void run();
bool check_update_param(uint64_t table_id);
private:
std::shared_ptr<paddle::ps::PSClient> _ps_client;
int _thread_num;
int _threshold;
int _sleep_time_ms;
Scope* _root_scope;
bool _running;
std::map<uint64_t, uint64_t> _last_versions;
std::map<uint64_t, uint64_t> _current_version;
std::mutex _mutex_for_version;
std::map<uint64_t, std::vector<uint64_t>> _training_versions;
std::map<uint64_t, std::vector<std::string>> _dense_variable_name;
std::thread _t;
std::vector<::std::future<int32_t>> _pull_dense_status;
std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
uint32_t _pull_dense_fail_times = 0;
std::vector<float> _base_norm_param;
std::vector<float> _mean;
std::vector<float> _scale;
float _squared_sum_epsilon = 1e-4;
std::mutex _mutex_for_mean_scale;
float _total_batch_num = 0;
};
class ExecutorThreadWorker { class ExecutorThreadWorker {
public: public:
ExecutorThreadWorker() ExecutorThreadWorker()
: thread_id_(-1), root_scope_(NULL), thread_scope_(NULL), debug_(false) {} : thread_id_(-1), root_scope_(NULL), thread_scope_(NULL), debug_(false) {}
~ExecutorThreadWorker() {} virtual ~ExecutorThreadWorker() {}
void CreateThreadResource(const framework::ProgramDesc& program, void CreateThreadResource(const framework::ProgramDesc& program,
const paddle::platform::Place& place); const paddle::platform::Place& place);
...@@ -51,10 +142,13 @@ class ExecutorThreadWorker { ...@@ -51,10 +142,13 @@ class ExecutorThreadWorker {
// set data feed declared in executor // set data feed declared in executor
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed); void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
// A multi-thread training function // A multi-thread training function
void TrainFiles(); virtual void TrainFiles();
// set fetch variable names from python interface assigned by users // set fetch variable names from python interface assigned by users
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names); void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
virtual void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr);
virtual void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt) {};
virtual void BindingSlotVariableMemory() {};
virtual void SetParamConfig(AsyncWorkerParamConfig* pc) {};
private: private:
void CreateThreadScope(const framework::ProgramDesc& program); void CreateThreadScope(const framework::ProgramDesc& program);
void CreateThreadOperators(const framework::ProgramDesc& program); void CreateThreadOperators(const framework::ProgramDesc& program);
...@@ -77,12 +171,58 @@ class ExecutorThreadWorker { ...@@ -77,12 +171,58 @@ class ExecutorThreadWorker {
Scope* root_scope_; Scope* root_scope_;
// a thread scope, father scope is global score which is shared // a thread scope, father scope is global score which is shared
Scope* thread_scope_; Scope* thread_scope_;
//private:
private:
std::vector<std::string> fetch_var_names_; std::vector<std::string> fetch_var_names_;
std::vector<std::vector<float>> fetch_values_; std::vector<std::vector<float>> fetch_values_;
bool debug_; bool debug_;
}; };
class AsyncExecutorThreadWorker: public ExecutorThreadWorker {
public:
AsyncExecutorThreadWorker(){};
virtual ~AsyncExecutorThreadWorker() {}
void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr);
void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt);
void BindingSlotVariableMemory();
void SetParamConfig(AsyncWorkerParamConfig* pc);
void TrainFiles();
void TrainOneNetwork();
void PrepareParams();
void UpdateParams();
void PullSparse(int table_id);
void FillSparse(int table_id);
void PushSparse(int table_id);
void PushDense(int table_id);
void check_pull_push_memory(std::vector<uint64_t>& features, std::vector<float*>& push_g, int dim);
void check_pull_push_memory(std::vector<uint64_t>& features, std::vector<std::vector<float>>& push_g, int dim);
void collect_feasign_info(int table_id);
private:
struct FeasignInfo {
uint32_t slot;
uint32_t ins;
int64_t label;
};
std::map<uint64_t, std::vector<uint64_t>> _features;
std::map<uint64_t, std::vector<FeasignInfo>> _fea_info;
std::map<uint64_t, std::vector<std::vector<float>>> _feature_value;
std::map<uint64_t, std::vector<std::vector<float>>> _feature_push_value;
std::unordered_map<std::string, uint64_t> _slot_alias_to_table; //TODO
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread;
std::vector<::std::future<int32_t>> _pull_sparse_status;
std::vector<::std::future<int32_t>> _pull_dense_status;
std::vector<::std::future<int32_t>> _push_sparse_status;
std::vector<::std::future<int32_t>> _push_dense_status;
AsyncWorkerParamConfig* _param_config;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -75,8 +75,13 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -75,8 +75,13 @@ class AucKernel : public framework::OpKernel<T> {
const auto *label_data = label->data<int64_t>(); const auto *label_data = label->data<int64_t>();
for (size_t i = 0; i < batch_size; i++) { for (size_t i = 0; i < batch_size; i++) {
uint32_t binIdx = static_cast<uint32_t>( auto predict_data = inference_data[i * inference_width + 1];
inference_data[i * inference_width + 1] * num_thresholds); PADDLE_ENFORCE_LE(predict_data, 1,
"The predict data must less or equal 1.");
PADDLE_ENFORCE_GE(predict_data, 0,
"The predict data must gather or equal 0.");
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i]) { if (label_data[i]) {
(*stat_pos)[binIdx] += 1.0; (*stat_pos)[binIdx] += 1.0;
} else { } else {
......
...@@ -47,7 +47,11 @@ void BindAsyncExecutor(py::module* m) { ...@@ -47,7 +47,11 @@ void BindAsyncExecutor(py::module* m) {
return std::unique_ptr<framework::AsyncExecutor>( return std::unique_ptr<framework::AsyncExecutor>(
new framework::AsyncExecutor(scope, place)); new framework::AsyncExecutor(scope, place));
})) }))
.def("run_from_files", &framework::AsyncExecutor::RunFromFile); .def("run_from_files", &framework::AsyncExecutor::RunFromFile)
.def("config_pslib", &framework::AsyncExecutor::ConfigPslib)
.def("start_server", &framework::AsyncExecutor::StartServer)
.def("init_model", &framework::AsyncExecutor::InitModel)
.def("save_model", &framework::AsyncExecutor::SaveModel);
} // end BindAsyncExecutor } // end BindAsyncExecutor
} // end namespace pybind } // end namespace pybind
} // end namespace paddle } // end namespace paddle
...@@ -149,3 +149,16 @@ class AsyncExecutor(object): ...@@ -149,3 +149,16 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc, self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num, data_feed.desc(), filelist, thread_num,
fetch_var_names, debug) fetch_var_names, debug)
def config_ps(self, dist_desc, host_sign_list, node_num, index):
self.executor.config_pslib(dist_desc, host_sign_list, node_num, index)
def start_server(self):
self.executor.start_server()
def init_model(self):
self.executor.init_model()
def save_model(self, save_path):
self.executor.save_model(save_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册