Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
61343fbf
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
61343fbf
编写于
5月 10, 2018
作者:
W
Wu Yi
提交者:
GitHub
5月 10, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10531 from typhoonzero/refine_grpc_serde_code
Refine serde code
上级
6d371e45
796a448c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
172 addition
and
178 deletion
+172
-178
paddle/fluid/operators/detail/sendrecvop_utils.cc
paddle/fluid/operators/detail/sendrecvop_utils.cc
+105
-113
paddle/fluid/operators/detail/serde_test.cc
paddle/fluid/operators/detail/serde_test.cc
+3
-3
paddle/fluid/operators/detail/variable_response.cc
paddle/fluid/operators/detail/variable_response.cc
+20
-20
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+44
-42
未找到文件。
paddle/fluid/operators/detail/sendrecvop_utils.cc
浏览文件 @
61343fbf
...
@@ -29,129 +29,127 @@ namespace paddle {
...
@@ -29,129 +29,127 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
using
VarMsg
=
sendrecv
::
VariableMessage
;
void
GetTensorPayload
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
void
**
payload
,
size_t
*
payload_size
)
{
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
// FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto
request
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
framework
::
ToDataType
(
tensor
.
type
())));
for
(
auto
&
dim
:
framework
::
vectorize
(
tensor
.
dims
()))
{
request
->
add_dims
(
dim
);
}
const
framework
::
LoD
lod
=
tensor
.
lod
();
if
(
lod
.
size
()
>
0
)
{
request
->
set_lod_level
(
lod
.
size
());
for
(
auto
&
each
:
lod
)
{
VarMsg
::
LodData
*
lod_inner
=
request
->
add_lod
();
for
(
auto
&
d
:
each
)
{
lod_inner
->
add_lod_data
(
d
);
}
}
}
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
tensor
.
place
()));
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
*
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
*
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
copy_size
,
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
#endif
}
else
{
*
payload
=
tensor
.
data
<
void
>
();
}
*
payload_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
}
void
GetSelectedRowsPayload
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
void
**
payload
,
size_t
*
payload_size
)
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
request
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
framework
::
ToDataType
(
slr
->
value
().
type
())));
request
->
set_lod_level
(
0
);
request
->
set_slr_height
(
slr
->
height
());
for
(
auto
&
dim
:
framework
::
vectorize
(
slr
->
value
().
dims
()))
{
request
->
add_dims
(
dim
);
}
auto
*
tensor
=
slr
->
mutable_value
();
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
*
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
*
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
->
data
<
void
>
()),
copy_size
,
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
#endif
}
else
{
*
payload
=
slr
->
mutable_value
()
->
data
<
void
>
();
}
*
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
}
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
const
std
::
string
&
out_name
)
{
using
VarMsg
=
sendrecv
::
VariableMessage
;
// Default DestroyCallback does nothing, When using GPU
// When using GPU, need to free the copied CPU buffer
// the CPU buffer need to be freed.
// when the ByteBuffer destroies
// TODO(typhoonzero): add unref here, if we have dependent
// parallelism execution, need to know when to free the tensor.
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
VarMsg
request
;
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
void
*
buf
=
buffer
.
get
();
void
*
payload
=
nullptr
;
void
*
payload
=
nullptr
;
size_t
payload_size
;
size_t
payload_size
;
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
request
.
set_varname
(
name
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// servers the trainer's profiling state so that PS can follow the
// trainer.
// trainer.
if
(
platform
::
ShouldSendProfileState
())
{
request
.
set_profile
(
platform
::
IsProfileEnabled
());
e
.
WriteBool
(
VarMsg
::
kProfileFieldNumber
,
platform
::
IsProfileEnabled
());
if
(
!
out_name
.
empty
())
{
request
.
set_out_varname
(
out_name
);
}
}
e
.
WriteString
(
VarMsg
::
kVarnameFieldNumber
,
name
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
0
);
request
.
set_type
(
::
sendrecv
::
LOD_TENSOR
);
GetTensorPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
1
);
request
.
set_type
(
::
sendrecv
::
SELECTED_ROWS
);
GetSelectedRowsPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
}
else
{
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
}
}
if
(
!
out_name
.
empty
())
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
e
.
WriteString
(
VarMsg
::
kOutVarnameFieldNumber
,
out_name
);
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
};
}
}
switch
(
framework
::
ToVarType
(
var
->
Type
()))
{
case
framework
::
proto
::
VarType_Type_LOD_TENSOR
:
{
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
e
.
WriteUint64
(
VarMsg
::
kDataTypeFieldNumber
,
framework
::
ToDataType
(
tensor
.
type
()));
for
(
auto
&
dim
:
framework
::
vectorize
(
tensor
.
dims
()))
{
e
.
WriteUint64
(
VarMsg
::
kDimsFieldNumber
,
dim
);
}
auto
lod
=
tensor
.
lod
();
// std::vector<Vector<size_t>>
if
(
lod
.
size
()
>
0
)
{
e
.
WriteUint64
(
VarMsg
::
kLodLevelFieldNumber
,
lod
.
size
());
for
(
auto
&
each
:
lod
)
{
e
.
WriteVarlengthBeginning
(
VarMsg
::
kLodFieldNumber
,
2
+
// tag + varintlength of submessage
1
+
// kLodDataFieldNumber
each
.
size
());
// auto copied from GPU
for
(
auto
&
d
:
each
)
{
e
.
WriteUint64
(
VarMsg
::
LodData
::
kLodDataFieldNumber
,
d
);
}
}
}
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
tensor
.
place
()));
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
copy_size
,
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
};
#endif
std
::
string
header
;
}
else
{
request
.
AppendToString
(
&
header
);
payload
=
tensor
.
data
<
void
>
();
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
}
void
*
buf
=
buffer
.
get
();
payload_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
e
.
WriteRawBytes
(
std
::
string
(
header
.
data
(),
header
.
size
()));
}
break
;
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
case
framework
::
proto
::
VarType_Type_SELECTED_ROWS
:
{
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
e
.
WriteUint64
(
VarMsg
::
kDataTypeFieldNumber
,
framework
::
ToDataType
(
slr
->
value
().
type
()));
for
(
auto
&
dim
:
framework
::
vectorize
(
slr
->
value
().
dims
()))
{
e
.
WriteUint64
(
VarMsg
::
kDimsFieldNumber
,
dim
);
}
e
.
WriteUint64
(
VarMsg
::
kLodLevelFieldNumber
,
0
);
e
.
WriteUint64
(
VarMsg
::
kSlrHeightFieldNumber
,
slr
->
height
());
auto
*
tensor
=
slr
->
mutable_value
();
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
->
data
<
void
>
()),
copy_size
,
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
};
#endif
}
else
{
payload
=
slr
->
mutable_value
()
->
data
<
void
>
();
}
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
}
break
;
default:
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
break
;
}
// steal reference of tensor data
// steal reference of tensor data
::
grpc
::
Slice
slices
[
4
];
// metadata, tensor, rows meta, rows
::
grpc
::
Slice
slices
[
4
];
// metadata, tensor, rows meta, rows
int
num_slices
=
2
;
// only SelectedRows have rows buffer
int
num_slices
=
2
;
// only SelectedRows have rows buffer
...
@@ -162,12 +160,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -162,12 +160,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
static_cast
<
char
*>
(
payload
)),
static_cast
<
char
*>
(
payload
)),
::
grpc
::
Slice
::
STEAL_REF
);
::
grpc
::
Slice
::
STEAL_REF
);
if
(
framework
::
ToVarType
(
var
->
Type
())
==
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
// NOTE: rows is of type int64_t
size_t
rows_memory_size
=
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
...
@@ -178,10 +173,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -178,10 +173,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
grpc_slice_new_with_user_data
(
grpc_slice_new_with_user_data
(
const_cast
<
void
*>
(
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
slr
->
rows
().
data
())),
reinterpret_cast
<
const
void
*>
(
slr
->
rows
().
data
())),
rows_memory_size
,
rows_memory_size
,
[](
void
*
backing
)
{},
[](
void
*
backing
)
{
// TODO(typhoonzero): add unref here, same as above.
},
const_cast
<
char
*>
(
const_cast
<
char
*>
(
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()))),
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()))),
::
grpc
::
Slice
::
STEAL_REF
);
::
grpc
::
Slice
::
STEAL_REF
);
...
...
paddle/fluid/operators/detail/serde_test.cc
浏览文件 @
61343fbf
...
@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
...
@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// serialize var to ByteBuffer
// serialize var to ByteBuffer
framework
::
Variable
var
;
framework
::
Variable
var
;
auto
*
tensor
=
var
.
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
tensor
=
var
.
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
framework
::
make_ddim
({
4
,
8
,
4
,
2
}));
tensor
->
Resize
(
framework
::
make_ddim
({
512
,
8
,
4
,
2
}));
framework
::
LoD
lod
;
framework
::
LoD
lod
;
lod
.
push_back
(
framework
::
Vector
<
size_t
>
({
1
,
3
,
8
}));
lod
.
push_back
(
framework
::
Vector
<
size_t
>
({
1
,
3
,
8
}));
tensor
->
set_lod
(
lod
);
tensor
->
set_lod
(
lod
);
int
tensor_numel
=
4
*
8
*
4
*
2
;
int
tensor_numel
=
512
*
8
*
4
*
2
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
tensor
->
mutable_data
<
float
>
(
place
);
tensor
->
mutable_data
<
float
>
(
place
);
...
@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
...
@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
EXPECT_TRUE
(
varmsg
.
ParseFromString
(
tmp
));
EXPECT_TRUE
(
varmsg
.
ParseFromString
(
tmp
));
EXPECT_EQ
(
varmsg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
varmsg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
varmsg
.
type
(),
0
);
EXPECT_EQ
(
varmsg
.
type
(),
0
);
EXPECT_EQ
(
varmsg
.
dims
()[
0
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
0
],
512
);
EXPECT_EQ
(
varmsg
.
dims
()[
1
],
8
);
EXPECT_EQ
(
varmsg
.
dims
()[
1
],
8
);
EXPECT_EQ
(
varmsg
.
dims
()[
2
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
2
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
3
],
2
);
EXPECT_EQ
(
varmsg
.
dims
()[
3
],
2
);
...
...
paddle/fluid/operators/detail/variable_response.cc
浏览文件 @
61343fbf
...
@@ -210,15 +210,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
...
@@ -210,15 +210,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
}
}
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
length
=
0
;
int
num_bytes
=
0
;
if
(
!
input
->
ReadVarintSizeAsInt
(
&
length
))
{
if
(
!
input
->
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
return
tag
;
}
}
int
start_pos
=
input
->
CurrentPosition
();
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
while
(
input
->
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
uint64_t
v
;
if
(
!
input
->
ReadVarint64
(
&
v
))
{
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
false
;
return
tag
;
}
}
lod
->
push_back
(
v
);
lod
->
push_back
(
v
);
}
}
...
@@ -275,8 +275,8 @@ int VariableResponse::Parse(Source* source) {
...
@@ -275,8 +275,8 @@ int VariableResponse::Parse(Source* source) {
break
;
break
;
}
}
case
sendrecv
::
VariableMessage
::
kTypeFieldNumber
:
{
case
sendrecv
::
VariableMessage
::
kTypeFieldNumber
:
{
uint
64
_t
v
;
uint
32
_t
v
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
64
(
&
v
))
{
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
32
(
&
v
))
{
return
tag
;
return
tag
;
}
}
...
@@ -284,8 +284,8 @@ int VariableResponse::Parse(Source* source) {
...
@@ -284,8 +284,8 @@ int VariableResponse::Parse(Source* source) {
break
;
break
;
}
}
case
sendrecv
::
VariableMessage
::
kDataTypeFieldNumber
:
{
case
sendrecv
::
VariableMessage
::
kDataTypeFieldNumber
:
{
uint
64
_t
v
=
0
;
uint
32
_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
64
(
&
v
))
{
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
32
(
&
v
))
{
return
tag
;
return
tag
;
}
}
...
@@ -305,11 +305,12 @@ int VariableResponse::Parse(Source* source) {
...
@@ -305,11 +305,12 @@ int VariableResponse::Parse(Source* source) {
// packed
// packed
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
length
=
0
;
int
num_bytes
=
0
;
if
(
!
input
.
ReadVarintSizeAsInt
(
&
length
))
{
if
(
!
input
.
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
return
tag
;
}
}
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
int
start_pos
=
input
.
CurrentPosition
();
while
(
input
.
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
return
tag
;
...
@@ -318,7 +319,6 @@ int VariableResponse::Parse(Source* source) {
...
@@ -318,7 +319,6 @@ int VariableResponse::Parse(Source* source) {
}
}
break
;
break
;
}
}
return
tag
;
return
tag
;
}
}
case
sendrecv
::
VariableMessage
::
kLodLevelFieldNumber
:
{
case
sendrecv
::
VariableMessage
::
kLodLevelFieldNumber
:
{
...
@@ -372,9 +372,9 @@ int VariableResponse::Parse(Source* source) {
...
@@ -372,9 +372,9 @@ int VariableResponse::Parse(Source* source) {
meta_
.
varname
()
!=
""
,
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
"meta info should be got first!"
);
int
length
=
0
;
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
return
tag
;
}
}
...
@@ -382,14 +382,14 @@ int VariableResponse::Parse(Source* source) {
...
@@ -382,14 +382,14 @@ int VariableResponse::Parse(Source* source) {
if
(
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
{
if
(
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
{
PADDLE_ENFORCE
(
meta_
.
lod_size
()
>=
0
,
PADDLE_ENFORCE
(
meta_
.
lod_size
()
>=
0
,
"lod info should be got first!"
);
"lod info should be got first!"
);
if
(
!
CopyLodTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
length
))
{
if
(
!
CopyLodTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
tag
;
return
tag
;
}
}
break
;
break
;
}
}
if
(
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
)
{
if
(
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
)
{
if
(
!
CopySelectRowsTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
length
))
{
if
(
!
CopySelectRowsTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
tag
;
return
tag
;
}
}
break
;
break
;
...
@@ -403,13 +403,13 @@ int VariableResponse::Parse(Source* source) {
...
@@ -403,13 +403,13 @@ int VariableResponse::Parse(Source* source) {
meta_
.
varname
()
!=
""
,
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
"meta info should be got first!"
);
int
length
=
0
;
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
return
tag
;
}
}
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
length
))
{
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
num_bytes
))
{
return
tag
;
return
tag
;
}
}
break
;
break
;
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
61343fbf
...
@@ -18,7 +18,9 @@ import math
...
@@ -18,7 +18,9 @@ import math
import
distributed_splitter
as
splitter
import
distributed_splitter
as
splitter
from
..
import
core
from
..
import
core
from
..framework
import
Program
,
default_main_program
,
Variable
,
Parameter
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
\
Variable
,
Parameter
,
grad_var_name
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
...
@@ -153,43 +155,43 @@ class DistributeTranspiler:
...
@@ -153,43 +155,43 @@ class DistributeTranspiler:
split_method
=
splitter
.
round_robin
,
split_method
=
splitter
.
round_robin
,
sync_mode
=
True
):
sync_mode
=
True
):
"""
"""
Transpile the program to distributed data-parallelism programs.
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
to do parameter optimization. And the optimization graph will be put
into a parameter server program.
into a parameter server program.
Use different methods to split trainable variables to different
Use different methods to split trainable variables to different
parameter servers.
parameter servers.
Steps to transpile trainer:
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
5. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
Steps to transpile pserver:
1. create new program for parameter server.
1. create new program for parameter server.
2. create params and grad variables that assigned to current server instance.
2. create params and grad variables that assigned to current server instance.
3. create a sub-block in the server side program
3. create a sub-block in the server side program
4. append ops that should run on current server instance.
4. append ops that should run on current server instance.
5. add listen_and_serv op
5. add listen_and_serv op
:param trainer_id: one unique id for each trainer in a job.
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:param program: program to transpile, default is default_main_program
:type program: Program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:type pservers: string
:param trainers: total number of workers/trainers in the job
:param trainers: total number of workers/trainers in the job
:type trainers: int
:type trainers: int
:param split_method: A function to determin how to split variables
:param split_method: A function to determin how to split variables
to different servers equally.
to different servers equally.
:type split_method: function
:type split_method: function
:type sync_mode: boolean default True
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
will transpile the program into sync_mode pserver and trainer program.
"""
"""
assert
(
callable
(
split_method
))
assert
(
callable
(
split_method
))
if
program
is
None
:
if
program
is
None
:
...
@@ -244,7 +246,7 @@ class DistributeTranspiler:
...
@@ -244,7 +246,7 @@ class DistributeTranspiler:
]
]
grad_list
=
[
grad_list
=
[
grad
for
grad
in
grad_list
grad
for
grad
in
grad_list
if
grad
.
name
!=
framework
.
grad_var_name
(
self
.
table_name
)
if
grad
.
name
!=
grad_var_name
(
self
.
table_name
)
]
]
self
.
table_param_grad
=
[
self
.
table_param_grad
=
[
param_grad
for
param_grad
in
params_grads
param_grad
for
param_grad
in
params_grads
...
@@ -494,7 +496,7 @@ class DistributeTranspiler:
...
@@ -494,7 +496,7 @@ class DistributeTranspiler:
were split to several blocks.
were split to several blocks.
"""
"""
s_prog
=
Program
()
s_prog
=
Program
()
orig_s_prog
=
framework
.
default_startup_program
()
orig_s_prog
=
default_startup_program
()
params
=
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
params
=
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
def
_get_splited_name_and_shape
(
varname
):
def
_get_splited_name_and_shape
(
varname
):
...
@@ -619,7 +621,7 @@ class DistributeTranspiler:
...
@@ -619,7 +621,7 @@ class DistributeTranspiler:
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
# there should only be one table_name
all_ops
=
program
.
global_block
().
ops
all_ops
=
program
.
global_block
().
ops
table_grad_name
=
framework
.
grad_var_name
(
self
.
table_name
)
table_grad_name
=
grad_var_name
(
self
.
table_name
)
for
op
in
all_ops
:
for
op
in
all_ops
:
if
table_grad_name
in
op
.
output_arg_names
:
if
table_grad_name
in
op
.
output_arg_names
:
op_index
=
list
(
all_ops
).
index
(
op
)
op_index
=
list
(
all_ops
).
index
(
op
)
...
@@ -692,7 +694,7 @@ class DistributeTranspiler:
...
@@ -692,7 +694,7 @@ class DistributeTranspiler:
persistable
=
True
)
persistable
=
True
)
grad_var
=
_clone_var
(
grad_var
=
_clone_var
(
pserver_program
.
global_block
(),
pserver_program
.
global_block
(),
self
.
origin_program
.
global_block
().
vars
[
framework
.
grad_var_name
(
self
.
origin_program
.
global_block
().
vars
[
grad_var_name
(
self
.
table_name
)],
self
.
table_name
)],
persistable
=
False
)
persistable
=
False
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录