Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a2de156d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a2de156d
编写于
5月 09, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine serde code
上级
9a98a572
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
193 addition
and
178 deletion
+193
-178
paddle/fluid/operators/detail/sendrecvop_utils.cc
paddle/fluid/operators/detail/sendrecvop_utils.cc
+126
-113
paddle/fluid/operators/detail/serde_test.cc
paddle/fluid/operators/detail/serde_test.cc
+3
-3
paddle/fluid/operators/detail/variable_response.cc
paddle/fluid/operators/detail/variable_response.cc
+20
-20
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+44
-42
未找到文件。
paddle/fluid/operators/detail/sendrecvop_utils.cc
浏览文件 @
a2de156d
...
@@ -29,60 +29,47 @@ namespace paddle {
...
@@ -29,60 +29,47 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
using
VarMsg
=
sendrecv
::
VariableMessage
;
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
VarMsg
::
Type
DataTypeToEnum
(
std
::
type_index
type
)
{
const
std
::
string
&
out_name
)
{
if
(
typeid
(
platform
::
float16
).
hash_code
()
==
type
.
hash_code
())
{
using
VarMsg
=
sendrecv
::
VariableMessage
;
return
VarMsg
::
FP16
;
// When using GPU, need to free the copied CPU buffer
}
else
if
(
typeid
(
const
float
).
hash_code
()
==
type
.
hash_code
())
{
// when the ByteBuffer destroies
// CPPLint complains Using C-style cast. Use static_cast<float>() instead
// TODO(typhoonzero): add unref here, if we have dependent
// One fix to this is to replace float with const float because
// parallelism execution, need to know when to free the tensor.
// typeid(T) == typeid(const T)
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
// http://en.cppreference.com/w/cpp/language/typeid
return
VarMsg
::
FP32
;
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
}
else
if
(
typeid
(
const
double
).
hash_code
()
==
type
.
hash_code
())
{
void
*
buf
=
buffer
.
get
();
return
VarMsg
::
FP64
;
}
else
if
(
typeid
(
const
int
).
hash_code
()
==
type
.
hash_code
())
{
void
*
payload
=
nullptr
;
return
VarMsg
::
INT32
;
size_t
payload_size
;
}
else
if
(
typeid
(
const
int64_t
).
hash_code
()
==
type
.
hash_code
())
{
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
return
VarMsg
::
INT64
;
// Note: normally the profiler is enabled in 1 trainer, hence only
}
else
if
(
typeid
(
const
bool
).
hash_code
()
==
type
.
hash_code
())
{
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
return
VarMsg
::
BOOL
;
// servers the trainer's profiling state so that PS can follow the
}
else
{
// trainer.
PADDLE_THROW
(
"Not supported"
);
if
(
platform
::
ShouldSendProfileState
())
{
e
.
WriteBool
(
VarMsg
::
kProfileFieldNumber
,
platform
::
IsProfileEnabled
());
}
e
.
WriteString
(
VarMsg
::
kVarnameFieldNumber
,
name
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
0
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
1
);
}
}
}
if
(
!
out_name
.
empty
())
{
void
GetTensorPayload
(
framework
::
Variable
*
var
,
e
.
WriteString
(
VarMsg
::
kOutVarnameFieldNumber
,
out_name
);
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
}
void
**
payload
,
size_t
*
payload_size
)
{
switch
(
framework
::
ToVarType
(
var
->
Type
()))
{
case
framework
::
proto
::
VarType_Type_LOD_TENSOR
:
{
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
e
.
WriteUint64
(
VarMsg
::
kDataTypeFieldNumber
,
// FIXME(wuyi): data types in send_recv.proto is not synced with
framework
::
ToDataType
(
tensor
.
type
()));
// framework.proto
request
->
set_data_type
(
DataTypeToEnum
(
tensor
.
type
()));
for
(
auto
&
dim
:
framework
::
vectorize
(
tensor
.
dims
()))
{
for
(
auto
&
dim
:
framework
::
vectorize
(
tensor
.
dims
()))
{
e
.
WriteUint64
(
VarMsg
::
kDimsFieldNumber
,
dim
);
request
->
add_dims
(
dim
);
}
}
auto
lod
=
tensor
.
lod
();
// std::vector<Vector<size_t>>
const
framework
::
LoD
lod
=
tensor
.
lod
();
if
(
lod
.
size
()
>
0
)
{
if
(
lod
.
size
()
>
0
)
{
e
.
WriteUint64
(
VarMsg
::
kLodLevelFieldNumber
,
lod
.
size
());
request
->
set_lod_level
(
lod
.
size
());
for
(
auto
&
each
:
lod
)
{
for
(
auto
&
each
:
lod
)
{
e
.
WriteVarlengthBeginning
(
VarMsg
::
kLodFieldNumber
,
VarMsg
::
LodData
*
lod_inner
=
request
->
add_lod
();
2
+
// tag + varintlength of submessage
1
+
// kLodDataFieldNumber
each
.
size
());
// auto copied from GPU
for
(
auto
&
d
:
each
)
{
for
(
auto
&
d
:
each
)
{
e
.
WriteUint64
(
VarMsg
::
LodData
::
kLodDataFieldNumber
,
d
);
lod_inner
->
add_lod_data
(
d
);
}
}
}
}
}
}
...
@@ -90,68 +77,100 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -90,68 +77,100 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
tensor
.
place
()));
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
tensor
.
place
()));
platform
::
CPUPlace
cpu
;
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
auto
copy_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
*
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
payload
,
memory
::
Copy
(
cpu
,
*
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
copy_size
,
reinterpret_cast
<
const
void
*>
(
tensor
.
data
<
void
>
()),
gpu_dev_ctx
.
stream
());
copy_size
,
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
ctx
.
Wait
();
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
};
#endif
#endif
}
else
{
}
else
{
payload
=
tensor
.
data
<
void
>
();
*
payload
=
tensor
.
data
<
void
>
();
}
}
payload_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
*
payload_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
}
}
break
;
case
framework
::
proto
::
VarType_Type_SELECTED_ROWS
:
{
void
GetSelectedRowsPayload
(
framework
::
Variable
*
var
,
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
void
**
payload
,
size_t
*
payload_size
)
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
e
.
WriteUint64
(
VarMsg
::
kDataTypeFieldNumber
,
request
->
set_data_type
(
DataTypeToEnum
(
slr
->
value
().
type
()));
framework
::
ToDataType
(
slr
->
value
().
type
()));
request
->
set_lod_level
(
0
);
request
->
set_slr_height
(
slr
->
height
());
for
(
auto
&
dim
:
framework
::
vectorize
(
slr
->
value
().
dims
()))
{
for
(
auto
&
dim
:
framework
::
vectorize
(
slr
->
value
().
dims
()))
{
e
.
WriteUint64
(
VarMsg
::
kDimsFieldNumber
,
dim
);
request
->
add_dims
(
dim
);
}
}
e
.
WriteUint64
(
VarMsg
::
kLodLevelFieldNumber
,
0
);
e
.
WriteUint64
(
VarMsg
::
kSlrHeightFieldNumber
,
slr
->
height
());
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
tensor
=
slr
->
mutable_value
();
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
platform
::
CPUPlace
cpu
;
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
auto
copy_size
=
*
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
memory
::
Copy
(
cpu
,
*
payload
,
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
reinterpret_cast
<
const
void
*>
(
tensor
->
data
<
void
>
())
,
reinterpret_cast
<
const
void
*>
(
tensor
->
data
<
void
>
()),
copy_size
,
copy_size
,
gpu_dev_ctx
.
stream
());
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
ctx
.
Wait
();
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
};
#endif
#endif
}
else
{
}
else
{
payload
=
slr
->
mutable_value
()
->
data
<
void
>
();
*
payload
=
slr
->
mutable_value
()
->
data
<
void
>
();
}
}
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
*
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
}
}
break
;
default:
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
VarMsg
request
;
void
*
payload
=
nullptr
;
size_t
payload_size
;
request
.
set_varname
(
name
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
request
.
set_profile
(
platform
::
IsProfileEnabled
());
if
(
!
out_name
.
empty
())
{
request
.
set_out_varname
(
out_name
);
}
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
request
.
set_type
(
::
sendrecv
::
LOD_TENSOR
);
GetTensorPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
request
.
set_type
(
::
sendrecv
::
SELECTED_ROWS
);
GetSelectedRowsPayload
(
var
,
ctx
,
&
request
,
&
payload
,
&
payload_size
);
}
else
{
PADDLE_THROW
(
"Serialize does not support type: %s"
,
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
typeid
(
var
->
Type
()).
name
());
break
;
}
}
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback
=
[](
void
*
backing
)
{
platform
::
CPUPlace
cpu
;
memory
::
Free
(
cpu
,
backing
);
};
}
std
::
string
header
;
request
.
AppendToString
(
&
header
);
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
void
*
buf
=
buffer
.
get
();
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
),
1024
);
e
.
WriteRawBytes
(
std
::
string
(
header
.
data
(),
header
.
size
()));
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
// steal reference of tensor data
// steal reference of tensor data
::
grpc
::
Slice
slices
[
4
];
// metadata, tensor, rows meta, rows
::
grpc
::
Slice
slices
[
4
];
// metadata, tensor, rows meta, rows
int
num_slices
=
2
;
// only SelectedRows have rows buffer
int
num_slices
=
2
;
// only SelectedRows have rows buffer
...
@@ -162,12 +181,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -162,12 +181,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
static_cast
<
char
*>
(
payload
)),
static_cast
<
char
*>
(
payload
)),
::
grpc
::
Slice
::
STEAL_REF
);
::
grpc
::
Slice
::
STEAL_REF
);
if
(
framework
::
ToVarType
(
var
->
Type
())
==
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
// NOTE: rows is of type int64_t
size_t
rows_memory_size
=
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
...
@@ -178,10 +194,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -178,10 +194,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
grpc_slice_new_with_user_data
(
grpc_slice_new_with_user_data
(
const_cast
<
void
*>
(
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
slr
->
rows
().
data
())),
reinterpret_cast
<
const
void
*>
(
slr
->
rows
().
data
())),
rows_memory_size
,
rows_memory_size
,
[](
void
*
backing
)
{},
[](
void
*
backing
)
{
// TODO(typhoonzero): add unref here, same as above.
},
const_cast
<
char
*>
(
const_cast
<
char
*>
(
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()))),
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()))),
::
grpc
::
Slice
::
STEAL_REF
);
::
grpc
::
Slice
::
STEAL_REF
);
...
...
paddle/fluid/operators/detail/serde_test.cc
浏览文件 @
a2de156d
...
@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
...
@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// serialize var to ByteBuffer
// serialize var to ByteBuffer
framework
::
Variable
var
;
framework
::
Variable
var
;
auto
*
tensor
=
var
.
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
tensor
=
var
.
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
framework
::
make_ddim
({
4
,
8
,
4
,
2
}));
tensor
->
Resize
(
framework
::
make_ddim
({
512
,
8
,
4
,
2
}));
framework
::
LoD
lod
;
framework
::
LoD
lod
;
lod
.
push_back
(
framework
::
Vector
<
size_t
>
({
1
,
3
,
8
}));
lod
.
push_back
(
framework
::
Vector
<
size_t
>
({
1
,
3
,
8
}));
tensor
->
set_lod
(
lod
);
tensor
->
set_lod
(
lod
);
int
tensor_numel
=
4
*
8
*
4
*
2
;
int
tensor_numel
=
512
*
8
*
4
*
2
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
tensor
->
mutable_data
<
float
>
(
place
);
tensor
->
mutable_data
<
float
>
(
place
);
...
@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
...
@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
EXPECT_TRUE
(
varmsg
.
ParseFromString
(
tmp
));
EXPECT_TRUE
(
varmsg
.
ParseFromString
(
tmp
));
EXPECT_EQ
(
varmsg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
varmsg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
varmsg
.
type
(),
0
);
EXPECT_EQ
(
varmsg
.
type
(),
0
);
EXPECT_EQ
(
varmsg
.
dims
()[
0
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
0
],
512
);
EXPECT_EQ
(
varmsg
.
dims
()[
1
],
8
);
EXPECT_EQ
(
varmsg
.
dims
()[
1
],
8
);
EXPECT_EQ
(
varmsg
.
dims
()[
2
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
2
],
4
);
EXPECT_EQ
(
varmsg
.
dims
()[
3
],
2
);
EXPECT_EQ
(
varmsg
.
dims
()[
3
],
2
);
...
...
paddle/fluid/operators/detail/variable_response.cc
浏览文件 @
a2de156d
...
@@ -210,15 +210,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
...
@@ -210,15 +210,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
}
}
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
length
=
0
;
int
num_bytes
=
0
;
if
(
!
input
->
ReadVarintSizeAsInt
(
&
length
))
{
if
(
!
input
->
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
return
tag
;
}
}
int
start_pos
=
input
->
CurrentPosition
();
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
while
(
input
->
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
uint64_t
v
;
if
(
!
input
->
ReadVarint64
(
&
v
))
{
if
(
!
input
->
ReadVarint64
(
&
v
))
{
return
false
;
return
tag
;
}
}
lod
->
push_back
(
v
);
lod
->
push_back
(
v
);
}
}
...
@@ -275,8 +275,8 @@ int VariableResponse::Parse(Source* source) {
...
@@ -275,8 +275,8 @@ int VariableResponse::Parse(Source* source) {
break
;
break
;
}
}
case
sendrecv
::
VariableMessage
::
kTypeFieldNumber
:
{
case
sendrecv
::
VariableMessage
::
kTypeFieldNumber
:
{
uint
64
_t
v
;
uint
32
_t
v
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
64
(
&
v
))
{
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
32
(
&
v
))
{
return
tag
;
return
tag
;
}
}
...
@@ -284,8 +284,8 @@ int VariableResponse::Parse(Source* source) {
...
@@ -284,8 +284,8 @@ int VariableResponse::Parse(Source* source) {
break
;
break
;
}
}
case
sendrecv
::
VariableMessage
::
kDataTypeFieldNumber
:
{
case
sendrecv
::
VariableMessage
::
kDataTypeFieldNumber
:
{
uint
64
_t
v
=
0
;
uint
32
_t
v
=
0
;
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
64
(
&
v
))
{
if
((
wt
!=
WIRETYPE_VARINT
)
||
!
input
.
ReadVarint
32
(
&
v
))
{
return
tag
;
return
tag
;
}
}
...
@@ -305,11 +305,12 @@ int VariableResponse::Parse(Source* source) {
...
@@ -305,11 +305,12 @@ int VariableResponse::Parse(Source* source) {
// packed
// packed
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
if
(
wt
==
WIRETYPE_LENGTH_DELIMITED
)
{
int
length
=
0
;
int
num_bytes
=
0
;
if
(
!
input
.
ReadVarintSizeAsInt
(
&
length
))
{
if
(
!
input
.
ReadVarintSizeAsInt
(
&
num_bytes
))
{
return
tag
;
return
tag
;
}
}
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
int
start_pos
=
input
.
CurrentPosition
();
while
(
input
.
CurrentPosition
()
-
start_pos
<
num_bytes
)
{
uint64_t
v
;
uint64_t
v
;
if
(
!
input
.
ReadVarint64
(
&
v
))
{
if
(
!
input
.
ReadVarint64
(
&
v
))
{
return
tag
;
return
tag
;
...
@@ -318,7 +319,6 @@ int VariableResponse::Parse(Source* source) {
...
@@ -318,7 +319,6 @@ int VariableResponse::Parse(Source* source) {
}
}
break
;
break
;
}
}
return
tag
;
return
tag
;
}
}
case
sendrecv
::
VariableMessage
::
kLodLevelFieldNumber
:
{
case
sendrecv
::
VariableMessage
::
kLodLevelFieldNumber
:
{
...
@@ -372,9 +372,9 @@ int VariableResponse::Parse(Source* source) {
...
@@ -372,9 +372,9 @@ int VariableResponse::Parse(Source* source) {
meta_
.
varname
()
!=
""
,
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
"meta info should be got first!"
);
int
length
=
0
;
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
return
tag
;
}
}
...
@@ -382,14 +382,14 @@ int VariableResponse::Parse(Source* source) {
...
@@ -382,14 +382,14 @@ int VariableResponse::Parse(Source* source) {
if
(
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
{
if
(
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
{
PADDLE_ENFORCE
(
meta_
.
lod_size
()
>=
0
,
PADDLE_ENFORCE
(
meta_
.
lod_size
()
>=
0
,
"lod info should be got first!"
);
"lod info should be got first!"
);
if
(
!
CopyLodTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
length
))
{
if
(
!
CopyLodTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
tag
;
return
tag
;
}
}
break
;
break
;
}
}
if
(
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
)
{
if
(
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
)
{
if
(
!
CopySelectRowsTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
length
))
{
if
(
!
CopySelectRowsTensorData
(
&
input
,
*
dev_ctx_
,
dims
,
num_bytes
))
{
return
tag
;
return
tag
;
}
}
break
;
break
;
...
@@ -403,13 +403,13 @@ int VariableResponse::Parse(Source* source) {
...
@@ -403,13 +403,13 @@ int VariableResponse::Parse(Source* source) {
meta_
.
varname
()
!=
""
,
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
"meta info should be got first!"
);
int
length
=
0
;
int
num_bytes
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
!
ReadVarintSizeAsInt
(
&
input
,
&
num_bytes
))
{
return
tag
;
return
tag
;
}
}
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
length
))
{
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
num_bytes
))
{
return
tag
;
return
tag
;
}
}
break
;
break
;
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
a2de156d
...
@@ -18,7 +18,9 @@ import math
...
@@ -18,7 +18,9 @@ import math
import
distributed_splitter
as
splitter
import
distributed_splitter
as
splitter
from
..
import
core
from
..
import
core
from
..framework
import
Program
,
default_main_program
,
Variable
,
Parameter
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
\
Variable
,
Parameter
,
grad_var_name
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_TYPE
=
"lookup_table"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
LOOKUP_TABLE_GRAD_TYPE
=
"lookup_table_grad"
...
@@ -244,7 +246,7 @@ class DistributeTranspiler:
...
@@ -244,7 +246,7 @@ class DistributeTranspiler:
]
]
grad_list
=
[
grad_list
=
[
grad
for
grad
in
grad_list
grad
for
grad
in
grad_list
if
grad
.
name
!=
framework
.
grad_var_name
(
self
.
table_name
)
if
grad
.
name
!=
grad_var_name
(
self
.
table_name
)
]
]
self
.
table_param_grad
=
[
self
.
table_param_grad
=
[
param_grad
for
param_grad
in
params_grads
param_grad
for
param_grad
in
params_grads
...
@@ -494,7 +496,7 @@ class DistributeTranspiler:
...
@@ -494,7 +496,7 @@ class DistributeTranspiler:
were split to several blocks.
were split to several blocks.
"""
"""
s_prog
=
Program
()
s_prog
=
Program
()
orig_s_prog
=
framework
.
default_startup_program
()
orig_s_prog
=
default_startup_program
()
params
=
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
params
=
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
def
_get_splited_name_and_shape
(
varname
):
def
_get_splited_name_and_shape
(
varname
):
...
@@ -619,7 +621,7 @@ class DistributeTranspiler:
...
@@ -619,7 +621,7 @@ class DistributeTranspiler:
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
# there should only be one table_name
all_ops
=
program
.
global_block
().
ops
all_ops
=
program
.
global_block
().
ops
table_grad_name
=
framework
.
grad_var_name
(
self
.
table_name
)
table_grad_name
=
grad_var_name
(
self
.
table_name
)
for
op
in
all_ops
:
for
op
in
all_ops
:
if
table_grad_name
in
op
.
output_arg_names
:
if
table_grad_name
in
op
.
output_arg_names
:
op_index
=
list
(
all_ops
).
index
(
op
)
op_index
=
list
(
all_ops
).
index
(
op
)
...
@@ -692,7 +694,7 @@ class DistributeTranspiler:
...
@@ -692,7 +694,7 @@ class DistributeTranspiler:
persistable
=
True
)
persistable
=
True
)
grad_var
=
_clone_var
(
grad_var
=
_clone_var
(
pserver_program
.
global_block
(),
pserver_program
.
global_block
(),
self
.
origin_program
.
global_block
().
vars
[
framework
.
grad_var_name
(
self
.
origin_program
.
global_block
().
vars
[
grad_var_name
(
self
.
table_name
)],
self
.
table_name
)],
persistable
=
False
)
persistable
=
False
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录