diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 0858ec6a226cc8760918e1e7427f75c7fa0b7660..36979de68f3abfdedfcc4a49cc312c1f849f5676 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -23,7 +23,7 @@ if(WITH_GRPC) cc_test(rpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler) - cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc) + cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory) else() set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) @@ -33,7 +33,7 @@ else() PROTO send_recv.proto DEPS lod_tensor selected_rows memory) - cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc) + cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc memory) set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy) diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc index 36f4f0eefddc5eacf12d2e64d154370d61caf88b..cf14538b1c284d297242197088a66cc156b1762c 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc @@ -102,7 +102,8 @@ static void MergeMultipleVarsIntoOneBySection( const std::string& out_name, const std::vector& out_var_names, const std::vector& height_section, const std::vector>& splited_ids, - const framework::ExecutionContext& context, framework::Scope* scope) { + const framework::ExecutionContext& context, framework::Scope* scope, + platform::DeviceContext* actual_ctx) { PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), ""); auto cpu_place = platform::CPUPlace(); @@ -151,10 +152,12 @@ static void MergeMultipleVarsIntoOneBySection( #ifndef PADDLE_WITH_CUDA PADDLE_THROW("paddle is not compiled with CUDA!"); #else + auto stream = + static_cast(actual_ctx)->stream(); memory::Copy(boost::get(id_tensor.place()), out_tensor_data + offset * row_numel, cpu_place, out_var_data + i * row_numel, - sizeof(float) * row_numel); + sizeof(float) * row_numel, stream); #endif } } @@ -174,6 +177,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& cpu_ctx = *pool.Get(platform::CPUPlace()); + auto& actual_ctx = *pool.Get(context.GetPlace()); distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance( @@ -201,11 +205,13 @@ void prefetch(const std::string& id_name, const std::string& out_name, framework::Tensor cpu_tensor; auto* cpu_tensor_data = cpu_tensor.mutable_data(id_tensor.dims(), cpu_place); + auto stream = + static_cast(&actual_ctx)->stream(); memory::Copy(cpu_place, cpu_tensor_data, boost::get(id_tensor.place()), - id_tensor.data(), - sizeof(int64_t) * id_tensor.numel()); - for (size_t i = 0; i < id_tensor.numel(); ++i) { + id_tensor.data(), sizeof(int64_t) * id_tensor.numel(), + stream); + for (size_t i = 0; i < cpu_tensor.numel(); ++i) { ids_vector.push_back(cpu_tensor_data[i]); } #endif @@ -239,7 +245,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, out_var_names, height_sections, splited_ids, - context, &local_scope); + context, &local_scope, &actual_ctx); context.scope().DeleteScope(&local_scope); } diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index abd5dce8f7e7146a1671a387328c177e5e6e0a85..36156a1f6174631dd084c8dc63dc432f5275008e 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -78,27 +78,47 @@ class LookupTableCUDAKernel : public framework::OpKernel { auto *output_t = context.Output("Out"); int64_t padding_idx = context.Attr("padding_idx"); - size_t N = table_t->dims()[0]; - size_t D = table_t->dims()[1]; - size_t K = ids_t->numel(); - - auto *ids = ids_t->data(); - auto *table = table_t->data(); - auto *output = output_t->mutable_data(context.GetPlace()); - - dim3 threads(128, 8); - dim3 grids(8, 1); - - if (padding_idx == -1) - LookupTable< - T, 128, 8, 8, - false><<>>( - output, table, ids, N, K, D, padding_idx); - else - LookupTable< - T, 128, 8, 8, - true><<>>( - output, table, ids, N, K, D, padding_idx); + auto id_name = context.Inputs("Ids").front(); + auto out_name = context.Outputs("Out").front(); + + // for remote prefetch + auto epmap = context.Attr>("epmap"); + auto height_sections = context.Attr>("height_sections"); + auto table_names = context.Attr>("table_names"); + + if (!epmap.empty()) { +// if epmap is not empty, then the parameter will be fetched from remote +// parameter +// server +#ifdef PADDLE_WITH_DISTRIBUTE + operators::distributed::prefetch(id_name, out_name, table_names, epmap, + height_sections, context); +#else + PADDLE_THROW( + "paddle is not compiled with distribute support, can not do " + "parameter prefetch!"); +#endif + } else { + size_t N = table_t->dims()[0]; + size_t D = table_t->dims()[1]; + size_t K = ids_t->numel(); + + auto *ids = ids_t->data(); + auto *table = table_t->data(); + auto *output = output_t->mutable_data(context.GetPlace()); + + dim3 threads(128, 8); + dim3 grids(8, 1); + + if (padding_idx == -1) + LookupTable<<< + grids, threads, 0, context.cuda_device_context().stream()>>>( + output, table, ids, N, K, D, padding_idx); + else + LookupTable<<< + grids, threads, 0, context.cuda_device_context().stream()>>>( + output, table, ids, N, K, D, padding_idx); + } } }; @@ -109,6 +129,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { auto &dev_ctx = context.template device_context(); bool is_sparse = context.Attr("is_sparse"); + // Since paddings are not trainable and fixed in forward, the gradient of // paddings makes no sense and we don't deal with it in backward. if (is_sparse) {