提交 5b871982 编写于 作者: Q Qiao Longfei

Merge pull request #14589 from jacquesqiao/refactor-prefetch

Refactor prefetch
test=release/1.2
上级 25c2cdaf
......@@ -862,7 +862,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id =
GetVarDeviceID(*result, output->Name(), *sharded_var_device);
PADDLE_ENFORCE_NE(outvar_dev_id, -1);
PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name());
}
p = places_[outvar_dev_id];
ir::Node *new_node = nullptr;
......
......@@ -37,7 +37,13 @@ if (WITH_GPU)
SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} cub)
endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS})
SET(OP_PREFETCH_DEPS "")
if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
# warpctc_op needs cudnn 7 above
if (WITH_GPU AND NOT WIN32)
......
......@@ -9,36 +9,37 @@ else()
endif()
configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
return()
endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory)
else()
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc memory)
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
cc_test(brpc_server_test SRCS rpc_server_test.cc
DEPS ${brpc_test_depends} SERIAL)
cc_test(brpc_server_test SRCS rpc_server_test.cc
DEPS ${brpc_test_depends} SERIAL)
cc_test(brpc_serde_test SRCS brpc_serde_test.cc
DEPS ${brpc_test_depends} SERIAL)
cc_test(brpc_serde_test SRCS brpc_serde_test.cc
DEPS ${brpc_test_depends} SERIAL)
endif()
......@@ -171,11 +171,13 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
......@@ -186,11 +188,12 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
s, method, h, this] {
s, method, h, table_name_val, this] {
auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
0, table_name_val);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
......
......@@ -194,6 +194,7 @@ class GRPCClient : public RPCClient {
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendBatchBarrier(
......
......@@ -42,7 +42,8 @@ static void SerializeDestroyCallback(void* payload) {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, const std::string& out_name,
const int trainer_id) {
const int trainer_id,
const std::string& table_name) {
platform::RecordRPCEvent record_event("serial", &ctx);
VarMsg request;
TensorPayload* payload = nullptr;
......@@ -63,6 +64,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (!out_name.empty()) {
request.set_out_varname(out_name);
}
if (!table_name.empty()) {
request.set_table_name(table_name);
}
if (var->IsType<framework::LoDTensor>()) {
request.set_type(::sendrecv::LOD_TENSOR);
payload = new TensorPayload(GetTensorPayload(var, ctx, &request));
......
......@@ -40,7 +40,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string(),
const int trainer_id = 0);
const int trainer_id = 0,
const std::string& table_name = std::string());
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
......
......@@ -130,7 +130,8 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
math::set_constant(ctx, tensor, 31.9);
::grpc::ByteBuffer msg;
operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg);
operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg,
"outvar", 0, "table_name");
EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize
......
......@@ -183,6 +183,7 @@ class RequestPrefetch final : public RequestBase {
// prefetch process...
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name;
......@@ -193,7 +194,7 @@ class RequestPrefetch final : public RequestBase {
framework::Variable* outvar = scope->Var(out_var_name);
request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name);
out_var_name, table_name);
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_);
......
......@@ -301,6 +301,20 @@ int GRPCVariableResponse::Parse(Source* source) {
meta_.set_trainer_id(trainer_id);
break;
}
case sendrecv::VariableMessage::kTableNameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_table_name(temp);
break;
}
default: {
// Unknown tag, return unknown error.
return -1;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
namespace paddle {
namespace operators {
namespace distributed {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
static size_t GetSectionIndex(int64_t id,
const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) {
if (id < abs_sections[i]) {
return i - 1;
}
}
return abs_sections.size() - 1;
}
static std::vector<int64_t> ToAbsoluteSection(
const std::vector<int>& height_sections) {
std::vector<int64_t> abs_sections;
abs_sections.resize(height_sections.size());
abs_sections[0] = 0;
for (size_t i = 1; i < height_sections.size(); ++i) {
abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1];
}
return abs_sections;
}
static std::vector<std::vector<int64_t>> SplitIds(
const std::vector<int64_t>& ids_vector,
const std::vector<int>& height_section, framework::Scope* scope) {
std::set<int64_t> all_ids;
for (auto id : ids_vector) {
all_ids.insert(id);
}
auto abs_sections = ToAbsoluteSection(height_section);
std::vector<std::vector<int64_t>> splited_ids;
splited_ids.resize(height_section.size() + 1);
for (auto& id : all_ids) {
auto section_index = GetSectionIndex(id, abs_sections);
splited_ids[section_index].push_back(id - abs_sections[section_index]);
}
return splited_ids;
}
static void SplitIdsIntoMultipleVarsBySection(
const std::vector<std::string>& in_var_names,
const std::vector<int>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids,
framework::Scope* scope) {
PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), "");
auto place = platform::CPUPlace();
for (size_t i = 0; i < in_var_names.size(); ++i) {
auto* id_tensor =
scope->Var(in_var_names[i])->GetMutable<framework::LoDTensor>();
auto& ids = splited_ids[i];
if (!ids.empty()) {
auto* id_tensor_data = id_tensor->mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size());
}
}
}
static void MergeMultipleVarsIntoOneBySection(
const std::string& id_name, const std::vector<int64_t>& ids_vector,
const std::string& out_name, const std::vector<std::string>& out_var_names,
const std::vector<int>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids,
const framework::ExecutionContext& context, framework::Scope* scope,
platform::DeviceContext* actual_ctx) {
PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), "");
auto cpu_place = platform::CPUPlace();
auto abs_sections = ToAbsoluteSection(height_section);
std::unordered_map<int64_t, std::vector<size_t>> id_to_offset;
for (size_t i = 0; i < ids_vector.size(); ++i) {
id_to_offset[ids_vector[i]].push_back(i);
}
auto& id_tensor = scope->FindVar(id_name)->Get<framework::LoDTensor>();
auto* out_tensor =
scope->FindVar(out_name)->GetMutable<framework::LoDTensor>();
auto* out_tensor_data = out_tensor->mutable_data<float>(id_tensor.place());
bool is_on_cpu_place = true;
if (!platform::is_cpu_place(id_tensor.place())) {
is_on_cpu_place = false;
}
for (size_t section_idx = 0; section_idx < out_var_names.size();
++section_idx) {
auto& ids_in_this_section = splited_ids[section_idx];
if (!ids_in_this_section.empty()) {
auto& prefetch_out_var =
scope->Var(out_var_names[section_idx])->Get<framework::LoDTensor>();
const auto* out_var_data = prefetch_out_var.data<float>();
auto& dims = prefetch_out_var.dims();
PADDLE_ENFORCE_EQ(dims.size(), 2, "");
PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0]);
auto row_numel = dims[1];
for (size_t i = 0; i < dims[0]; ++i) {
auto id = ids_in_this_section[i];
auto origin_id = id + abs_sections[section_idx];
auto& offsets = id_to_offset[origin_id];
for (auto& offset : offsets) {
// should support GPU tensor
if (is_on_cpu_place) {
memory::Copy(cpu_place, out_tensor_data + offset * row_numel,
cpu_place, out_var_data + i * row_numel,
sizeof(float) * row_numel);
} else {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("paddle is not compiled with CUDA!");
#else
auto stream =
static_cast<platform::CUDADeviceContext*>(actual_ctx)->stream();
memory::Copy(boost::get<platform::CUDAPlace>(id_tensor.place()),
out_tensor_data + offset * row_numel, cpu_place,
out_var_data + i * row_numel,
sizeof(float) * row_numel, stream);
#endif
}
}
}
} else {
VLOG(3) << "ids in this section is empty";
}
}
}
void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap,
const std::vector<int>& height_sections,
const framework::ExecutionContext& context) {
auto& local_scope = context.scope().NewScope();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& cpu_ctx = *pool.Get(platform::CPUPlace());
auto& actual_ctx = *pool.Get(context.GetPlace());
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
for (size_t i = 0; i < epmap.size(); ++i) {
in_var_names.push_back(id_name + "@" + epmap[i]);
out_var_names.push_back(out_name + "@" + epmap[i]);
}
auto& id_tensor = local_scope.FindVar(id_name)->Get<framework::LoDTensor>();
std::vector<int64_t> ids_vector;
if (platform::is_cpu_place(id_tensor.place())) {
auto* id_data = id_tensor.data<int64_t>();
for (size_t i = 0; i < id_tensor.numel(); ++i) {
ids_vector.push_back(id_data[i]);
}
} else {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("paddle is not compiled with CUDA!");
#else
auto cpu_place = platform::CPUPlace();
framework::Tensor cpu_tensor;
auto* cpu_tensor_data =
cpu_tensor.mutable_data<int64_t>(id_tensor.dims(), cpu_place);
auto stream =
static_cast<platform::CUDADeviceContext*>(&actual_ctx)->stream();
memory::Copy(cpu_place, cpu_tensor_data,
boost::get<platform::CUDAPlace>(id_tensor.place()),
id_tensor.data<int64_t>(), sizeof(int64_t) * id_tensor.numel(),
stream);
for (size_t i = 0; i < cpu_tensor.numel(); ++i) {
ids_vector.push_back(cpu_tensor_data[i]);
}
#endif
}
auto splited_ids = SplitIds(ids_vector, height_sections, &local_scope);
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
&local_scope);
// create output var in local scope
for (auto& name : out_var_names) {
local_scope.Var(name)->GetMutable<framework::LoDTensor>();
}
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(local_scope, in_var_names[i])) {
VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i]
<< " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], cpu_ctx, local_scope, in_var_names[i], out_var_names[i],
table_names[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
}
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
out_var_names, height_sections, splited_ids,
context, &local_scope, &actual_ctx);
context.scope().DeleteScope(&local_scope);
}
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
namespace distributed {
void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap,
const std::vector<int>& height_sections,
const framework::ExecutionContext& context);
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
......@@ -191,7 +191,8 @@ class RequestHandler {
virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") = 0;
const std::string& out_var_name = "",
const std::string& table_name = "") = 0;
protected:
const bool sync_mode_;
......
......@@ -37,7 +37,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestSendHandler:" << varname;
// Sync
......@@ -77,8 +78,10 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestGetHandler:" << varname;
if (sync_mode_) {
if (varname == FETCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv fetch barrier message";
......@@ -113,14 +116,22 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
if (table_name.empty()) {
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
} else {
(*outvar)->GetMutable<framework::LoDTensor>();
auto lookup_table_op =
BuildLookupTableOp(table_name, varname, out_var_name);
paddle::platform::CPUPlace cpu_place;
lookup_table_op->Run(*scope, cpu_place);
}
return true;
}
......@@ -129,7 +140,8 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) {
const std::string& out_var_name,
const std::string& table_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");
......
......@@ -24,6 +24,7 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
......@@ -43,8 +44,8 @@ class RequestSendHandler final : public RequestHandler {
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
bool enable_dc_asgd_;
......@@ -59,21 +60,44 @@ class RequestGetHandler final : public RequestHandler {
virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
bool enable_dc_asgd_;
};
static inline void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
}
}
class RequestPrefetchHandler final : public RequestHandler {
public:
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
std::unique_ptr<paddle::framework::OperatorBase> BuildLookupTableOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("lookup_table");
BuildVar("W", {table_name.data()}, op_desc.add_inputs());
BuildVar("Ids", {id_name.data()}, op_desc.add_inputs());
BuildVar("Out", {out_name.data()}, op_desc.add_outputs());
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
return op;
}
};
class RequestCheckpointHandler final : public RequestHandler {
......@@ -85,8 +109,8 @@ class RequestCheckpointHandler final : public RequestHandler {
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override;
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
int checkpoint_notify_id;
......
......@@ -48,7 +48,7 @@ class RPCClient {
virtual VarHandlePtr AsyncPrefetchVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& in_var_name,
const std::string& out_var_name,
const std::string& out_var_name, const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendBatchBarrier(
......
......@@ -80,6 +80,7 @@ message VariableMessage {
// when profile switches from 1 to 2.
int64 profile = 11;
int64 trainer_id = 12;
string table_name = 13;
}
message VoidMessage {}
......@@ -85,6 +85,7 @@ class VariableResponse {
inline framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() const { return meta_.out_varname(); }
inline std::string TableName() const { return meta_.table_name(); }
// should call parse first.
framework::Variable* GetVar() {
......
......@@ -87,6 +87,25 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"If the grad op reuse the input's variable.")
.SetDefault(false);
// for parameter prefetch
AddAttr<bool>("remote_prefetch", "").SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<int>>("height_sections",
"Height for each output SelectedRows.")
.SetDefault(std::vector<int>({}));
AddAttr<std::vector<std::string>>(
"epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"table_names",
"(string vector, the splited table names that will be fetched from "
"parameter server)"
"in the order of input variables for mapping")
.SetDefault({});
AddComment(R"DOC(
Lookup Table Operator.
......
......@@ -78,27 +78,47 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto *output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto *ids = ids_t->data<int64_t>();
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8);
dim3 grids(8, 1);
if (padding_idx == -1)
LookupTable<
T, 128, 8, 8,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
else
LookupTable<
T, 128, 8, 8,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
auto id_name = context.Inputs("Ids").front();
auto out_name = context.Outputs("Out").front();
// for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto height_sections = context.Attr<std::vector<int>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (!epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_names, epmap,
height_sections, context);
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
"parameter prefetch!");
#endif
} else {
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto *ids = ids_t->data<int64_t>();
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8);
dim3 grids(8, 1);
if (padding_idx == -1)
LookupTable<T, 128, 8, 8, false><<<
grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
else
LookupTable<T, 128, 8, 8, true><<<
grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
}
}
};
......@@ -109,6 +129,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
bool is_sparse = context.Attr<bool>("is_sparse");
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
......
......@@ -23,6 +23,10 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif
namespace paddle {
namespace operators {
......@@ -41,44 +45,66 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
auto id_name = context.Inputs("Ids").front();
auto out_name = context.Outputs("Out").front();
// for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto height_sections = context.Attr<std::vector<int>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (!epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_names, epmap,
height_sections, context);
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
"parameter prefetch!");
#endif
} else {
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
}
}
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
}
}
}
......
......@@ -326,6 +326,11 @@ def embedding(input,
"""
helper = LayerHelper('embedding', **locals())
remote_prefetch = False
if os.environ.get('PADDLE_ENABLE_REMOTE_PREFETCH'):
remote_prefetch = True
if remote_prefetch:
assert is_sparse is True and is_distributed is False
w = helper.create_parameter(
attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False)
tmp = helper.create_variable_for_type_inference(dtype)
......@@ -339,6 +344,7 @@ def embedding(input,
attrs={
'is_sparse': is_sparse,
'is_distributed': is_distributed,
'remote_prefetch': remote_prefetch,
'padding_idx': padding_idx
})
return tmp
......
......@@ -16,11 +16,13 @@ from __future__ import print_function
import paddle
import paddle.fluid as fluid
import os
import dist_ctr_reader
from test_dist_base import TestDistRunnerBase, runtime_main
IS_SPARSE = True
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
# Fix seed for test
fluid.default_startup_program().random_seed = 1
......
......@@ -782,5 +782,46 @@ class TestNCCL2Transpile(TranspilerTest):
pass
# test for remote prefetch
class TestRemoteLookupTable(TestDistLookupTableBase):
def net_conf(self):
import os
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
self.network_with_table(is_sparse=True, is_distributed=False)
def transpiler_test_impl(self):
pserver1, startup1 = self.get_pserver(self.pserver1_ep)
self.assertEqual(len(pserver1.blocks), 4)
# 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
["sum", "scale", "adam", "scale", "scale"])
# 2 optimize for table adam
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
["sum", "scale", "adam", "scale", "scale"])
# 3 optimize for table 2 adam
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
["sum", "scale", "adam", "scale", "scale"])
trainer, _ = self.get_trainer()
self.assertEqual(len(trainer.blocks), 1)
ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'split_selected_rows', 'send', 'sequence_pool_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
'sum', 'split_selected_rows', 'send', 'send_barrier', 'recv',
'recv', 'fetch_barrier'
]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import signal
import time
import unittest
from multiprocessing import Process
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid.framework import Program, program_guard
def run_pserver(pserver_id, use_cuda, sync_mode):
scope = fluid.core.Scope()
program = Program()
with fluid.scope_guard(scope):
with program_guard(program, startup_program=Program()):
# create table parameter in scope
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
# create and initialize Param Variable
param = scope.var('table').get_tensor()
param_array = np.ones((10, 8)).astype("float32")
for i in range(len(param_array)):
param_array[i] *= param_array[i] * i + pserver_id * 10
param.set(param_array, place)
optimize_block = program._create_block(program.global_block().idx)
program.global_block().append_op(
type="listen_and_serv",
inputs={'X': []},
outputs={},
attrs={
"optimize_blocks": [optimize_block],
"endpoint": '127.0.0.1:0',
"Fanin": 1,
"sync_mode": True,
"grad_to_block_id": []
})
exe = fluid.Executor(place)
exe.run(program)
class TestListenAndServOp(unittest.TestCase):
def setUp(self):
self.ps_timeout = 5
def _start_pserver(self, pserver_id, use_cuda, sync_mode, pserver_func):
p = Process(target=pserver_func, args=(pserver_id, use_cuda, sync_mode))
p.daemon = True
p.start()
return p
def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout
sleep_time = 0.5
while True:
assert start_left_time >= 0, "wait ps ready failed"
time.sleep(sleep_time)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
start_left_time -= sleep_time
def _get_pserver_port(self, pid):
with open("/tmp/paddle.%d.port" % pid, 'r') as f:
port = int(f.read().strip())
return port
def _run_lookup_table_op_one_pserver(self, place, port):
scope = fluid.core.Scope()
program = Program()
with fluid.scope_guard(scope):
with program_guard(program, startup_program=Program()):
# create and initialize Param Variable
param = scope.var('W').get_tensor()
param_array = np.full((10, 8), 1.0).astype("float32")
param.set(param_array, place)
ids = scope.var('Ids').get_tensor()
ids_array = np.array([[1], [2], [5]]).astype("int64")
ids.set(ids_array, place)
ids_lod = [[0, 1, 2, 3]]
ids.set_lod(ids_lod)
out = scope.var('Out').get_tensor()
emaps = ['127.0.0.1:' + str(port)]
table_names = ['table']
height_sections = [10]
# create and run sgd operator
lookup_table_op = Operator(
"lookup_table",
W='W',
Ids='Ids',
Out='Out',
remote_prefetch=True,
epmap=emaps,
table_names=table_names,
height_sections=height_sections)
lookup_table_op.run(scope, place)
# get and compare result
result_array = np.array(out)
self.assertEqual(out.lod(), ids_lod)
self.assertEqual(list(result_array.shape), [len(ids_array), 8])
for i in range(len(ids_array)):
id = ids_array[i][0]
self.assertTrue((result_array[i] == id).all())
def _run_lookup_table_op_two_pserver(self, place, port0, port1):
scope = fluid.core.Scope()
program = Program()
with fluid.scope_guard(scope):
with program_guard(program, startup_program=Program()):
# create and initialize Param Variable
param = scope.var('W').get_tensor()
param_array = np.full((10, 8), 1.0).astype("float32")
param.set(param_array, place)
ids = scope.var('Ids').get_tensor()
ids_array = np.array([[1], [2], [11], [13]]).astype("int64")
ids.set(ids_array, place)
ids_lod = [[0, 2, 3, 4]]
ids.set_lod(ids_lod)
out = scope.var('Out').get_tensor()
emaps = ['127.0.0.1:' + str(port0), '127.0.0.1:' + str(port1)]
table_names = ['table', 'table']
height_sections = [10, 20]
# create and run sgd operator
lookup_table_op = Operator(
"lookup_table",
W='W',
Ids='Ids',
Out='Out',
remote_prefetch=True,
epmap=emaps,
table_names=table_names,
height_sections=height_sections)
lookup_table_op.run(scope, place)
# get and compare result
result_array = np.array(out)
self.assertEqual(out.lod(), ids_lod)
self.assertEqual(list(result_array.shape), [len(ids_array), 8])
for i in range(len(ids_array)):
id = ids_array[i][0]
self.assertTrue((result_array[i] == id).all())
def test_lookup_remote_table(self):
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
# run pserver on CPU in sync mode
p0 = self._start_pserver(0, False, True, run_pserver)
self._wait_ps_ready(p0.pid)
port0 = self._get_pserver_port(p0.pid)
p1 = self._start_pserver(1, False, True, run_pserver)
self._wait_ps_ready(p1.pid)
port1 = self._get_pserver_port(p1.pid)
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self._run_lookup_table_op_one_pserver(place, port0)
self._run_lookup_table_op_two_pserver(place, port0, port1)
# raise SIGTERM to pserver
os.kill(p0.pid, signal.SIGINT)
p0.join()
os.kill(p1.pid, signal.SIGINT)
p1.join()
if __name__ == '__main__':
unittest.main()
......@@ -236,6 +236,31 @@ class DistributeTranspiler(object):
else:
raise ValueError("must set trainer_id > 0")
def _get_all_remote_sparse_update_op(self, main_program):
sparse_update_ops = []
sparse_update_op_types = ["lookup_table"]
for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr(
'remote_prefetch') is True and not op.attr(
'is_distributed'):
sparse_update_ops.append(op)
return sparse_update_ops
def _update_remote_sparse_update_op(self, param_varname, height_sections,
endpint_map, table_names):
for op in self.sparse_update_ops:
if param_varname in op.input_arg_names:
op._set_attr('epmap', endpint_map)
op._set_attr('table_names', table_names)
op._set_attr('height_sections', height_sections)
op._set_attr('trainer_id', self.trainer_id)
def _is_input_of_remote_sparse_update_op(self, param_name):
for op in self.sparse_update_ops:
if param_name in op.input_arg_names:
return True
return False
def transpile(self,
trainer_id,
program=None,
......@@ -299,6 +324,12 @@ class DistributeTranspiler(object):
self.param_name_to_grad_name[param_var.name] = grad_var.name
self.grad_name_to_param_name[grad_var.name] = param_var.name
# get all sparse update ops
self.sparse_update_ops = self._get_all_remote_sparse_update_op(
self.origin_program)
# use_sparse_update_param_name -> split_height_section
self.sparse_param_to_height_sections = dict()
# add distributed attrs to program
self.origin_program._is_distributed = True
self.origin_program._endpoints = self.pserver_endpoints
......@@ -336,6 +367,13 @@ class DistributeTranspiler(object):
splited_grad_varname = splited_vars[0].name
index = find_op_by_output_arg(
program.global_block(), splited_grad_varname, reverse=True)
if splited_vars[0].type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_param_name = self.grad_name_to_param_name[
grad_varname]
if self._is_input_of_remote_sparse_update_op(
sparse_param_name):
self.sparse_param_to_height_sections[
sparse_param_name] = [splited_vars[0].shape[0]]
elif len(splited_vars) > 1:
orig_var = program.global_block().vars[splited_grad_varname]
index = find_op_by_output_arg(
......@@ -406,16 +444,18 @@ class DistributeTranspiler(object):
all_recv_outputs = []
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
eps = []
table_names = []
for var in splited_var:
index = [v.name for v in recv_vars].index(var.name)
eps.append(eplist[index])
table_names.append(var.name)
if self.sync_mode:
recv_dep_in = send_barrier_out
else:
# connect deps to send op in async mode
recv_dep_in = self.grad_name_to_send_dummy_out[
self.param_name_to_grad_name[param_varname]]
all_recv_outputs.extend(splited_var)
# get recv op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name. ParallelExecutor
# will use op_role_var to get expected device place to run this op.
......@@ -425,18 +465,25 @@ class DistributeTranspiler(object):
if len(splited_trainer_grad) == 1:
recv_op_role_var_name = splited_trainer_grad[0].name
program.global_block().append_op(
type="recv",
inputs={"X": [recv_dep_in]},
outputs={"Out": splited_var},
attrs={
"epmap": eps,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name],
"sync_mode": not self.sync_mode
})
if param_varname in self.sparse_param_to_height_sections:
height_sections = self.sparse_param_to_height_sections[
param_varname]
self._update_remote_sparse_update_op(
param_varname, height_sections, eps, table_names)
else:
all_recv_outputs.extend(splited_var)
program.global_block().append_op(
type="recv",
inputs={"X": [recv_dep_in]},
outputs={"Out": splited_var},
attrs={
"epmap": eps,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name],
"sync_mode": not self.sync_mode
})
if self.sync_mode:
# form a WAW dependency
......@@ -454,14 +501,15 @@ class DistributeTranspiler(object):
if len(splited_var) <= 1:
continue
orig_param = program.global_block().vars[param_varname]
program.global_block().append_op(
type="concat",
inputs={"X": splited_var},
outputs={"Out": [orig_param]},
attrs={
"axis": 0,
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
if param_varname not in self.sparse_param_to_height_sections:
program.global_block().append_op(
type="concat",
inputs={"X": splited_var},
outputs={"Out": [orig_param]},
attrs={
"axis": 0,
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
......@@ -1420,6 +1468,10 @@ to transpile() call.")
height_sections = []
for v in splited_vars:
height_sections.append(v.shape[0])
sparse_param_name = self.grad_name_to_param_name[orig_var.name]
if self._is_input_of_remote_sparse_update_op(sparse_param_name):
self.sparse_param_to_height_sections[
sparse_param_name] = height_sections
program.global_block()._insert_op(
index=index + 1,
type="split_selected_rows",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册