提交 3e45a5a5 编写于 作者: Q Qiao Longfei

lookup_table gpu kernel support prefetch

test=develop
上级 3a3cfc2d
...@@ -23,7 +23,7 @@ if(WITH_GRPC) ...@@ -23,7 +23,7 @@ if(WITH_GRPC)
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL) DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory)
else() else()
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
...@@ -33,7 +33,7 @@ else() ...@@ -33,7 +33,7 @@ else()
PROTO send_recv.proto PROTO send_recv.proto
DEPS lod_tensor selected_rows memory) DEPS lod_tensor selected_rows memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc memory)
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy) set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
......
...@@ -102,7 +102,8 @@ static void MergeMultipleVarsIntoOneBySection( ...@@ -102,7 +102,8 @@ static void MergeMultipleVarsIntoOneBySection(
const std::string& out_name, const std::vector<std::string>& out_var_names, const std::string& out_name, const std::vector<std::string>& out_var_names,
const std::vector<int>& height_section, const std::vector<int>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids, const std::vector<std::vector<int64_t>>& splited_ids,
const framework::ExecutionContext& context, framework::Scope* scope) { const framework::ExecutionContext& context, framework::Scope* scope,
platform::DeviceContext* actual_ctx) {
PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), ""); PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), "");
auto cpu_place = platform::CPUPlace(); auto cpu_place = platform::CPUPlace();
...@@ -151,10 +152,12 @@ static void MergeMultipleVarsIntoOneBySection( ...@@ -151,10 +152,12 @@ static void MergeMultipleVarsIntoOneBySection(
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("paddle is not compiled with CUDA!"); PADDLE_THROW("paddle is not compiled with CUDA!");
#else #else
auto stream =
static_cast<platform::CUDADeviceContext*>(actual_ctx)->stream();
memory::Copy(boost::get<platform::CUDAPlace>(id_tensor.place()), memory::Copy(boost::get<platform::CUDAPlace>(id_tensor.place()),
out_tensor_data + offset * row_numel, cpu_place, out_tensor_data + offset * row_numel, cpu_place,
out_var_data + i * row_numel, out_var_data + i * row_numel,
sizeof(float) * row_numel); sizeof(float) * row_numel, stream);
#endif #endif
} }
} }
...@@ -174,6 +177,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -174,6 +177,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& cpu_ctx = *pool.Get(platform::CPUPlace()); auto& cpu_ctx = *pool.Get(platform::CPUPlace());
auto& actual_ctx = *pool.Get(context.GetPlace());
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(
...@@ -201,11 +205,13 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -201,11 +205,13 @@ void prefetch(const std::string& id_name, const std::string& out_name,
framework::Tensor cpu_tensor; framework::Tensor cpu_tensor;
auto* cpu_tensor_data = auto* cpu_tensor_data =
cpu_tensor.mutable_data<int64_t>(id_tensor.dims(), cpu_place); cpu_tensor.mutable_data<int64_t>(id_tensor.dims(), cpu_place);
auto stream =
static_cast<platform::CUDADeviceContext*>(&actual_ctx)->stream();
memory::Copy(cpu_place, cpu_tensor_data, memory::Copy(cpu_place, cpu_tensor_data,
boost::get<platform::CUDAPlace>(id_tensor.place()), boost::get<platform::CUDAPlace>(id_tensor.place()),
id_tensor.data<int64_t>(), id_tensor.data<int64_t>(), sizeof(int64_t) * id_tensor.numel(),
sizeof(int64_t) * id_tensor.numel()); stream);
for (size_t i = 0; i < id_tensor.numel(); ++i) { for (size_t i = 0; i < cpu_tensor.numel(); ++i) {
ids_vector.push_back(cpu_tensor_data[i]); ids_vector.push_back(cpu_tensor_data[i]);
} }
#endif #endif
...@@ -239,7 +245,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -239,7 +245,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
out_var_names, height_sections, splited_ids, out_var_names, height_sections, splited_ids,
context, &local_scope); context, &local_scope, &actual_ctx);
context.scope().DeleteScope(&local_scope); context.scope().DeleteScope(&local_scope);
} }
......
...@@ -78,6 +78,27 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -78,6 +78,27 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto *output_t = context.Output<LoDTensor>("Out"); auto *output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx"); int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto id_name = context.Inputs("Ids").front();
auto out_name = context.Outputs("Out").front();
// for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto height_sections = context.Attr<std::vector<int>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (!epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_names, epmap,
height_sections, context);
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
"parameter prefetch!");
#endif
} else {
size_t N = table_t->dims()[0]; size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1]; size_t D = table_t->dims()[1];
size_t K = ids_t->numel(); size_t K = ids_t->numel();
...@@ -90,16 +111,15 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -90,16 +111,15 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
dim3 grids(8, 1); dim3 grids(8, 1);
if (padding_idx == -1) if (padding_idx == -1)
LookupTable< LookupTable<T, 128, 8, 8, false><<<
T, 128, 8, 8, grids, threads, 0, context.cuda_device_context().stream()>>>(
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx); output, table, ids, N, K, D, padding_idx);
else else
LookupTable< LookupTable<T, 128, 8, 8, true><<<
T, 128, 8, 8, grids, threads, 0, context.cuda_device_context().stream()>>>(
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx); output, table, ids, N, K, D, padding_idx);
} }
}
}; };
template <typename T> template <typename T>
...@@ -109,6 +129,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -109,6 +129,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto &dev_ctx = auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>(); context.template device_context<platform::CUDADeviceContext>();
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context.Attr<bool>("is_sparse");
// Since paddings are not trainable and fixed in forward, the gradient of // Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward. // paddings makes no sense and we don't deal with it in backward.
if (is_sparse) { if (is_sparse) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册