未验证 提交 61343fbf 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #10531 from typhoonzero/refine_grpc_serde_code

Refine serde code
...@@ -29,129 +29,127 @@ namespace paddle { ...@@ -29,129 +29,127 @@ namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
using VarMsg = sendrecv::VariableMessage;
void GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size) {
auto tensor = var->Get<framework::LoDTensor>();
// FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto
request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
for (auto& dim : framework::vectorize(tensor.dims())) {
request->add_dims(dim);
}
const framework::LoD lod = tensor.lod();
if (lod.size() > 0) {
request->set_lod_level(lod.size());
for (auto& each : lod) {
VarMsg::LodData* lod_inner = request->add_lod();
for (auto& d : each) {
lod_inner->add_lod_data(d);
}
}
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
platform::CPUPlace cpu;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
*payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, *payload, boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()), copy_size,
gpu_dev_ctx.stream());
ctx.Wait();
#endif
} else {
*payload = tensor.data<void>();
}
*payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
}
void GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size) {
auto* slr = var->GetMutable<framework::SelectedRows>();
request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
request->set_lod_level(0);
request->set_slr_height(slr->height());
for (auto& dim : framework::vectorize(slr->value().dims())) {
request->add_dims(dim);
}
auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
*payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, *payload,
boost::get<platform::CUDAPlace>(tensor->place()),
reinterpret_cast<const void*>(tensor->data<void>()), copy_size,
gpu_dev_ctx.stream());
ctx.Wait();
#endif
} else {
*payload = slr->mutable_value()->data<void>();
}
*payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
}
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, ::grpc::ByteBuffer* msg,
const std::string& out_name) { const std::string& out_name) {
using VarMsg = sendrecv::VariableMessage; // Default DestroyCallback does nothing, When using GPU
// When using GPU, need to free the copied CPU buffer // the CPU buffer need to be freed.
// when the ByteBuffer destroies
// TODO(typhoonzero): add unref here, if we have dependent
// parallelism execution, need to know when to free the tensor.
DestroyCallback destroy_callback = [](void* backing) {}; DestroyCallback destroy_callback = [](void* backing) {};
VarMsg request;
auto buffer = std::unique_ptr<char[]>(new char[1024]);
void* buf = buffer.get();
void* payload = nullptr; void* payload = nullptr;
size_t payload_size; size_t payload_size;
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
request.set_varname(name);
// Note: normally the profiler is enabled in 1 trainer, hence only // Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS // 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the // servers the trainer's profiling state so that PS can follow the
// trainer. // trainer.
if (platform::ShouldSendProfileState()) { request.set_profile(platform::IsProfileEnabled());
e.WriteBool(VarMsg::kProfileFieldNumber, platform::IsProfileEnabled()); if (!out_name.empty()) {
request.set_out_varname(out_name);
} }
e.WriteString(VarMsg::kVarnameFieldNumber, name);
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 0); request.set_type(::sendrecv::LOD_TENSOR);
GetTensorPayload(var, ctx, &request, &payload, &payload_size);
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 1); request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
} }
if (!out_name.empty()) { if (platform::is_gpu_place(ctx.GetPlace())) {
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name); // GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback = [](void* backing) {
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
} }
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarType_Type_LOD_TENSOR: {
auto tensor = var->Get<framework::LoDTensor>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
framework::ToDataType(tensor.type()));
for (auto& dim : framework::vectorize(tensor.dims())) {
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
}
auto lod = tensor.lod(); // std::vector<Vector<size_t>>
if (lod.size() > 0) {
e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size());
for (auto& each : lod) {
e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber,
2 + // tag + varintlength of submessage
1 + // kLodDataFieldNumber
each.size());
// auto copied from GPU
for (auto& d : each) {
e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d);
}
}
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
platform::CPUPlace cpu;
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()),
copy_size, gpu_dev_ctx.stream());
ctx.Wait();
destroy_callback = [](void* backing) {
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
#endif std::string header;
} else { request.AppendToString(&header);
payload = tensor.data<void>(); auto buffer = std::unique_ptr<char[]>(new char[1024]);
} void* buf = buffer.get();
payload_size = tensor.numel() * framework::SizeOfType(tensor.type()); ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); e.WriteRawBytes(std::string(header.data(), header.size()));
} break; e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
case framework::proto::VarType_Type_SELECTED_ROWS: {
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
auto* slr = var->GetMutable<framework::SelectedRows>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
framework::ToDataType(slr->value().type()));
for (auto& dim : framework::vectorize(slr->value().dims())) {
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
}
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size =
tensor->numel() * framework::SizeOfType(tensor->type());
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor->place()),
reinterpret_cast<const void*>(tensor->data<void>()),
copy_size, gpu_dev_ctx.stream());
ctx.Wait();
destroy_callback = [](void* backing) {
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
#endif
} else {
payload = slr->mutable_value()->data<void>();
}
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break;
default:
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
break;
}
// steal reference of tensor data // steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer int num_slices = 2; // only SelectedRows have rows buffer
...@@ -162,12 +160,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -162,12 +160,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
static_cast<char*>(payload)), static_cast<char*>(payload)),
::grpc::Slice::STEAL_REF); ::grpc::Slice::STEAL_REF);
if (framework::ToVarType(var->Type()) == if (var->IsType<framework::SelectedRows>()) {
framework::proto::VarType_Type_SELECTED_ROWS) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128); ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
// NOTE: rows is of type int64_t
size_t rows_memory_size = size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t)); slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
...@@ -178,10 +173,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -178,10 +173,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
grpc_slice_new_with_user_data( grpc_slice_new_with_user_data(
const_cast<void*>( const_cast<void*>(
reinterpret_cast<const void*>(slr->rows().data())), reinterpret_cast<const void*>(slr->rows().data())),
rows_memory_size, rows_memory_size, [](void* backing) {},
[](void* backing) {
// TODO(typhoonzero): add unref here, same as above.
},
const_cast<char*>( const_cast<char*>(
reinterpret_cast<const char*>(slr->rows().data()))), reinterpret_cast<const char*>(slr->rows().data()))),
::grpc::Slice::STEAL_REF); ::grpc::Slice::STEAL_REF);
......
...@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { ...@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// serialize var to ByteBuffer // serialize var to ByteBuffer
framework::Variable var; framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>(); auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({4, 8, 4, 2})); tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
framework::LoD lod; framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8})); lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod); tensor->set_lod(lod);
int tensor_numel = 4 * 8 * 4 * 2; int tensor_numel = 512 * 8 * 4 * 2;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
tensor->mutable_data<float>(place); tensor->mutable_data<float>(place);
...@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { ...@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
EXPECT_TRUE(varmsg.ParseFromString(tmp)); EXPECT_TRUE(varmsg.ParseFromString(tmp));
EXPECT_EQ(varmsg.varname(), "myvar"); EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 0); EXPECT_EQ(varmsg.type(), 0);
EXPECT_EQ(varmsg.dims()[0], 4); EXPECT_EQ(varmsg.dims()[0], 512);
EXPECT_EQ(varmsg.dims()[1], 8); EXPECT_EQ(varmsg.dims()[1], 8);
EXPECT_EQ(varmsg.dims()[2], 4); EXPECT_EQ(varmsg.dims()[2], 4);
EXPECT_EQ(varmsg.dims()[3], 2); EXPECT_EQ(varmsg.dims()[3], 2);
......
...@@ -210,15 +210,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input, ...@@ -210,15 +210,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
} }
if (wt == WIRETYPE_LENGTH_DELIMITED) { if (wt == WIRETYPE_LENGTH_DELIMITED) {
int length = 0; int num_bytes = 0;
if (!input->ReadVarintSizeAsInt(&length)) { if (!input->ReadVarintSizeAsInt(&num_bytes)) {
return tag; return tag;
} }
int start_pos = input->CurrentPosition();
for (int i = 0; i < length; i++) { while (input->CurrentPosition() - start_pos < num_bytes) {
uint64_t v; uint64_t v;
if (!input->ReadVarint64(&v)) { if (!input->ReadVarint64(&v)) {
return false; return tag;
} }
lod->push_back(v); lod->push_back(v);
} }
...@@ -275,8 +275,8 @@ int VariableResponse::Parse(Source* source) { ...@@ -275,8 +275,8 @@ int VariableResponse::Parse(Source* source) {
break; break;
} }
case sendrecv::VariableMessage::kTypeFieldNumber: { case sendrecv::VariableMessage::kTypeFieldNumber: {
uint64_t v; uint32_t v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag; return tag;
} }
...@@ -284,8 +284,8 @@ int VariableResponse::Parse(Source* source) { ...@@ -284,8 +284,8 @@ int VariableResponse::Parse(Source* source) {
break; break;
} }
case sendrecv::VariableMessage::kDataTypeFieldNumber: { case sendrecv::VariableMessage::kDataTypeFieldNumber: {
uint64_t v = 0; uint32_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag; return tag;
} }
...@@ -305,11 +305,12 @@ int VariableResponse::Parse(Source* source) { ...@@ -305,11 +305,12 @@ int VariableResponse::Parse(Source* source) {
// packed // packed
if (wt == WIRETYPE_LENGTH_DELIMITED) { if (wt == WIRETYPE_LENGTH_DELIMITED) {
int length = 0; int num_bytes = 0;
if (!input.ReadVarintSizeAsInt(&length)) { if (!input.ReadVarintSizeAsInt(&num_bytes)) {
return tag; return tag;
} }
for (int i = 0; i < length; i++) { int start_pos = input.CurrentPosition();
while (input.CurrentPosition() - start_pos < num_bytes) {
uint64_t v; uint64_t v;
if (!input.ReadVarint64(&v)) { if (!input.ReadVarint64(&v)) {
return tag; return tag;
...@@ -318,7 +319,6 @@ int VariableResponse::Parse(Source* source) { ...@@ -318,7 +319,6 @@ int VariableResponse::Parse(Source* source) {
} }
break; break;
} }
return tag; return tag;
} }
case sendrecv::VariableMessage::kLodLevelFieldNumber: { case sendrecv::VariableMessage::kLodLevelFieldNumber: {
...@@ -372,9 +372,9 @@ int VariableResponse::Parse(Source* source) { ...@@ -372,9 +372,9 @@ int VariableResponse::Parse(Source* source) {
meta_.varname() != "", meta_.varname() != "",
"meta info should be got first!"); "meta info should be got first!");
int length = 0; int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED || if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) { !ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag; return tag;
} }
...@@ -382,14 +382,14 @@ int VariableResponse::Parse(Source* source) { ...@@ -382,14 +382,14 @@ int VariableResponse::Parse(Source* source) {
if (meta_.type() == sendrecv::LOD_TENSOR) { if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0, PADDLE_ENFORCE(meta_.lod_size() >= 0,
"lod info should be got first!"); "lod info should be got first!");
if (!CopyLodTensorData(&input, *dev_ctx_, dims, length)) { if (!CopyLodTensorData(&input, *dev_ctx_, dims, num_bytes)) {
return tag; return tag;
} }
break; break;
} }
if (meta_.type() == sendrecv::SELECTED_ROWS) { if (meta_.type() == sendrecv::SELECTED_ROWS) {
if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, length)) { if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, num_bytes)) {
return tag; return tag;
} }
break; break;
...@@ -403,13 +403,13 @@ int VariableResponse::Parse(Source* source) { ...@@ -403,13 +403,13 @@ int VariableResponse::Parse(Source* source) {
meta_.varname() != "", meta_.varname() != "",
"meta info should be got first!"); "meta info should be got first!");
int length = 0; int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED || if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) { !ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag; return tag;
} }
if (!CopySelectRowsData(&input, *dev_ctx_, length)) { if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return tag; return tag;
} }
break; break;
......
...@@ -18,7 +18,9 @@ import math ...@@ -18,7 +18,9 @@ import math
import distributed_splitter as splitter import distributed_splitter as splitter
from .. import core from .. import core
from ..framework import Program, default_main_program, Variable, Parameter from ..framework import Program, default_main_program, \
default_startup_program, \
Variable, Parameter, grad_var_name
LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
...@@ -153,43 +155,43 @@ class DistributeTranspiler: ...@@ -153,43 +155,43 @@ class DistributeTranspiler:
split_method=splitter.round_robin, split_method=splitter.round_robin,
sync_mode=True): sync_mode=True):
""" """
Transpile the program to distributed data-parallelism programs. Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put to do parameter optimization. And the optimization graph will be put
into a parameter server program. into a parameter server program.
Use different methods to split trainable variables to different Use different methods to split trainable variables to different
parameter servers. parameter servers.
Steps to transpile trainer: Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width). 1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d". 2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable. 3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch 4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server. params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights. 5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver: Steps to transpile pserver:
1. create new program for parameter server. 1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance. 2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program 3. create a sub-block in the server side program
4. append ops that should run on current server instance. 4. append ops that should run on current server instance.
5. add listen_and_serv op 5. add listen_and_serv op
:param trainer_id: one unique id for each trainer in a job. :param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int :type trainer_id: int
:param program: program to transpile, default is default_main_program :param program: program to transpile, default is default_main_program
:type program: Program :type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174" :param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string :type pservers: string
:param trainers: total number of workers/trainers in the job :param trainers: total number of workers/trainers in the job
:type trainers: int :type trainers: int
:param split_method: A function to determin how to split variables :param split_method: A function to determin how to split variables
to different servers equally. to different servers equally.
:type split_method: function :type split_method: function
:type sync_mode: boolean default True :type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler :param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program. will transpile the program into sync_mode pserver and trainer program.
""" """
assert (callable(split_method)) assert (callable(split_method))
if program is None: if program is None:
...@@ -244,7 +246,7 @@ class DistributeTranspiler: ...@@ -244,7 +246,7 @@ class DistributeTranspiler:
] ]
grad_list = [ grad_list = [
grad for grad in grad_list grad for grad in grad_list
if grad.name != framework.grad_var_name(self.table_name) if grad.name != grad_var_name(self.table_name)
] ]
self.table_param_grad = [ self.table_param_grad = [
param_grad for param_grad in params_grads param_grad for param_grad in params_grads
...@@ -494,7 +496,7 @@ class DistributeTranspiler: ...@@ -494,7 +496,7 @@ class DistributeTranspiler:
were split to several blocks. were split to several blocks.
""" """
s_prog = Program() s_prog = Program()
orig_s_prog = framework.default_startup_program() orig_s_prog = default_startup_program()
params = self.param_grad_ep_mapping[endpoint]["params"] params = self.param_grad_ep_mapping[endpoint]["params"]
def _get_splited_name_and_shape(varname): def _get_splited_name_and_shape(varname):
...@@ -619,7 +621,7 @@ class DistributeTranspiler: ...@@ -619,7 +621,7 @@ class DistributeTranspiler:
# 2. add split_ids_op and send_vars_op to send gradient to pservers # 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name # there should only be one table_name
all_ops = program.global_block().ops all_ops = program.global_block().ops
table_grad_name = framework.grad_var_name(self.table_name) table_grad_name = grad_var_name(self.table_name)
for op in all_ops: for op in all_ops:
if table_grad_name in op.output_arg_names: if table_grad_name in op.output_arg_names:
op_index = list(all_ops).index(op) op_index = list(all_ops).index(op)
...@@ -692,7 +694,7 @@ class DistributeTranspiler: ...@@ -692,7 +694,7 @@ class DistributeTranspiler:
persistable=True) persistable=True)
grad_var = _clone_var( grad_var = _clone_var(
pserver_program.global_block(), pserver_program.global_block(),
self.origin_program.global_block().vars[framework.grad_var_name( self.origin_program.global_block().vars[grad_var_name(
self.table_name)], self.table_name)],
persistable=False) persistable=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册