未验证 提交 c756b5d2 编写于 作者: H hutuxian 提交者: GitHub

Paddlebox Framework (#18982)

* Support looking up embeddings from BoxPS.
* Add a _pull_box_sparse op, for now this op is not exposed to users.
* Add a BoxHelper class, providing 'BeginPass', 'EndPass', 'FeedPass' functions and so on.
* Add 'BoxPSDataset' in python code.
* Add a compile options WITH_BOX_PS and a MACRO PADDLE_WITH_BOX_PS.
* Add UT.
* More concrete information pls refer to: https://github.com/PaddlePaddle/Paddle/pull/18982
上级 5dce1da6
......@@ -74,6 +74,7 @@ option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools"
option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF)
option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF)
option(WITH_PSLIB "Compile with pslib support" OFF)
option(WITH_BOX_PS "Compile with box_ps support" OFF)
option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
......@@ -170,6 +171,9 @@ if(WITH_PSLIB)
include(external/pslib_brpc)
include(external/pslib)
endif(WITH_PSLIB)
if(WITH_BOX_PS)
include(external/box_ps)
endif(WITH_BOX_PS)
if(WITH_DISTRIBUTE)
if(WITH_GRPC)
......
......@@ -62,6 +62,10 @@ if(WITH_PSLIB)
add_definitions(-DPADDLE_WITH_PSLIB)
endif()
if(WITH_BOX_PS)
add_definitions(-DPADDLE_WITH_BOX_PS)
endif()
if(WITH_GPU)
add_definitions(-DPADDLE_WITH_CUDA)
add_definitions(-DEIGEN_USE_GPU)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_BOX_PS})
return()
ENDIF(NOT ${WITH_BOX_PS})
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with BOX_PS in Paddle yet."
"Force WITH_BOX_PS=OFF")
SET(WITH_BOX_PS OFF CACHE STRING "Disable BOX_PS package in Windows and MacOS" FORCE)
return()
ENDIF()
INCLUDE(ExternalProject)
SET(BOX_PS_PROJECT "extern_box_ps")
IF((NOT DEFINED BOX_PS_VER) OR (NOT DEFINED BOX_PS_URL))
MESSAGE(STATUS "use pre defined download url")
SET(BOX_PS_VER "0.1.1" CACHE STRING "" FORCE)
SET(BOX_PS_NAME "box_ps" CACHE STRING "" FORCE)
SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps_stub.tar.gz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "BOX_PS_NAME: ${BOX_PS_NAME}, BOX_PS_URL: ${BOX_PS_URL}")
SET(BOX_PS_SOURCE_DIR "${THIRD_PARTY_PATH}/box_ps")
SET(BOX_PS_DOWNLOAD_DIR "${BOX_PS_SOURCE_DIR}/src/${BOX_PS_PROJECT}")
SET(BOX_PS_DST_DIR "box_ps")
SET(BOX_PS_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(BOX_PS_INSTALL_DIR ${BOX_PS_INSTALL_ROOT}/${BOX_PS_DST_DIR})
SET(BOX_PS_ROOT ${BOX_PS_INSTALL_DIR})
SET(BOX_PS_INC_DIR ${BOX_PS_ROOT}/include)
SET(BOX_PS_LIB_DIR ${BOX_PS_ROOT}/lib)
SET(BOX_PS_LIB ${BOX_PS_LIB_DIR}/libbox_ps.so)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${BOX_PS_ROOT}/lib")
INCLUDE_DIRECTORIES(${BOX_PS_INC_DIR})
FILE(WRITE ${BOX_PS_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(BOX_PS)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${BOX_PS_NAME}/include ${BOX_PS_NAME}/lib \n"
" DESTINATION ${BOX_PS_DST_DIR})\n")
ExternalProject_Add(
${BOX_PS_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${BOX_PS_SOURCE_DIR}
DOWNLOAD_DIR ${BOX_PS_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${BOX_PS_URL} -c -q -O ${BOX_PS_NAME}.tar.gz
&& tar zxvf ${BOX_PS_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${BOX_PS_INSTALL_ROOT}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${BOX_PS_INSTALL_ROOT}
)
ADD_LIBRARY(box_ps SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET box_ps PROPERTY IMPORTED_LOCATION ${BOX_PS_LIB})
ADD_DEPENDENCIES(box_ps ${BOX_PS_PROJECT})
......@@ -123,8 +123,8 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope
glog box_wrapper shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
......@@ -179,7 +179,7 @@ if(WITH_DISTRIBUTE)
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
......@@ -190,7 +190,7 @@ else()
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method
lod_rank_table fs shell fleet_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
......
......@@ -146,6 +146,7 @@ class DatasetImpl : public Dataset {
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; }
virtual Channel<T> GetInputChannel() { return input_channel_; }
virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_);
......
......@@ -5,3 +5,8 @@ else()
endif(WITH_PSLIB)
cc_library(nccl_wrapper SRCS nccl_wrapper.cc DEPS framework_proto variable_helper scope)
if(WITH_BOX_PS)
cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor box_ps)
else()
cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor)
endif(WITH_BOX_PS)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace framework {
std::shared_ptr<BoxWrapper> BoxWrapper::s_instance_ = nullptr;
#ifdef PADDLE_WITH_BOX_PS
std::shared_ptr<paddle::boxps::BoxPSBase> BoxWrapper::boxps_ptr_ = nullptr;
#endif
int BoxWrapper::GetDate() const {
time_t now = time(0);
tm t;
#ifdef _WIN32
localtime_s(&t, &now);
#else
localtime_r(&now, &t);
#endif
char buf[10];
snprintf(buf, sizeof(buf), "%04d%02d%02d", (1900 + t.tm_year), (1 + t.tm_mon),
t.tm_mday);
return atoi(buf);
}
void BoxWrapper::FeedPass(const std::vector<uint64_t>& feasgin_to_box) const {
#ifdef PADDLE_WITH_BOX_PS
int ret = boxps_ptr_->FeedPass(GetDate(), feasgin_to_box);
PADDLE_ENFORCE_EQ(ret, 0, "FeedPass failed in BoxPS.");
#endif
}
void BoxWrapper::BeginPass() const {
#ifdef PADDLE_WITH_BOX_PS
int ret = boxps_ptr_->BeginPass();
PADDLE_ENFORCE_EQ(ret, 0, "BeginPass failed in BoxPS.");
#endif
}
void BoxWrapper::EndPass() const {
#ifdef PADDLE_WITH_BOX_PS
int ret = boxps_ptr_->EndPass();
PADDLE_ENFORCE_EQ(ret, 0, "EndPass failed in BoxPS.");
#endif
}
void BoxWrapper::PullSparse(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size) {
#ifdef PADDLE_WITH_BOX_PS
if (platform::is_cpu_place(place) || platform::is_gpu_place(place)) {
int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
LoDTensor total_keys_tensor;
int64_t* total_keys =
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place);
int64_t offset = 0;
for (size_t i = 0; i < keys.size(); ++i) {
if (platform::is_cpu_place(place)) {
memory::Copy(boost::get<platform::CPUPlace>(place), total_keys + offset,
boost::get<platform::CPUPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t));
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
memory::Copy(boost::get<platform::CUDAPlace>(place),
total_keys + offset,
boost::get<platform::CUDAPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t), nullptr);
#else
PADDLE_THROW(
"Please compile WITH_GPU option, and NCCL doesn't support "
"windows.");
#endif
}
offset += slot_lengths[i];
}
PADDLE_ENFORCE_EQ(offset, total_length,
"BoxWrapper::PullSparse: total feasign keys length "
"should be equal to the sum of length of all input "
"tensors.");
// Space allocation for FeatureValue is left for boxps
paddle::boxps::FeatureValue* total_values;
if (platform::is_cpu_place(place)) {
int ret = boxps_ptr_->PullSparseCPU(
reinterpret_cast<uint64_t*>(total_keys), &total_values,
static_cast<int>(total_length));
PADDLE_ENFORCE_EQ(ret, 0, "PullSparseCPU failed in BoxPS.");
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int ret = boxps_ptr_->PullSparseGPU(
reinterpret_cast<uint64_t*>(total_keys), &total_values,
static_cast<int>(total_length),
boost::get<platform::CUDAPlace>(place).GetDeviceId());
PADDLE_ENFORCE_EQ(ret, 0, "PullSparseGPU failed in BoxPS.");
#endif
}
offset = 0;
for (size_t i = 0; i < values.size(); ++i) {
int64_t fea_num = slot_lengths[i];
for (auto j = 0; j < fea_num; ++j) {
// Copy the emb from BoxPS to paddle tensor. Since 'show','click','emb'
// are continuous in memory, so we copy here using the 'show' address
if (platform::is_cpu_place(place)) {
memory::Copy(
boost::get<platform::CPUPlace>(place),
values[i] + j * hidden_size,
boost::get<platform::CPUPlace>(place),
reinterpret_cast<float*>(&((total_values + offset)->show)),
sizeof(float) * hidden_size);
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
memory::Copy(
boost::get<platform::CUDAPlace>(place),
values[i] + j * hidden_size,
boost::get<platform::CUDAPlace>(place),
reinterpret_cast<float*>(&((total_values + offset)->show)),
sizeof(float) * hidden_size, nullptr);
#endif
}
++offset;
}
}
PADDLE_ENFORCE_EQ(offset, total_length,
"BoxWrapper::PullSparse: total emb values length should "
"be equal to the sum of length of all input tensors.");
} else {
PADDLE_THROW(
"PaddleBox: PullSparse Only Support CPUPlace and CUDAPlace Now.");
}
#endif
}
void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size) {
#ifdef PADDLE_WITH_BOX_PS
if (platform::is_cpu_place(place) || platform::is_gpu_place(place)) {
int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
LoDTensor total_keys_tensor;
int64_t* total_keys =
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place);
int64_t offset = 0;
for (size_t i = 0; i < keys.size(); ++i) {
if (platform::is_cpu_place(place)) {
memory::Copy(boost::get<platform::CPUPlace>(place), total_keys + offset,
boost::get<platform::CPUPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t));
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
memory::Copy(boost::get<platform::CUDAPlace>(place),
total_keys + offset,
boost::get<platform::CUDAPlace>(place), keys[i],
slot_lengths[i] * sizeof(uint64_t), nullptr);
#else
PADDLE_THROW(
"Please compile WITH_GPU option, and for now NCCL doesn't support "
"windows.");
#endif
}
offset += slot_lengths[i];
}
PADDLE_ENFORCE_EQ(offset, total_length,
"BoxWrapper::PushSparseGrad: total feasign keys length "
"should be equal to the sum of length of all input "
"tensors.");
auto buf = memory::AllocShared(
place, total_length * sizeof(paddle::boxps::FeaturePushValue));
paddle::boxps::FeaturePushValue* total_grad_values =
reinterpret_cast<paddle::boxps::FeaturePushValue*>(buf->ptr());
offset = 0;
for (size_t i = 0; i < grad_values.size(); ++i) {
int64_t fea_num = slot_lengths[i];
for (auto j = 0; j < fea_num; ++j) {
// Copy the emb grad from paddle tensor to BoxPS. Since
// 'show','click','emb' are continuous in memory, so we copy here using
// the 'show' address
if (platform::is_cpu_place(place)) {
memory::Copy(
boost::get<platform::CPUPlace>(place),
reinterpret_cast<float*>(&((total_grad_values + offset)->show)),
boost::get<platform::CPUPlace>(place),
grad_values[i] + j * hidden_size, sizeof(float) * hidden_size);
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
memory::Copy(
boost::get<platform::CUDAPlace>(place),
reinterpret_cast<float*>(&((total_grad_values + offset)->show)),
boost::get<platform::CUDAPlace>(place),
grad_values[i] + j * hidden_size, sizeof(float) * hidden_size,
nullptr);
#endif
}
++offset;
}
}
PADDLE_ENFORCE_EQ(offset, total_length,
"BoxWrapper::PushSparseGrad: total emb grad values "
"length should be equal to the sum of length of all "
"input tensors.");
if (platform::is_cpu_place(place)) {
int ret = boxps_ptr_->PushSparseCPU(
reinterpret_cast<uint64_t*>(total_keys), total_grad_values,
static_cast<int>(total_length));
PADDLE_ENFORCE_EQ(ret, 0, "PushSparseCPU failed in BoxPS.");
} else {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int ret = boxps_ptr_->PushSparseGPU(
reinterpret_cast<uint64_t*>(total_keys), total_grad_values,
static_cast<int>(total_length),
boost::get<platform::CUDAPlace>(place).GetDeviceId());
PADDLE_ENFORCE_EQ(ret, 0, "PushSparseGPU failed in BoxPS.");
#endif
}
} else {
PADDLE_THROW(
"PaddleBox: PushSparse Only Support CPUPlace and CUDAPlace Now.");
}
#endif
}
} // end namespace framework
} // end namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_set.h"
#ifdef PADDLE_WITH_BOX_PS
#include <boxps.h>
#endif
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
class BoxWrapper {
public:
virtual ~BoxWrapper() {}
BoxWrapper() {}
void FeedPass(const std::vector<uint64_t>& feasgin_to_box) const;
void BeginPass() const;
void EndPass() const;
void PullSparse(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size);
void PushSparseGrad(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size);
static std::shared_ptr<BoxWrapper> GetInstance() {
if (nullptr == s_instance_) {
// If main thread is guaranteed to init this, this lock can be removed
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (nullptr == s_instance_) {
s_instance_.reset(new paddle::framework::BoxWrapper());
#ifdef PADDLE_WITH_BOX_PS
s_instance_->boxps_ptr_.reset(new paddle::boxps::FakeBoxPS());
#endif
}
}
return s_instance_;
}
private:
#ifdef PADDLE_WITH_BOX_PS
static std::shared_ptr<paddle::boxps::BoxPSBase> boxps_ptr_;
#endif
static std::shared_ptr<BoxWrapper> s_instance_;
int GetDate() const;
};
class BoxHelper {
public:
explicit BoxHelper(paddle::framework::Dataset* dataset) : dataset_(dataset) {}
virtual ~BoxHelper() {}
void BeginPass() {
auto box_ptr = BoxWrapper::GetInstance();
box_ptr->BeginPass();
}
void EndPass() {
auto box_ptr = BoxWrapper::GetInstance();
box_ptr->EndPass();
}
void LoadIntoMemory() {
dataset_->LoadIntoMemory();
FeedPass();
}
void PreLoadIntoMemory() {
dataset_->PreLoadIntoMemory();
feed_data_thread_.reset(new std::thread([&]() {
dataset_->WaitPreLoadDone();
FeedPass();
}));
}
void WaitFeedPassDone() { feed_data_thread_->join(); }
private:
Dataset* dataset_;
std::shared_ptr<std::thread> feed_data_thread_;
// notify boxps to feed this pass feasigns from SSD to memory
void FeedPass() {
auto box_ptr = BoxWrapper::GetInstance();
auto input_channel_ =
dynamic_cast<MultiSlotDataset*>(dataset_)->GetInputChannel();
std::vector<Record> pass_data;
std::vector<uint64_t> feasign_to_box;
input_channel_->ReadAll(pass_data);
for (const auto& ins : pass_data) {
const auto& feasign_v = ins.uint64_feasigns_;
for (const auto feasign : feasign_v) {
feasign_to_box.push_back(feasign.sign().uint64_feasign_);
}
}
input_channel_->Open();
input_channel_->Write(pass_data);
input_channel_->Close();
box_ptr->FeedPass(feasign_to_box);
}
};
} // end namespace framework
} // end namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/pull_box_sparse_op.h"
namespace paddle {
namespace operators {
class PullBoxSparseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("Ids").size(), 1UL,
"Inputs(Ids) of PullBoxSparseOp should not be empty.");
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
"Outputs(Out) of PullBoxSparseOp should not be empty.");
auto hidden_size = static_cast<int64_t>(ctx->Attrs().Get<int>("size"));
auto all_ids_dim = ctx->GetInputsDim("Ids");
const size_t n_ids = all_ids_dim.size();
std::vector<framework::DDim> outs_dims;
outs_dims.resize(n_ids);
for (size_t i = 0; i < n_ids; ++i) {
const auto ids_dims = all_ids_dim[i];
int ids_rank = ids_dims.size();
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
"Shape error in %lu id, the last dimension of the "
"'Ids' tensor must be 1.",
i);
auto out_dim = framework::vectorize(
framework::slice_ddim(ids_dims, 0, ids_rank - 1));
out_dim.push_back(hidden_size);
outs_dims[i] = framework::make_ddim(out_dim);
}
ctx->SetOutputsDim("Out", outs_dims);
for (size_t i = 0; i < n_ids; ++i) {
ctx->ShareLoD("Ids", "Out", i, i);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
}
};
class PullBoxSparseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids",
"Input tensors with type int32 or int64 "
"contains the ids to be looked up in BoxPS. "
"The last dimension size must be 1.")
.AsDuplicable();
AddOutput("Out", "The lookup results tensors.").AsDuplicable();
AddAttr<int>("size", "(int, the embedding hidden size").SetDefault(1);
AddComment(R"DOC(
Pull Box Sparse Operator.
This operator is used to perform lookups on the BoxPS,
then concatenated into a dense tensor.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC");
}
};
class PushBoxSparseOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("push_box_sparse");
op->SetInput("Ids", Input("Ids"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetAttrMap(Attrs());
return op;
}
};
class PushBoxSparseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.MultiInput<framework::Tensor>(framework::GradVarName("Out"))[0]
->type(),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(pull_box_sparse, ops::PullBoxSparseOp,
ops::PullBoxSparseOpMaker, ops::PushBoxSparseOpDescMaker);
REGISTER_OPERATOR(push_box_sparse, ops::PushBoxSparseOp);
REGISTER_OP_CPU_KERNEL(pull_box_sparse, ops::PullBoxSparseCPUKernel<float>)
REGISTER_OP_CPU_KERNEL(push_box_sparse, ops::PushBoxSparseCPUKernel<float>)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/pull_box_sparse_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using LoDTensor = framework::LoDTensor;
template <typename T>
class PullBoxSparseCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PullBoxSparseFunctor<T>(ctx);
}
};
template <typename T>
class PushBoxSparseCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PushBoxSparseFunctor<T>(ctx);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(pull_box_sparse, ops::PullBoxSparseCUDAKernel<float>)
REGISTER_OP_CUDA_KERNEL(push_box_sparse, ops::PushBoxSparseCUDAKernel<float>)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
template <typename T>
static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) {
auto inputs = ctx.MultiInput<framework::Tensor>("Ids");
auto outputs = ctx.MultiOutput<framework::Tensor>("Out");
auto hidden_size = ctx.Attr<int>("size");
const auto slot_size = inputs.size();
std::vector<const uint64_t *> all_keys(slot_size);
// BoxPS only supports float now
std::vector<float *> all_values(slot_size);
std::vector<int64_t> slot_lengths(slot_size);
for (size_t i = 0; i < slot_size; i++) {
const auto *slot = inputs[i];
const uint64_t *single_slot_keys =
reinterpret_cast<const uint64_t *>(slot->data<int64_t>());
all_keys[i] = single_slot_keys;
slot_lengths[i] = slot->numel();
auto *output = outputs[i]->mutable_data<T>(ctx.GetPlace());
all_values[i] = output;
}
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths,
hidden_size);
}
template <typename T>
static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) {
auto inputs = ctx.MultiInput<framework::Tensor>("Ids");
auto d_output =
ctx.MultiInput<framework::Tensor>(framework::GradVarName("Out"));
auto hidden_size = ctx.Attr<int>("size");
const auto slot_size = inputs.size();
std::vector<const uint64_t *> all_keys(slot_size);
std::vector<const float *> all_grad_values(slot_size);
std::vector<int64_t> slot_lengths(slot_size);
for (size_t i = 0; i < slot_size; i++) {
const auto *slot = inputs[i];
const uint64_t *single_slot_keys =
reinterpret_cast<const uint64_t *>(slot->data<int64_t>());
all_keys[i] = single_slot_keys;
slot_lengths[i] = slot->numel();
const float *grad_value = d_output[i]->data<float>();
all_grad_values[i] = grad_value;
}
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values,
slot_lengths, hidden_size);
}
using LoDTensor = framework::LoDTensor;
template <typename T>
class PullBoxSparseCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PullBoxSparseFunctor<T>(ctx);
}
};
template <typename T>
class PushBoxSparseCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PushBoxSparseFunctor<T>(ctx);
}
};
} // namespace operators
} // namespace paddle
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper nccl_wrapper prune
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper nccl_wrapper prune
feed_fetch_method pass_builder parallel_executor profiler layer scope_pool
tracer analysis_predictor imperative_profiler nccl_context)
......@@ -17,6 +17,7 @@ set(PYBIND_SRCS
const_value.cc
reader_py.cc
fleet_wrapper_py.cc
box_helper_py.cc
nccl_wrapper_py.cc
data_set_py.cc
imperative.cc
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/pybind/box_helper_py.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindBoxHelper(py::module* m) {
py::class_<framework::BoxHelper, std::shared_ptr<framework::BoxHelper>>(
*m, "BoxPS")
.def(py::init([](paddle::framework::Dataset* dataset) {
return std::make_shared<paddle::framework::BoxHelper>(dataset);
}))
.def("begin_pass", &framework::BoxHelper::BeginPass)
.def("end_pass", &framework::BoxHelper::EndPass)
.def("wait_feed_pass_done", &framework::BoxHelper::WaitFeedPassDone)
.def("preload_into_memory", &framework::BoxHelper::PreLoadIntoMemory)
.def("load_into_memory", &framework::BoxHelper::LoadIntoMemory);
} // end BoxHelper
} // end namespace pybind
} // end namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindBoxHelper(py::module* m);
} // namespace pybind
} // namespace paddle
......@@ -51,6 +51,7 @@ limitations under the License. */
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/box_helper_py.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/pybind/exception.h"
......@@ -1691,6 +1692,7 @@ All parameter, weight, gradient are variables in Paddle.
});
BindFleetWrapper(&m);
BindBoxHelper(&m);
#ifndef _WIN32
BindNCCLWrapper(&m);
#endif
......
......@@ -720,3 +720,54 @@ class FileInstantDataset(DatasetBase):
raise NotImplementedError(
"FileInstantDataset does not support global shuffle, "
"please use InMemoryDataset for global_shuffle")
class BoxPSDataset(InMemoryDataset):
"""
BoxPSDataset: derived from InMemoryDataset.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory.create_dataset("BoxPSDataset")
"""
def __init__(self):
"""
Init
"""
super(BoxPSDataset, self).__init__()
self.boxps = core.BoxPS(self.dataset)
def begin_pass(self):
"""
Notify BoxPS to begin next pass
"""
self.boxps.begin_pass()
def end_pass(self):
"""
Notify BoxPS to end current pass
"""
self.boxps.end_pass()
def wait_preload_done(self):
"""
Wait async proload done
"""
self.boxps.wait_feed_pass_done()
def load_into_memory(self):
"""
Load next pass into memory and notify boxps to fetch its emb from SSD
"""
self._prepare_to_run()
self.boxps.load_into_memory()
def preload_into_memory(self):
"""
begin async preload next pass while current pass may be training
"""
self._prepare_to_run()
self.boxps.preload_into_memory()
......@@ -516,6 +516,54 @@ def embedding(input,
return tmp
def _pull_box_sparse(input, size, dtype='float32'):
"""
**Pull Box Sparse Layer**
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
BoxPS lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`.
Args:
input(Variable|list of Variable): Input is a Tensor<int64> Variable, which
contains the IDs information.
size(int): The embedding size parameter, which indicates the size of
each embedding vector respectively.
dtype(str): The dtype refers to the data type of output tensor. Only supports
float32 now.
Returns:
Variable|list of Variable: The tensor variable storing the embeddings of the \
supplied inputs.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.layers.data(name='sequence', shape=[1], dtype='int64', lod_level=1)
emb = fluid.layers.pull_box_sparse(input=data, size=[11])
"""
helper = LayerHelper('pull_box_sparse', **locals())
if dtype != 'float32':
raise ValueError(
"BoxPS only support float type embedding now, and your type is: " +
dtype)
helper.input_dtype()
inputs = helper.multiple_input()
outs = [
helper.create_variable_for_type_inference(dtype)
for i in range(len(inputs))
]
helper.append_op(
type='pull_box_sparse',
inputs={'Ids': inputs},
outputs={'Out': outs},
attrs={'size': size})
if len(outs) == 1:
return outs[0]
return outs
@templatedoc(op_type="lstm")
def dynamic_lstm(input,
size,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import numpy as np
import os
import paddle.fluid.core as core
import unittest
from paddle.fluid.layers.nn import _pull_box_sparse
class TestBoxPSPreload(unittest.TestCase):
""" TestCases for BoxPS Preload """
def test_boxps_cpu(self):
self.run_boxps_preload(True)
def test_boxps_gpu(self):
self.run_boxps_preload(False)
def run_boxps_preload(self, is_cpu=True):
x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0)
y = fluid.layers.data(name='y', shape=[1], dtype='int64', lod_level=0)
emb_x, emb_y = _pull_box_sparse([x, y], size=2)
emb_xp = _pull_box_sparse(x, size=2)
layers.Print(emb_xp)
concat = layers.concat([emb_x, emb_y], axis=1)
fc = layers.fc(input=concat,
name="fc",
size=1,
num_flatten_dims=1,
bias_attr=False)
loss = layers.reduce_mean(fc)
layers.Print(loss)
place = fluid.CPUPlace() if is_cpu or not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
optimizer = fluid.optimizer.SGD(learning_rate=0.5)
batch_size = 2
def binary_print(slot, fout):
fout.write(str(len(slot)) + " ")
for e in slot:
fout.write(str(e) + " ")
batch1 = np.ones(
(batch_size, 2, 1)).astype("int64").reshape(batch_size, 2, 1)
filelist = []
for i in range(2):
filelist.append("test_hdfs_" + str(i))
for f in filelist:
with open(f, "w") as fout:
for ins in batch1:
for slot in ins:
binary_print(slot, fout)
fout.write("\n")
def create_dataset():
dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
dataset.set_use_var([x, y])
dataset.set_batch_size(2)
dataset.set_thread(1)
dataset.set_filelist(filelist)
return dataset
datasets = []
datasets.append(create_dataset())
datasets.append(create_dataset())
optimizer.minimize(loss)
exe.run(fluid.default_startup_program())
datasets[0].load_into_memory()
datasets[0].begin_pass()
datasets[1].preload_into_memory()
exe.train_from_dataset(
program=fluid.default_main_program(),
dataset=datasets[0],
print_period=1)
datasets[0].end_pass()
datasets[1].wait_preload_done()
datasets[1].begin_pass()
exe.train_from_dataset(
program=fluid.default_main_program(),
dataset=datasets[1],
print_period=1)
datasets[1].end_pass()
for f in filelist:
os.remove(f)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册