Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
61343fbf
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
61343fbf
编写于
5月 10, 2018
作者:
W
Wu Yi
提交者:
GitHub
5月 10, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10531 from typhoonzero/refine_grpc_serde_code
Refine serde code
上级
6d371e45
796a448c
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
172 addition
and
178 deletion
+172
-178
paddle/fluid/operators/detail/sendrecvop_utils.cc
paddle/fluid/operators/detail/sendrecvop_utils.cc
+105
-113
paddle/fluid/operators/detail/serde_test.cc
paddle/fluid/operators/detail/serde_test.cc
+3
-3
paddle/fluid/operators/detail/variable_response.cc
paddle/fluid/operators/detail/variable_response.cc
+20
-20
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+44
-42
未找到文件。
paddle/fluid/operators/detail/sendrecvop_utils.cc
浏览文件 @
61343fbf
...
@@ -29,60 +29,26 @@ namespace paddle {
...
@@ -29,60 +29,26 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
using
VarMsg
=
sendrecv
::
VariableMessage
;
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
using
VarMsg
=
sendrecv
::
VariableMessage
;
// When using GPU, need to free the copied CPU buffer
// when the ByteBuffer destroies
// TODO(typhoonzero): add unref here, if we have dependent
// parallelism execution, need to know when to free the tensor.
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
void
GetTensorPayload
(
framework
::
Variable
*
var
,
void
*
buf
=
buffer
.
get
();
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
void
**
payload
,
size_t
*
payload_size
)
{
void
*
payload
=
nullptr
;
size_t
payload_size
;
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if
(
platform
::
ShouldSendProfileState
())
{
e
.
WriteBool
(
VarMsg
::
kProfileFieldNumber
,
platform
::
IsProfileEnabled
());
}
e
.
WriteString
(
VarMsg
::
kVarnameFieldNumber
,
name
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
0
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
1
);
}
if
(
!
out_name
.
empty
())
{
e
.
WriteString
(
VarMsg
::
kOutVarnameFieldNumber
,
out_name
);
}
switch
(
framework
::
ToVarType
(
var
->
Type
()))
{
case
framework
::
proto
::
VarType_Type_LOD_TENSOR
:
{
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
e
.
WriteUint64
(
VarMsg
::
kDataTypeFieldNumber
,
// FIXME(wuyi): data types in send_recv.proto is copied from
framework
::
ToDataType
(
tensor
.
type
()));
// framework.proto
request
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
framework
::
ToDataType
(
tensor
.
type
())));
for
(
auto
&
dim
:
framework
::
vectorize
(
tensor
.
dims
()))
{
for
(
auto
&
dim
:
framework
::
vectorize
(
tensor
.
dims
()))
{
e
.
WriteUint64
(
VarMsg
::
kDimsFieldNumber
,
dim
);
request
->
add_dims
(
dim
);
}
}
auto
lod
=
tensor
.
lod
();
// std::vector<Vector<size_t>>
const
framework
::
LoD
lod
=
tensor
.
lod
();
if
(
lod
.
size
()
>
0
)
{
if
(
lod
.
size
()
>
0
)
{
e
.
WriteUint64
(
VarMsg
::
kLodLevelFieldNumber
,
lod
.
size
());
request
->
set_lod_level
(
lod
.
size
());
for
(
auto
&
each
:
lod
)
{
for
(
auto
&
each
:
lod
)
{
e
.
WriteVarlengthBeginning
(
VarMsg
::
kLodFieldNumber
,
VarMsg
::
LodData
*
lod_inner
=
request
->
add_lod
();
2
+
// tag + varintlength of submessage
1
+
// kLodDataFieldNumber
each
.
size
());
// auto copied from GPU
for
(
auto
&
d
:
each
)
{
for
(
auto
&
d
:
each
)
{
e
.
WriteUint64
(
VarMsg
::
LodData
::
kLodDataFieldNumber
,
d
);
lod_inner
->
add_lod_data
(
d
);
}
}
}
}
}
}
...
@@ -90,68 +56,100 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -90,68 +56,100 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
tensor
.
place
()));
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
tensor
.
place
()));
platform
::
CPUPlace
cpu
;
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
auto
copy_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
*
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
payload
,
memory
::
Copy
(
cpu
,
*
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
copy_size
,
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
gpu_dev_ctx
.
stream
());
copy_size
,
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
ctx
.
Wait
();
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
};
#endif
#endif
}
else
{
}
else
{
payload
=
tensor
.
data
<
void
>
();
*
payload
=
tensor
.
data
<
void
>
();
}
}
payload_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
*
payload_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
}
}
break
;
case
framework
::
proto
::
VarType_Type_SELECTED_ROWS
:
{
void
GetSelectedRowsPayload
(
framework
::
Variable
*
var
,
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
void
**
payload
,
size_t
*
payload_size
)
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
e
.
WriteUint64
(
VarMsg
::
kDataTypeFieldNumber
,
request
->
set_data_type
(
framework
::
ToDataType
(
slr
->
value
().
type
()));
static_cast
<
VarMsg
::
Type
>
(
framework
::
ToDataType
(
slr
->
value
().
type
())));
request
->
set_lod_level
(
0
);
request
->
set_slr_height
(
slr
->
height
());
for
(
auto
&
dim
:
framework
::
vectorize
(
slr
->
value
().
dims
()))
{
for
(
auto
&
dim
:
framework
::
vectorize
(
slr
->
value
().
dims
()))
{
e
.
WriteUint64
(
VarMsg
::
kDimsFieldNumber
,
dim
);
request
->
add_dims
(
dim
);
}
}
e
.
WriteUint64
(
VarMsg
::
kLodLevelFieldNumber
,
0
);
e
.
WriteUint64
(
VarMsg
::
kSlrHeightFieldNumber
,
slr
->
height
());
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
tensor
=
slr
->
mutable_value
();
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
platform
::
CPUPlace
cpu
;
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
auto
copy_size
=
*
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
memory
::
Copy
(
cpu
,
*
payload
,
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
->
data
<
void
>
())
,
reinterpret_cast
<
const
void
*>
(
tensor
->
data
<
void
>
()),
copy_size
,
copy_size
,
gpu_dev_ctx
.
stream
());
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
ctx
.
Wait
();
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
};
#endif
#endif
}
else
{
}
else
{
payload
=
slr
->
mutable_value
()
->
data
<
void
>
();
*
payload
=
slr
->
mutable_value
()
->
data
<
void
>
();
}
}
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
*
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
}
}
break
;
default:
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
VarMsg
request
;
void
*
payload
=
nullptr
;
size_t
payload_size
;
request
.
set_varname
(
name
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
request
.
set_profile
(
platform
::
IsProfileEnabled
());
if
(
!
out_name
.
empty
())
{
request
.
set_out_varname
(
out_name
);
}
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
request
.
set_type
(
::
sendrecv
::
LOD_TENSOR
);
GetTensorPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
request
.
set_type
(
::
sendrecv
::
SELECTED_ROWS
);
GetSelectedRowsPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
}
else
{
PADDLE_THROW
(
"Serialize does not support type: %s"
,
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
typeid
(
var
->
Type
()).
name
());
break
;
}
}
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
};
}
std
::
string
header
;
request
.
AppendToString
(
&
header
);
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
void
*
buf
=
buffer
.
get
();
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
e
.
WriteRawBytes
(
std
::
string
(
header
.
data
(),
header
.
size
()));
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
// steal reference of tensor data
// steal reference of tensor data
::
grpc
::
Slice
slices
[
4
];
// metadata, tensor, rows meta, rows
::
grpc
::
Slice
slices
[
4
];
// metadata, tensor, rows meta, rows
int
num_slices
=
2
;
// only SelectedRows have rows buffer
int
num_slices
=
2
;
// only SelectedRows have rows buffer
...
@@ -162,12 +160,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -162,12 +160,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
static_cast
<
char
*>
(
payload
)),
static_cast
<
char
*>
(
payload
)),
::
grpc
::
Slice
::
STEAL_REF
);
::
grpc
::
Slice
::
STEAL_REF
);
if
(
framework
::
ToVarType
(
var
->
Type
())
==
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
// NOTE: rows is of type int64_t
size_t
rows_memory_size
=
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
...
@@ -178,10 +173,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -178,10 +173,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
grpc_slice_new_with_user_data
(
grpc_slice_new_with_user_data
(
const_cast
<
void
*>
(
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
slr
->
rows
().
data
())),
reinterpret_cast
<
const
void
*>
(
slr
->
rows
().
data
())),
rows_memory_size
,
rows_memory_size
,
[](
void
*
backing
)
{},
[](
void
*
backing
)
{
// TODO(typhoonzero): add unref here, same as above.
},
const_cast
<
char
*>
(
const_cast
<
char
*>
(
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()))),
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()))),
::
grpc
::
Slice
::
STEAL_REF
);
::
grpc
::
Slice
::
STEAL_REF
);
...
...
paddle/fluid/operators/detail/serde_test.cc
浏览文件 @
61343fbf
...
@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
...
@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// serialize var to ByteBuffer
// serialize var to ByteBuffer
framework
::
Variable
var
;
framework
::
Variable
var
;
auto
*
tensor
=
var
.
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
tensor
=
var
.
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
framework
::
make_ddim
({
4
,
8
,
4
,
2
}));
tensor
->
Resize
(
framework
::
make_ddim
({
512
,
8
,
4
,
2
}));
framework
::
LoD
lod
;
framework
::
LoD
lod
;
lod
.
push_back
(
framework
::
Vector
<
size_t
>
({
1
,
3
,
8
}));
lod
.
push_back
(
framework
::
Vector
<
size_t
>
({
1
,
3
,
8
}));
tensor
->
set_lod
(
lod
);
tensor
->
set_lod
(
lod
);
int
tensor_numel
=
4
*
8
*
4
*
2
;
int
tensor_numel
=
512
*
8
*
4
*
2
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
tensor
->
mutable_data
<
float
>
(
place
);
tensor
->
mutable_data
<
float
>
(
place
);
...
@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
...
@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
EXPECT_TRUE
(
varmsg
.
ParseFromString
(
tmp
));
EXPECT_TRUE
(
varmsg
.
ParseFromString
(
tmp
));
EXPECT_EQ
(
varmsg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
varmsg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
varmsg
.
type
(),
0
);
EXPECT_EQ
(
varmsg
.
type
(),
0
);
EXPECT_EQ
(
varmsg
.
dims
()[
0
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
0
],
512
);
EXPECT_EQ
(
varmsg
.
dims
()[
1
],
8
);
EXPECT_EQ
(
varmsg
.
dims
()[
1
],
8
);
EXPECT_EQ
(
varmsg
.
dims
()[
2
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
2
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
3
],
2
);
EXPECT_EQ
(
varmsg
.
dims
()[
3
],
2
);
...
...
paddle/fluid/operators/detail/variable_response.cc
浏览文件 @
61343fbf
...
@@ -210,15 +210,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
...
@@ -210,15 +210,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
}
}
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
length
=
0
;
int
num_bytes
=
0
;
if
(
!
input
->
ReadVarintSizeAsInt
(
&
length
))
{
if
(
!
input
->
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
return
tag
;
}
}
int
start_pos
=
input
->
CurrentPosition
();
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
while
(
input
->
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
uint64_t
v
;
if
(
!
input
->
ReadVarint64
(
&
v
))
{
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
false
;
return
tag
;
}
}
lod
->
push_back
(
v
);
lod
->
push_back
(
v
);
}
}
...
@@ -275,8 +275,8 @@ int VariableResponse::Parse(Source* source) {
...
@@ -275,8 +275,8 @@ int VariableResponse::Parse(Source* source) {
break
;
break
;
}
}
case
sendrecv
::
VariableMessage
::
kTypeFieldNumber
:
{
case
sendrecv
::
VariableMessage
::
kTypeFieldNumber
:
{
uint
64
_t
v
;
uint
32
_t
v
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
64
(
&
v
))
{
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
32
(
&
v
))
{
return
tag
;
return
tag
;
}
}
...
@@ -284,8 +284,8 @@ int VariableResponse::Parse(Source* source) {
...
@@ -284,8 +284,8 @@ int VariableResponse::Parse(Source* source) {
break
;
break
;
}
}
case
sendrecv
::
VariableMessage
::
kDataTypeFieldNumber
:
{
case
sendrecv
::
VariableMessage
::
kDataTypeFieldNumber
:
{
uint
64
_t
v
=
0
;
uint
32
_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
64
(
&
v
))
{
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
32
(
&
v
))
{
return
tag
;
return
tag
;
}
}
...
@@ -305,11 +305,12 @@ int VariableResponse::Parse(Source* source) {
...
@@ -305,11 +305,12 @@ int VariableResponse::Parse(Source* source) {
// packed
// packed
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
length
=
0
;
int
num_bytes
=
0
;
if
(
!
input
.
ReadVarintSizeAsInt
(
&
length
))
{
if
(
!
input
.
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
return
tag
;
}
}
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
int
start_pos
=
input
.
CurrentPosition
();
while
(
input
.
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
return
tag
;
...
@@ -318,7 +319,6 @@ int VariableResponse::Parse(Source* source) {
...
@@ -318,7 +319,6 @@ int VariableResponse::Parse(Source* source) {
}
}
break
;
break
;
}
}
return
tag
;
return
tag
;
}
}
case
sendrecv
::
VariableMessage
::
kLodLevelFieldNumber
:
{
case
sendrecv
::
VariableMessage
::
kLodLevelFieldNumber
:
{
...
@@ -372,9 +372,9 @@ int VariableResponse::Parse(Source* source) {
...
@@ -372,9 +372,9 @@ int VariableResponse::Parse(Source* source) {
meta_
.
varname
()
!=
""
,
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
"meta info should be got first!"
);
int
length
=
0
;
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
return
tag
;
}
}
...
@@ -382,14 +382,14 @@ int VariableResponse::Parse(Source* source) {
...
@@ -382,14 +382,14 @@ int VariableResponse::Parse(Source* source) {
if
(
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
{
if
(
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
{
PADDLE_ENFORCE
(
meta_
.
lod_size
()
>=
0
,
PADDLE_ENFORCE
(
meta_
.
lod_size
()
>=
0
,
"lod info should be got first!"
);
"lod info should be got first!"
);
if
(
!
CopyLodTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
length
))
{
if
(
!
CopyLodTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
tag
;
return
tag
;
}
}
break
;
break
;
}
}
if
(
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
)
{
if
(
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
)
{
if
(
!
CopySelectRowsTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
length
))
{
if
(
!
CopySelectRowsTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
tag
;
return
tag
;
}
}
break
;
break
;
...
@@ -403,13 +403,13 @@ int VariableResponse::Parse(Source* source) {
...
@@ -403,13 +403,13 @@ int VariableResponse::Parse(Source* source) {
meta_
.
varname
()
!=
""
,
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
"meta info should be got first!"
);
int
length
=
0
;
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
return
tag
;
}
}
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
length
))
{
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
num_bytes
))
{
return
tag
;
return
tag
;
}
}
break
;
break
;
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
61343fbf
...
@@ -18,7 +18,9 @@ import math
...
@@ -18,7 +18,9 @@ import math
import
distributed_splitter
as
splitter
import
distributed_splitter
as
splitter
from
..
import
core
from
..
import
core
from
..framework
import
Program
,
default_main_program
,
Variable
,
Parameter
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
\
Variable
,
Parameter
,
grad_var_name
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
...
@@ -244,7 +246,7 @@ class DistributeTranspiler:
...
@@ -244,7 +246,7 @@ class DistributeTranspiler:
]
]
grad_list
=
[
grad_list
=
[
grad
for
grad
in
grad_list
grad
for
grad
in
grad_list
if
grad
.
name
!=
framework
.
grad_var_name
(
self
.
table_name
)
if
grad
.
name
!=
grad_var_name
(
self
.
table_name
)
]
]
self
.
table_param_grad
=
[
self
.
table_param_grad
=
[
param_grad
for
param_grad
in
params_grads
param_grad
for
param_grad
in
params_grads
...
@@ -494,7 +496,7 @@ class DistributeTranspiler:
...
@@ -494,7 +496,7 @@ class DistributeTranspiler:
were split to several blocks.
were split to several blocks.
"""
"""
s_prog
=
Program
()
s_prog
=
Program
()
orig_s_prog
=
framework
.
default_startup_program
()
orig_s_prog
=
default_startup_program
()
params
=
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
params
=
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
def
_get_splited_name_and_shape
(
varname
):
def
_get_splited_name_and_shape
(
varname
):
...
@@ -619,7 +621,7 @@ class DistributeTranspiler:
...
@@ -619,7 +621,7 @@ class DistributeTranspiler:
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
# there should only be one table_name
all_ops
=
program
.
global_block
().
ops
all_ops
=
program
.
global_block
().
ops
table_grad_name
=
framework
.
grad_var_name
(
self
.
table_name
)
table_grad_name
=
grad_var_name
(
self
.
table_name
)
for
op
in
all_ops
:
for
op
in
all_ops
:
if
table_grad_name
in
op
.
output_arg_names
:
if
table_grad_name
in
op
.
output_arg_names
:
op_index
=
list
(
all_ops
).
index
(
op
)
op_index
=
list
(
all_ops
).
index
(
op
)
...
@@ -692,7 +694,7 @@ class DistributeTranspiler:
...
@@ -692,7 +694,7 @@ class DistributeTranspiler:
persistable
=
True
)
persistable
=
True
)
grad_var
=
_clone_var
(
grad_var
=
_clone_var
(
pserver_program
.
global_block
(),
pserver_program
.
global_block
(),
self
.
origin_program
.
global_block
().
vars
[
framework
.
grad_var_name
(
self
.
origin_program
.
global_block
().
vars
[
grad_var_name
(
self
.
table_name
)],
self
.
table_name
)],
persistable
=
False
)
persistable
=
False
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录