diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ac5b2f1cfe2acacf2ac35001ce8640261c30265..b0da4bbec2f9a8fc2c23b92b0d4f4e94f148bee6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,6 +74,7 @@ option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools" option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF) option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF) option(WITH_PSLIB "Compile with pslib support" OFF) +option(WITH_BOX_PS "Compile with box_ps support" OFF) option(WITH_CONTRIB "Compile the third-party contributation" OFF) option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) @@ -170,6 +171,9 @@ if(WITH_PSLIB) include(external/pslib_brpc) include(external/pslib) endif(WITH_PSLIB) +if(WITH_BOX_PS) + include(external/box_ps) +endif(WITH_BOX_PS) if(WITH_DISTRIBUTE) if(WITH_GRPC) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index ae39eeb99e7d6ef50ddec6630ff881a6e6f80c56..816314ddc6ece68540e01abe262dec3b7227dd07 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -62,6 +62,10 @@ if(WITH_PSLIB) add_definitions(-DPADDLE_WITH_PSLIB) endif() +if(WITH_BOX_PS) + add_definitions(-DPADDLE_WITH_BOX_PS) +endif() + if(WITH_GPU) add_definitions(-DPADDLE_WITH_CUDA) add_definitions(-DEIGEN_USE_GPU) diff --git a/cmake/external/box_ps.cmake b/cmake/external/box_ps.cmake new file mode 100644 index 0000000000000000000000000000000000000000..ddb4c82e1d4424c8c5305de8ba232d382b28def9 --- /dev/null +++ b/cmake/external/box_ps.cmake @@ -0,0 +1,68 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +IF(NOT ${WITH_BOX_PS}) + return() +ENDIF(NOT ${WITH_BOX_PS}) + +IF(WIN32 OR APPLE) + MESSAGE(WARNING + "Windows or Mac is not supported with BOX_PS in Paddle yet." + "Force WITH_BOX_PS=OFF") + SET(WITH_BOX_PS OFF CACHE STRING "Disable BOX_PS package in Windows and MacOS" FORCE) + return() +ENDIF() + +INCLUDE(ExternalProject) + +SET(BOX_PS_PROJECT "extern_box_ps") +IF((NOT DEFINED BOX_PS_VER) OR (NOT DEFINED BOX_PS_URL)) + MESSAGE(STATUS "use pre defined download url") + SET(BOX_PS_VER "0.1.1" CACHE STRING "" FORCE) + SET(BOX_PS_NAME "box_ps" CACHE STRING "" FORCE) + SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps_stub.tar.gz" CACHE STRING "" FORCE) +ENDIF() +MESSAGE(STATUS "BOX_PS_NAME: ${BOX_PS_NAME}, BOX_PS_URL: ${BOX_PS_URL}") +SET(BOX_PS_SOURCE_DIR "${THIRD_PARTY_PATH}/box_ps") +SET(BOX_PS_DOWNLOAD_DIR "${BOX_PS_SOURCE_DIR}/src/${BOX_PS_PROJECT}") +SET(BOX_PS_DST_DIR "box_ps") +SET(BOX_PS_INSTALL_ROOT "${THIRD_PARTY_PATH}/install") +SET(BOX_PS_INSTALL_DIR ${BOX_PS_INSTALL_ROOT}/${BOX_PS_DST_DIR}) +SET(BOX_PS_ROOT ${BOX_PS_INSTALL_DIR}) +SET(BOX_PS_INC_DIR ${BOX_PS_ROOT}/include) +SET(BOX_PS_LIB_DIR ${BOX_PS_ROOT}/lib) +SET(BOX_PS_LIB ${BOX_PS_LIB_DIR}/libbox_ps.so) +SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${BOX_PS_ROOT}/lib") + +INCLUDE_DIRECTORIES(${BOX_PS_INC_DIR}) +FILE(WRITE ${BOX_PS_DOWNLOAD_DIR}/CMakeLists.txt + "PROJECT(BOX_PS)\n" + "cmake_minimum_required(VERSION 3.0)\n" + "install(DIRECTORY ${BOX_PS_NAME}/include ${BOX_PS_NAME}/lib \n" + " DESTINATION ${BOX_PS_DST_DIR})\n") +ExternalProject_Add( + ${BOX_PS_PROJECT} + ${EXTERNAL_PROJECT_LOG_ARGS} + PREFIX ${BOX_PS_SOURCE_DIR} + DOWNLOAD_DIR ${BOX_PS_DOWNLOAD_DIR} + DOWNLOAD_COMMAND wget --no-check-certificate ${BOX_PS_URL} -c -q -O ${BOX_PS_NAME}.tar.gz + && tar zxvf ${BOX_PS_NAME}.tar.gz + DOWNLOAD_NO_PROGRESS 1 + UPDATE_COMMAND "" + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${BOX_PS_INSTALL_ROOT} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${BOX_PS_INSTALL_ROOT} +) +ADD_LIBRARY(box_ps SHARED IMPORTED GLOBAL) +SET_PROPERTY(TARGET box_ps PROPERTY IMPORTED_LOCATION ${BOX_PS_LIB}) +ADD_DEPENDENCIES(box_ps ${BOX_PS_PROJECT}) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index da1e977a9d44974ec16599bd0da3d63e0892fa7a..3182f18cc8ec0521791c02eb14c4292fe6758dd2 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -123,8 +123,8 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) -cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog - shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack) +cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope + glog box_wrapper shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) @@ -179,7 +179,7 @@ if(WITH_DISTRIBUTE) dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry - device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer + device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper lodtensor_printer lod_rank_table feed_fetch_method sendrecvop_rpc collective_helper ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") @@ -190,7 +190,7 @@ else() data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto trainer_desc_proto glog - lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method + lod_rank_table fs shell fleet_wrapper box_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 047dc0116fea4d0cf0ed24e81d1befdeb310964c..8471616cd76cfbae82b1e5691b43d022e68cea9b 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -146,6 +146,7 @@ class DatasetImpl : public Dataset { virtual const std::vector& GetFileList() { return filelist_; } virtual int GetThreadNum() { return thread_num_; } virtual int GetTrainerNum() { return trainer_num_; } + virtual Channel GetInputChannel() { return input_channel_; } virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; } virtual std::pair GetHdfsConfig() { return std::make_pair(fs_name_, fs_ugi_); diff --git a/paddle/fluid/framework/fleet/CMakeLists.txt b/paddle/fluid/framework/fleet/CMakeLists.txt index 12fc454fd262cdcf30f64757a6199c6a9331e1a2..424063970b7e394ca8142fc698b3936586246014 100644 --- a/paddle/fluid/framework/fleet/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/CMakeLists.txt @@ -5,3 +5,8 @@ else() endif(WITH_PSLIB) cc_library(nccl_wrapper SRCS nccl_wrapper.cc DEPS framework_proto variable_helper scope) +if(WITH_BOX_PS) + cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor box_ps) +else() + cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor) +endif(WITH_BOX_PS) diff --git a/paddle/fluid/framework/fleet/box_wrapper.cc b/paddle/fluid/framework/fleet/box_wrapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..935bcc722a3f8b762c480a46c24d8b9574150c89 --- /dev/null +++ b/paddle/fluid/framework/fleet/box_wrapper.cc @@ -0,0 +1,247 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/fleet/box_wrapper.h" +#include +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace framework { + +std::shared_ptr BoxWrapper::s_instance_ = nullptr; +#ifdef PADDLE_WITH_BOX_PS +std::shared_ptr BoxWrapper::boxps_ptr_ = nullptr; +#endif + +int BoxWrapper::GetDate() const { + time_t now = time(0); + tm t; +#ifdef _WIN32 + localtime_s(&t, &now); +#else + localtime_r(&now, &t); +#endif + char buf[10]; + snprintf(buf, sizeof(buf), "%04d%02d%02d", (1900 + t.tm_year), (1 + t.tm_mon), + t.tm_mday); + return atoi(buf); +} + +void BoxWrapper::FeedPass(const std::vector& feasgin_to_box) const { +#ifdef PADDLE_WITH_BOX_PS + int ret = boxps_ptr_->FeedPass(GetDate(), feasgin_to_box); + PADDLE_ENFORCE_EQ(ret, 0, "FeedPass failed in BoxPS."); +#endif +} + +void BoxWrapper::BeginPass() const { +#ifdef PADDLE_WITH_BOX_PS + int ret = boxps_ptr_->BeginPass(); + PADDLE_ENFORCE_EQ(ret, 0, "BeginPass failed in BoxPS."); +#endif +} + +void BoxWrapper::EndPass() const { +#ifdef PADDLE_WITH_BOX_PS + int ret = boxps_ptr_->EndPass(); + PADDLE_ENFORCE_EQ(ret, 0, "EndPass failed in BoxPS."); +#endif +} + +void BoxWrapper::PullSparse(const paddle::platform::Place& place, + const std::vector& keys, + const std::vector& values, + const std::vector& slot_lengths, + const int hidden_size) { +#ifdef PADDLE_WITH_BOX_PS + if (platform::is_cpu_place(place) || platform::is_gpu_place(place)) { + int64_t total_length = + std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); + LoDTensor total_keys_tensor; + int64_t* total_keys = + total_keys_tensor.mutable_data({total_length, 1}, place); + int64_t offset = 0; + for (size_t i = 0; i < keys.size(); ++i) { + if (platform::is_cpu_place(place)) { + memory::Copy(boost::get(place), total_keys + offset, + boost::get(place), keys[i], + slot_lengths[i] * sizeof(uint64_t)); + } else { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + memory::Copy(boost::get(place), + total_keys + offset, + boost::get(place), keys[i], + slot_lengths[i] * sizeof(uint64_t), nullptr); +#else + PADDLE_THROW( + "Please compile WITH_GPU option, and NCCL doesn't support " + "windows."); +#endif + } + offset += slot_lengths[i]; + } + PADDLE_ENFORCE_EQ(offset, total_length, + "BoxWrapper::PullSparse: total feasign keys length " + "should be equal to the sum of length of all input " + "tensors."); + + // Space allocation for FeatureValue is left for boxps + paddle::boxps::FeatureValue* total_values; + if (platform::is_cpu_place(place)) { + int ret = boxps_ptr_->PullSparseCPU( + reinterpret_cast(total_keys), &total_values, + static_cast(total_length)); + PADDLE_ENFORCE_EQ(ret, 0, "PullSparseCPU failed in BoxPS."); + } else { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + int ret = boxps_ptr_->PullSparseGPU( + reinterpret_cast(total_keys), &total_values, + static_cast(total_length), + boost::get(place).GetDeviceId()); + PADDLE_ENFORCE_EQ(ret, 0, "PullSparseGPU failed in BoxPS."); +#endif + } + + offset = 0; + for (size_t i = 0; i < values.size(); ++i) { + int64_t fea_num = slot_lengths[i]; + for (auto j = 0; j < fea_num; ++j) { + // Copy the emb from BoxPS to paddle tensor. Since 'show','click','emb' + // are continuous in memory, so we copy here using the 'show' address + if (platform::is_cpu_place(place)) { + memory::Copy( + boost::get(place), + values[i] + j * hidden_size, + boost::get(place), + reinterpret_cast(&((total_values + offset)->show)), + sizeof(float) * hidden_size); + } else { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + memory::Copy( + boost::get(place), + values[i] + j * hidden_size, + boost::get(place), + reinterpret_cast(&((total_values + offset)->show)), + sizeof(float) * hidden_size, nullptr); +#endif + } + ++offset; + } + } + PADDLE_ENFORCE_EQ(offset, total_length, + "BoxWrapper::PullSparse: total emb values length should " + "be equal to the sum of length of all input tensors."); + + } else { + PADDLE_THROW( + "PaddleBox: PullSparse Only Support CPUPlace and CUDAPlace Now."); + } +#endif +} + +void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place, + const std::vector& keys, + const std::vector& grad_values, + const std::vector& slot_lengths, + const int hidden_size) { +#ifdef PADDLE_WITH_BOX_PS + if (platform::is_cpu_place(place) || platform::is_gpu_place(place)) { + int64_t total_length = + std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); + LoDTensor total_keys_tensor; + int64_t* total_keys = + total_keys_tensor.mutable_data({total_length, 1}, place); + int64_t offset = 0; + for (size_t i = 0; i < keys.size(); ++i) { + if (platform::is_cpu_place(place)) { + memory::Copy(boost::get(place), total_keys + offset, + boost::get(place), keys[i], + slot_lengths[i] * sizeof(uint64_t)); + } else { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + memory::Copy(boost::get(place), + total_keys + offset, + boost::get(place), keys[i], + slot_lengths[i] * sizeof(uint64_t), nullptr); +#else + PADDLE_THROW( + "Please compile WITH_GPU option, and for now NCCL doesn't support " + "windows."); +#endif + } + offset += slot_lengths[i]; + } + PADDLE_ENFORCE_EQ(offset, total_length, + "BoxWrapper::PushSparseGrad: total feasign keys length " + "should be equal to the sum of length of all input " + "tensors."); + auto buf = memory::AllocShared( + place, total_length * sizeof(paddle::boxps::FeaturePushValue)); + paddle::boxps::FeaturePushValue* total_grad_values = + reinterpret_cast(buf->ptr()); + offset = 0; + for (size_t i = 0; i < grad_values.size(); ++i) { + int64_t fea_num = slot_lengths[i]; + for (auto j = 0; j < fea_num; ++j) { + // Copy the emb grad from paddle tensor to BoxPS. Since + // 'show','click','emb' are continuous in memory, so we copy here using + // the 'show' address + if (platform::is_cpu_place(place)) { + memory::Copy( + boost::get(place), + reinterpret_cast(&((total_grad_values + offset)->show)), + boost::get(place), + grad_values[i] + j * hidden_size, sizeof(float) * hidden_size); + } else { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + memory::Copy( + boost::get(place), + reinterpret_cast(&((total_grad_values + offset)->show)), + boost::get(place), + grad_values[i] + j * hidden_size, sizeof(float) * hidden_size, + nullptr); +#endif + } + ++offset; + } + } + PADDLE_ENFORCE_EQ(offset, total_length, + "BoxWrapper::PushSparseGrad: total emb grad values " + "length should be equal to the sum of length of all " + "input tensors."); + if (platform::is_cpu_place(place)) { + int ret = boxps_ptr_->PushSparseCPU( + reinterpret_cast(total_keys), total_grad_values, + static_cast(total_length)); + PADDLE_ENFORCE_EQ(ret, 0, "PushSparseCPU failed in BoxPS."); + } else { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + int ret = boxps_ptr_->PushSparseGPU( + reinterpret_cast(total_keys), total_grad_values, + static_cast(total_length), + boost::get(place).GetDeviceId()); + PADDLE_ENFORCE_EQ(ret, 0, "PushSparseGPU failed in BoxPS."); +#endif + } + } else { + PADDLE_THROW( + "PaddleBox: PushSparse Only Support CPUPlace and CUDAPlace Now."); + } +#endif +} +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..c650d9cb7a63242d9b8d42c41049545d534a0975 --- /dev/null +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -0,0 +1,126 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include // NOLINT +#include +#include +#include "paddle/fluid/framework/data_set.h" +#ifdef PADDLE_WITH_BOX_PS +#include +#endif +#include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { + +class BoxWrapper { + public: + virtual ~BoxWrapper() {} + BoxWrapper() {} + + void FeedPass(const std::vector& feasgin_to_box) const; + void BeginPass() const; + void EndPass() const; + void PullSparse(const paddle::platform::Place& place, + const std::vector& keys, + const std::vector& values, + const std::vector& slot_lengths, + const int hidden_size); + void PushSparseGrad(const paddle::platform::Place& place, + const std::vector& keys, + const std::vector& grad_values, + const std::vector& slot_lengths, + const int hidden_size); + + static std::shared_ptr GetInstance() { + if (nullptr == s_instance_) { + // If main thread is guaranteed to init this, this lock can be removed + static std::mutex mutex; + std::lock_guard lock(mutex); + if (nullptr == s_instance_) { + s_instance_.reset(new paddle::framework::BoxWrapper()); +#ifdef PADDLE_WITH_BOX_PS + s_instance_->boxps_ptr_.reset(new paddle::boxps::FakeBoxPS()); +#endif + } + } + return s_instance_; + } + + private: +#ifdef PADDLE_WITH_BOX_PS + static std::shared_ptr boxps_ptr_; +#endif + static std::shared_ptr s_instance_; + int GetDate() const; +}; + +class BoxHelper { + public: + explicit BoxHelper(paddle::framework::Dataset* dataset) : dataset_(dataset) {} + virtual ~BoxHelper() {} + + void BeginPass() { + auto box_ptr = BoxWrapper::GetInstance(); + box_ptr->BeginPass(); + } + + void EndPass() { + auto box_ptr = BoxWrapper::GetInstance(); + box_ptr->EndPass(); + } + void LoadIntoMemory() { + dataset_->LoadIntoMemory(); + FeedPass(); + } + void PreLoadIntoMemory() { + dataset_->PreLoadIntoMemory(); + feed_data_thread_.reset(new std::thread([&]() { + dataset_->WaitPreLoadDone(); + FeedPass(); + })); + } + void WaitFeedPassDone() { feed_data_thread_->join(); } + + private: + Dataset* dataset_; + std::shared_ptr feed_data_thread_; + // notify boxps to feed this pass feasigns from SSD to memory + void FeedPass() { + auto box_ptr = BoxWrapper::GetInstance(); + auto input_channel_ = + dynamic_cast(dataset_)->GetInputChannel(); + std::vector pass_data; + std::vector feasign_to_box; + input_channel_->ReadAll(pass_data); + for (const auto& ins : pass_data) { + const auto& feasign_v = ins.uint64_feasigns_; + for (const auto feasign : feasign_v) { + feasign_to_box.push_back(feasign.sign().uint64_feasign_); + } + } + input_channel_->Open(); + input_channel_->Write(pass_data); + input_channel_->Close(); + box_ptr->FeedPass(feasign_to_box); + } +}; + +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/operators/pull_box_sparse_op.cc b/paddle/fluid/operators/pull_box_sparse_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8532649614c867a860774378e4ffd9b251dd76d5 --- /dev/null +++ b/paddle/fluid/operators/pull_box_sparse_op.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/pull_box_sparse_op.h" + +namespace paddle { +namespace operators { + +class PullBoxSparseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_GE(ctx->Inputs("Ids").size(), 1UL, + "Inputs(Ids) of PullBoxSparseOp should not be empty."); + PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL, + "Outputs(Out) of PullBoxSparseOp should not be empty."); + auto hidden_size = static_cast(ctx->Attrs().Get("size")); + auto all_ids_dim = ctx->GetInputsDim("Ids"); + const size_t n_ids = all_ids_dim.size(); + std::vector outs_dims; + outs_dims.resize(n_ids); + for (size_t i = 0; i < n_ids; ++i) { + const auto ids_dims = all_ids_dim[i]; + int ids_rank = ids_dims.size(); + PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1, + "Shape error in %lu id, the last dimension of the " + "'Ids' tensor must be 1.", + i); + auto out_dim = framework::vectorize( + framework::slice_ddim(ids_dims, 0, ids_rank - 1)); + out_dim.push_back(hidden_size); + outs_dims[i] = framework::make_ddim(out_dim); + } + ctx->SetOutputsDim("Out", outs_dims); + for (size_t i = 0; i < n_ids; ++i) { + ctx->ShareLoD("Ids", "Out", i, i); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.device_context()); + } +}; + +class PullBoxSparseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Ids", + "Input tensors with type int32 or int64 " + "contains the ids to be looked up in BoxPS. " + "The last dimension size must be 1.") + .AsDuplicable(); + AddOutput("Out", "The lookup results tensors.").AsDuplicable(); + AddAttr("size", "(int, the embedding hidden size").SetDefault(1); + AddComment(R"DOC( +Pull Box Sparse Operator. + +This operator is used to perform lookups on the BoxPS, +then concatenated into a dense tensor. + +The input Ids can carry the LoD (Level of Details) information, +or not. And the output only shares the LoD information with input Ids. + +)DOC"); + } +}; + +class PushBoxSparseOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("push_box_sparse"); + op->SetInput("Ids", Input("Ids")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +class PushBoxSparseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override {} + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.MultiInput(framework::GradVarName("Out"))[0] + ->type(), + ctx.device_context()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(pull_box_sparse, ops::PullBoxSparseOp, + ops::PullBoxSparseOpMaker, ops::PushBoxSparseOpDescMaker); +REGISTER_OPERATOR(push_box_sparse, ops::PushBoxSparseOp); +REGISTER_OP_CPU_KERNEL(pull_box_sparse, ops::PullBoxSparseCPUKernel) +REGISTER_OP_CPU_KERNEL(push_box_sparse, ops::PushBoxSparseCPUKernel) diff --git a/paddle/fluid/operators/pull_box_sparse_op.cu b/paddle/fluid/operators/pull_box_sparse_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..8bba9db5426b7055dce03ee2f5e87c11a38aef1b --- /dev/null +++ b/paddle/fluid/operators/pull_box_sparse_op.cu @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/pull_box_sparse_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; +using LoDTensor = framework::LoDTensor; + +template +class PullBoxSparseCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PullBoxSparseFunctor(ctx); + } +}; + +template +class PushBoxSparseCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PushBoxSparseFunctor(ctx); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(pull_box_sparse, ops::PullBoxSparseCUDAKernel) +REGISTER_OP_CUDA_KERNEL(push_box_sparse, ops::PushBoxSparseCUDAKernel) diff --git a/paddle/fluid/operators/pull_box_sparse_op.h b/paddle/fluid/operators/pull_box_sparse_op.h new file mode 100644 index 0000000000000000000000000000000000000000..48a9e4d9313640b90d1ba7278703a217e31feb46 --- /dev/null +++ b/paddle/fluid/operators/pull_box_sparse_op.h @@ -0,0 +1,90 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/framework/fleet/box_wrapper.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { + +template +static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) { + auto inputs = ctx.MultiInput("Ids"); + auto outputs = ctx.MultiOutput("Out"); + auto hidden_size = ctx.Attr("size"); + const auto slot_size = inputs.size(); + std::vector all_keys(slot_size); + // BoxPS only supports float now + std::vector all_values(slot_size); + std::vector slot_lengths(slot_size); + for (size_t i = 0; i < slot_size; i++) { + const auto *slot = inputs[i]; + const uint64_t *single_slot_keys = + reinterpret_cast(slot->data()); + all_keys[i] = single_slot_keys; + slot_lengths[i] = slot->numel(); + auto *output = outputs[i]->mutable_data(ctx.GetPlace()); + all_values[i] = output; + } + auto box_ptr = paddle::framework::BoxWrapper::GetInstance(); + box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths, + hidden_size); +} + +template +static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) { + auto inputs = ctx.MultiInput("Ids"); + auto d_output = + ctx.MultiInput(framework::GradVarName("Out")); + auto hidden_size = ctx.Attr("size"); + const auto slot_size = inputs.size(); + std::vector all_keys(slot_size); + std::vector all_grad_values(slot_size); + std::vector slot_lengths(slot_size); + for (size_t i = 0; i < slot_size; i++) { + const auto *slot = inputs[i]; + const uint64_t *single_slot_keys = + reinterpret_cast(slot->data()); + all_keys[i] = single_slot_keys; + slot_lengths[i] = slot->numel(); + const float *grad_value = d_output[i]->data(); + all_grad_values[i] = grad_value; + } + auto box_ptr = paddle::framework::BoxWrapper::GetInstance(); + box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values, + slot_lengths, hidden_size); +} + +using LoDTensor = framework::LoDTensor; +template +class PullBoxSparseCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PullBoxSparseFunctor(ctx); + } +}; + +template +class PushBoxSparseCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PushBoxSparseFunctor(ctx); + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index ff35ca6ca4ee2c958fe5f3250763c68ec1fe925d..b721ebe81719bfb833af56038065f91ce5fb795f 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,4 +1,4 @@ -set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper nccl_wrapper prune +set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper nccl_wrapper prune feed_fetch_method pass_builder parallel_executor profiler layer scope_pool tracer analysis_predictor imperative_profiler nccl_context) @@ -17,6 +17,7 @@ set(PYBIND_SRCS const_value.cc reader_py.cc fleet_wrapper_py.cc + box_helper_py.cc nccl_wrapper_py.cc data_set_py.cc imperative.cc diff --git a/paddle/fluid/pybind/box_helper_py.cc b/paddle/fluid/pybind/box_helper_py.cc new file mode 100644 index 0000000000000000000000000000000000000000..13aec9aa9234c9109299136dba79c9e66ce535b0 --- /dev/null +++ b/paddle/fluid/pybind/box_helper_py.cc @@ -0,0 +1,50 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +#ifdef _POSIX_C_SOURCE +#undef _POSIX_C_SOURCE +#endif + +#ifdef _XOPEN_SOURCE +#undef _XOPEN_SOURCE +#endif + +#include +#include +#include + +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/data_feed.pb.h" +#include "paddle/fluid/framework/fleet/box_wrapper.h" +#include "paddle/fluid/pybind/box_helper_py.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { +void BindBoxHelper(py::module* m) { + py::class_>( + *m, "BoxPS") + .def(py::init([](paddle::framework::Dataset* dataset) { + return std::make_shared(dataset); + })) + .def("begin_pass", &framework::BoxHelper::BeginPass) + .def("end_pass", &framework::BoxHelper::EndPass) + .def("wait_feed_pass_done", &framework::BoxHelper::WaitFeedPassDone) + .def("preload_into_memory", &framework::BoxHelper::PreLoadIntoMemory) + .def("load_into_memory", &framework::BoxHelper::LoadIntoMemory); +} // end BoxHelper +} // end namespace pybind +} // end namespace paddle diff --git a/paddle/fluid/pybind/box_helper_py.h b/paddle/fluid/pybind/box_helper_py.h new file mode 100644 index 0000000000000000000000000000000000000000..33072dd5a3a38b0a306056a7bd4b8aa5cf36b1df --- /dev/null +++ b/paddle/fluid/pybind/box_helper_py.h @@ -0,0 +1,28 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +void BindBoxHelper(py::module* m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3a19effff23eef71361918fe285cd78c77a064a6..2b6ea4575aeb4cea6cce92c4fbbf89cec7865e5e 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -51,6 +51,7 @@ limitations under the License. */ #include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/pybind/box_helper_py.h" #include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/data_set_py.h" #include "paddle/fluid/pybind/exception.h" @@ -1691,6 +1692,7 @@ All parameter, weight, gradient are variables in Paddle. }); BindFleetWrapper(&m); + BindBoxHelper(&m); #ifndef _WIN32 BindNCCLWrapper(&m); #endif diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 499dcdf359ebc6cbfa3bde3669fc80f8d9a5dd61..9e143954049dc94ccccf4b1c2476b37891d32b3c 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -720,3 +720,54 @@ class FileInstantDataset(DatasetBase): raise NotImplementedError( "FileInstantDataset does not support global shuffle, " "please use InMemoryDataset for global_shuffle") + + +class BoxPSDataset(InMemoryDataset): + """ + BoxPSDataset: derived from InMemoryDataset. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory.create_dataset("BoxPSDataset") + """ + + def __init__(self): + """ + Init + """ + super(BoxPSDataset, self).__init__() + self.boxps = core.BoxPS(self.dataset) + + def begin_pass(self): + """ + Notify BoxPS to begin next pass + """ + self.boxps.begin_pass() + + def end_pass(self): + """ + Notify BoxPS to end current pass + """ + self.boxps.end_pass() + + def wait_preload_done(self): + """ + Wait async proload done + """ + self.boxps.wait_feed_pass_done() + + def load_into_memory(self): + """ + Load next pass into memory and notify boxps to fetch its emb from SSD + """ + self._prepare_to_run() + self.boxps.load_into_memory() + + def preload_into_memory(self): + """ + begin async preload next pass while current pass may be training + """ + self._prepare_to_run() + self.boxps.preload_into_memory() diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 87f8454c62d02f66fbe6a55d3dafcad9f310ff5c..b671f63e86a109912627ad8160d62493aa2ed0b7 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -516,6 +516,54 @@ def embedding(input, return tmp +def _pull_box_sparse(input, size, dtype='float32'): + """ + **Pull Box Sparse Layer** + + This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in + BoxPS lookup table. The result of this lookup is the embedding of each ID in the + :attr:`input`. + + Args: + input(Variable|list of Variable): Input is a Tensor Variable, which + contains the IDs information. + size(int): The embedding size parameter, which indicates the size of + each embedding vector respectively. + dtype(str): The dtype refers to the data type of output tensor. Only supports + float32 now. + + Returns: + Variable|list of Variable: The tensor variable storing the embeddings of the \ + supplied inputs. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + data = fluid.layers.data(name='sequence', shape=[1], dtype='int64', lod_level=1) + emb = fluid.layers.pull_box_sparse(input=data, size=[11]) + """ + helper = LayerHelper('pull_box_sparse', **locals()) + if dtype != 'float32': + raise ValueError( + "BoxPS only support float type embedding now, and your type is: " + + dtype) + helper.input_dtype() + inputs = helper.multiple_input() + outs = [ + helper.create_variable_for_type_inference(dtype) + for i in range(len(inputs)) + ] + helper.append_op( + type='pull_box_sparse', + inputs={'Ids': inputs}, + outputs={'Out': outs}, + attrs={'size': size}) + if len(outs) == 1: + return outs[0] + return outs + + @templatedoc(op_type="lstm") def dynamic_lstm(input, size, diff --git a/python/paddle/fluid/tests/unittests/test_boxps.py b/python/paddle/fluid/tests/unittests/test_boxps.py new file mode 100644 index 0000000000000000000000000000000000000000..6a068d00776fc856fa5dcbf63dfe519bcf1acf90 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_boxps.py @@ -0,0 +1,103 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import numpy as np +import os +import paddle.fluid.core as core +import unittest +from paddle.fluid.layers.nn import _pull_box_sparse + + +class TestBoxPSPreload(unittest.TestCase): + """ TestCases for BoxPS Preload """ + + def test_boxps_cpu(self): + self.run_boxps_preload(True) + + def test_boxps_gpu(self): + self.run_boxps_preload(False) + + def run_boxps_preload(self, is_cpu=True): + x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0) + y = fluid.layers.data(name='y', shape=[1], dtype='int64', lod_level=0) + emb_x, emb_y = _pull_box_sparse([x, y], size=2) + emb_xp = _pull_box_sparse(x, size=2) + layers.Print(emb_xp) + concat = layers.concat([emb_x, emb_y], axis=1) + fc = layers.fc(input=concat, + name="fc", + size=1, + num_flatten_dims=1, + bias_attr=False) + loss = layers.reduce_mean(fc) + layers.Print(loss) + place = fluid.CPUPlace() if is_cpu or not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + exe = fluid.Executor(place) + optimizer = fluid.optimizer.SGD(learning_rate=0.5) + batch_size = 2 + + def binary_print(slot, fout): + fout.write(str(len(slot)) + " ") + for e in slot: + fout.write(str(e) + " ") + + batch1 = np.ones( + (batch_size, 2, 1)).astype("int64").reshape(batch_size, 2, 1) + filelist = [] + for i in range(2): + filelist.append("test_hdfs_" + str(i)) + for f in filelist: + with open(f, "w") as fout: + for ins in batch1: + for slot in ins: + binary_print(slot, fout) + fout.write("\n") + + def create_dataset(): + dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") + dataset.set_use_var([x, y]) + dataset.set_batch_size(2) + dataset.set_thread(1) + dataset.set_filelist(filelist) + return dataset + + datasets = [] + datasets.append(create_dataset()) + datasets.append(create_dataset()) + optimizer.minimize(loss) + exe.run(fluid.default_startup_program()) + datasets[0].load_into_memory() + datasets[0].begin_pass() + datasets[1].preload_into_memory() + exe.train_from_dataset( + program=fluid.default_main_program(), + dataset=datasets[0], + print_period=1) + datasets[0].end_pass() + datasets[1].wait_preload_done() + datasets[1].begin_pass() + exe.train_from_dataset( + program=fluid.default_main_program(), + dataset=datasets[1], + print_period=1) + datasets[1].end_pass() + for f in filelist: + os.remove(f) + + +if __name__ == '__main__': + unittest.main()