Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
299fb8a7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
299fb8a7
编写于
9月 18, 2020
作者:
C
chengmo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support ps profiler
上级
ac82baa8
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
81 addition
and
7 deletion
+81
-7
paddle/fluid/operators/distributed/grpc/grpc_serde.cc
paddle/fluid/operators/distributed/grpc/grpc_serde.cc
+4
-0
paddle/fluid/operators/distributed/grpc/grpc_server.cc
paddle/fluid/operators/distributed/grpc/grpc_server.cc
+31
-0
paddle/fluid/operators/distributed/grpc/grpc_variable_response.cc
...luid/operators/distributed/grpc/grpc_variable_response.cc
+21
-5
paddle/fluid/operators/distributed/parameter_prefetch.cc
paddle/fluid/operators/distributed/parameter_prefetch.cc
+3
-1
paddle/fluid/operators/distributed/parameter_recv.cc
paddle/fluid/operators/distributed/parameter_recv.cc
+5
-1
paddle/fluid/operators/distributed/parameter_send.cc
paddle/fluid/operators/distributed/parameter_send.cc
+7
-0
paddle/fluid/operators/distributed/variable_response.cc
paddle/fluid/operators/distributed/variable_response.cc
+10
-0
未找到文件。
paddle/fluid/operators/distributed/grpc/grpc_serde.cc
浏览文件 @
299fb8a7
...
@@ -40,6 +40,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -40,6 +40,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const
int
trainer_id
,
const
int
trainer_id
,
const
std
::
string
&
table_name
)
{
const
std
::
string
&
table_name
)
{
platform
::
RecordRPCEvent
record_event
(
"serial"
);
platform
::
RecordRPCEvent
record_event
(
"serial"
);
platform
::
RecordEvent
record_event_grpc
(
"grpc::SerializeToByteBuffer"
,
platform
::
EventRole
::
kInnerOp
);
VarMsg
request
;
VarMsg
request
;
TensorPayload
*
payload
=
nullptr
;
TensorPayload
*
payload
=
nullptr
;
...
@@ -147,6 +149,8 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
...
@@ -147,6 +149,8 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const
framework
::
Scope
*
scope
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
,
int
*
trainer_id
)
{
framework
::
Variable
**
var
,
int
*
trainer_id
)
{
platform
::
RecordRPCEvent
record_event
(
"deserial"
);
platform
::
RecordRPCEvent
record_event
(
"deserial"
);
platform
::
RecordEvent
record_event_grpc
(
"grpc::DeserializeFromByteBuffer"
,
platform
::
EventRole
::
kInnerOp
);
operators
::
distributed
::
GRPCVariableResponse
resp
(
scope
,
&
ctx
);
operators
::
distributed
::
GRPCVariableResponse
resp
(
scope
,
&
ctx
);
PADDLE_ENFORCE
(
resp
.
Parse
(
msg
)
==
0
,
"parse bytebuffer to tensor error!"
);
PADDLE_ENFORCE
(
resp
.
Parse
(
msg
)
==
0
,
"parse bytebuffer to tensor error!"
);
*
var
=
resp
.
GetVar
();
*
var
=
resp
.
GetVar
();
...
...
paddle/fluid/operators/distributed/grpc/grpc_server.cc
浏览文件 @
299fb8a7
...
@@ -103,6 +103,7 @@ class RequestSend final : public RequestBase {
...
@@ -103,6 +103,7 @@ class RequestSend final : public RequestBase {
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
void
Process
()
override
{
platform
::
PushEvent
(
"RequestSend::Process"
,
platform
::
EventRole
::
kInnerOp
);
std
::
string
varname
=
GetReqName
();
std
::
string
varname
=
GetReqName
();
auto
scope
=
request_
->
GetMutableLocalScope
();
auto
scope
=
request_
->
GetMutableLocalScope
();
...
@@ -114,6 +115,7 @@ class RequestSend final : public RequestBase {
...
@@ -114,6 +115,7 @@ class RequestSend final : public RequestBase {
framework
::
Variable
*
outvar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
Finish
(
reply_
,
&
responder_
);
Finish
(
reply_
,
&
responder_
);
platform
::
PopEvent
(
"RequestSend::Process"
,
platform
::
EventRole
::
kInnerOp
);
}
}
protected:
protected:
...
@@ -139,6 +141,7 @@ class RequestGet final : public RequestBase {
...
@@ -139,6 +141,7 @@ class RequestGet final : public RequestBase {
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
void
Process
()
override
{
void
Process
()
override
{
platform
::
PushEvent
(
"RequestGet::Process"
,
platform
::
EventRole
::
kInnerOp
);
// proc request.
// proc request.
std
::
string
varname
=
request_
.
varname
();
std
::
string
varname
=
request_
.
varname
();
std
::
string
out_varname
=
request_
.
out_varname
();
std
::
string
out_varname
=
request_
.
out_varname
();
...
@@ -162,6 +165,7 @@ class RequestGet final : public RequestBase {
...
@@ -162,6 +165,7 @@ class RequestGet final : public RequestBase {
}
}
VLOG
(
1
)
<<
"after SerializeToByteBuffer"
;
VLOG
(
1
)
<<
"after SerializeToByteBuffer"
;
Finish
(
reply_
,
&
responder_
);
Finish
(
reply_
,
&
responder_
);
platform
::
PopEvent
(
"RequestGet::Process"
,
platform
::
EventRole
::
kInnerOp
);
}
}
protected:
protected:
...
@@ -189,6 +193,8 @@ class RequestGetNoBarrier final : public RequestBase {
...
@@ -189,6 +193,8 @@ class RequestGetNoBarrier final : public RequestBase {
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
void
Process
()
override
{
void
Process
()
override
{
platform
::
PushEvent
(
"RequestGetNoBarrier::Process"
,
platform
::
EventRole
::
kInnerOp
);
// proc request.
// proc request.
std
::
string
varname
=
request_
.
varname
();
std
::
string
varname
=
request_
.
varname
();
std
::
string
out_varname
=
request_
.
out_varname
();
std
::
string
out_varname
=
request_
.
out_varname
();
...
@@ -208,6 +214,8 @@ class RequestGetNoBarrier final : public RequestBase {
...
@@ -208,6 +214,8 @@ class RequestGetNoBarrier final : public RequestBase {
&
reply_
);
&
reply_
);
}
}
Finish
(
reply_
,
&
responder_
);
Finish
(
reply_
,
&
responder_
);
platform
::
PopEvent
(
"RequestGetNoBarrier::Process"
,
platform
::
EventRole
::
kInnerOp
);
}
}
protected:
protected:
...
@@ -237,6 +245,8 @@ class RequestGetMonomerVariable final : public RequestBase {
...
@@ -237,6 +245,8 @@ class RequestGetMonomerVariable final : public RequestBase {
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
void
Process
()
override
{
void
Process
()
override
{
platform
::
PushEvent
(
"RequestGetMonomerVariable::Process"
,
platform
::
EventRole
::
kInnerOp
);
// proc request.
// proc request.
std
::
string
varname
=
request_
.
varname
();
std
::
string
varname
=
request_
.
varname
();
...
@@ -254,6 +264,8 @@ class RequestGetMonomerVariable final : public RequestBase {
...
@@ -254,6 +264,8 @@ class RequestGetMonomerVariable final : public RequestBase {
SerializeToByteBuffer
(
varname
,
outvar
,
*
h
.
dev_ctx_
,
&
reply_
);
SerializeToByteBuffer
(
varname
,
outvar
,
*
h
.
dev_ctx_
,
&
reply_
);
}
}
Finish
(
reply_
,
&
responder_
);
Finish
(
reply_
,
&
responder_
);
platform
::
PopEvent
(
"RequestGetMonomerVariable::Process"
,
platform
::
EventRole
::
kInnerOp
);
}
}
protected:
protected:
...
@@ -284,6 +296,8 @@ class RequestGetMonomerBarrier final : public RequestBase {
...
@@ -284,6 +296,8 @@ class RequestGetMonomerBarrier final : public RequestBase {
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
void
Process
()
override
{
void
Process
()
override
{
platform
::
PushEvent
(
"RequestGetMonomerBarrier::Process"
,
platform
::
EventRole
::
kInnerOp
);
// proc request.
// proc request.
std
::
string
varname
=
request_
.
varname
();
std
::
string
varname
=
request_
.
varname
();
VLOG
(
4
)
<<
"RequestGetMonomerBarrier "
<<
varname
;
VLOG
(
4
)
<<
"RequestGetMonomerBarrier "
<<
varname
;
...
@@ -299,6 +313,8 @@ class RequestGetMonomerBarrier final : public RequestBase {
...
@@ -299,6 +313,8 @@ class RequestGetMonomerBarrier final : public RequestBase {
request_
.
trainer_id
());
request_
.
trainer_id
());
Finish
(
reply_
,
&
responder_
);
Finish
(
reply_
,
&
responder_
);
platform
::
PopEvent
(
"RequestGetMonomerBarrier::Process"
,
platform
::
EventRole
::
kInnerOp
);
}
}
protected:
protected:
...
@@ -330,6 +346,8 @@ class RequestPrefetch final : public RequestBase {
...
@@ -330,6 +346,8 @@ class RequestPrefetch final : public RequestBase {
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
void
Process
()
override
{
platform
::
PushEvent
(
"RequestPrefetch::Process"
,
platform
::
EventRole
::
kInnerOp
);
// prefetch process...
// prefetch process...
std
::
string
in_var_name
=
request_
->
Varname
();
std
::
string
in_var_name
=
request_
->
Varname
();
std
::
string
out_var_name
=
request_
->
OutVarname
();
std
::
string
out_var_name
=
request_
->
OutVarname
();
...
@@ -350,6 +368,8 @@ class RequestPrefetch final : public RequestBase {
...
@@ -350,6 +368,8 @@ class RequestPrefetch final : public RequestBase {
SerializeToByteBuffer
(
out_var_name
,
outvar
,
*
request_handler_
->
dev_ctx
(),
SerializeToByteBuffer
(
out_var_name
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
&
reply_
);
Finish
(
reply_
,
&
responder_
);
Finish
(
reply_
,
&
responder_
);
platform
::
PopEvent
(
"RequestPrefetch::Process"
,
platform
::
EventRole
::
kInnerOp
);
}
}
protected:
protected:
...
@@ -379,6 +399,8 @@ class RequestCheckpointNotify final : public RequestBase {
...
@@ -379,6 +399,8 @@ class RequestCheckpointNotify final : public RequestBase {
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
void
Process
()
override
{
platform
::
PushEvent
(
"RequestCheckpointNotify::Process"
,
platform
::
EventRole
::
kInnerOp
);
auto
scope
=
request_
->
GetMutableLocalScope
();
auto
scope
=
request_
->
GetMutableLocalScope
();
std
::
string
checkpoint_notify
=
request_
->
Varname
();
std
::
string
checkpoint_notify
=
request_
->
Varname
();
...
@@ -391,6 +413,8 @@ class RequestCheckpointNotify final : public RequestBase {
...
@@ -391,6 +413,8 @@ class RequestCheckpointNotify final : public RequestBase {
request_handler_
->
Handle
(
checkpoint_notify
,
scope
,
nullptr
,
nullptr
,
request_handler_
->
Handle
(
checkpoint_notify
,
scope
,
nullptr
,
nullptr
,
trainer_id
,
checkpoint_dir
);
trainer_id
,
checkpoint_dir
);
Finish
(
reply_
,
&
responder_
);
Finish
(
reply_
,
&
responder_
);
platform
::
PopEvent
(
"RequestCheckpointNotify::Process"
,
platform
::
EventRole
::
kInnerOp
);
}
}
protected:
protected:
...
@@ -417,6 +441,8 @@ class RequestNotify final : public RequestBase {
...
@@ -417,6 +441,8 @@ class RequestNotify final : public RequestBase {
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
void
Process
()
override
{
platform
::
PushEvent
(
"RequestNotify::Process"
,
platform
::
EventRole
::
kInnerOp
);
std
::
string
varname
=
GetReqName
();
std
::
string
varname
=
GetReqName
();
VLOG
(
4
)
<<
"RequestNotify var_name:"
<<
varname
;
VLOG
(
4
)
<<
"RequestNotify var_name:"
<<
varname
;
...
@@ -426,6 +452,7 @@ class RequestNotify final : public RequestBase {
...
@@ -426,6 +452,7 @@ class RequestNotify final : public RequestBase {
framework
::
Variable
*
outvar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
Finish
(
reply_
,
&
responder_
);
Finish
(
reply_
,
&
responder_
);
platform
::
PopEvent
(
"RequestNotify::Process"
,
platform
::
EventRole
::
kInnerOp
);
}
}
protected:
protected:
...
@@ -456,6 +483,8 @@ class RequestSendAndRecv final : public RequestBase {
...
@@ -456,6 +483,8 @@ class RequestSendAndRecv final : public RequestBase {
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
void
Process
()
override
{
platform
::
PushEvent
(
"RequestSendAndRecv::Process"
,
platform
::
EventRole
::
kInnerOp
);
std
::
string
in_var_name
=
request_
->
Varname
();
std
::
string
in_var_name
=
request_
->
Varname
();
std
::
string
out_var_name
=
request_
->
OutVarname
();
std
::
string
out_var_name
=
request_
->
OutVarname
();
std
::
string
table_name
=
request_
->
TableName
();
std
::
string
table_name
=
request_
->
TableName
();
...
@@ -471,6 +500,8 @@ class RequestSendAndRecv final : public RequestBase {
...
@@ -471,6 +500,8 @@ class RequestSendAndRecv final : public RequestBase {
SerializeToByteBuffer
(
out_var_name
,
outvar
,
*
request_handler_
->
dev_ctx
(),
SerializeToByteBuffer
(
out_var_name
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
&
reply_
);
Finish
(
reply_
,
&
responder_
);
Finish
(
reply_
,
&
responder_
);
platform
::
PopEvent
(
"RequestSendAndRecv::Process"
,
platform
::
EventRole
::
kInnerOp
);
}
}
protected:
protected:
...
...
paddle/fluid/operators/distributed/grpc/grpc_variable_response.cc
浏览文件 @
299fb8a7
...
@@ -26,6 +26,10 @@ namespace paddle {
...
@@ -26,6 +26,10 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
distributed
{
namespace
distributed
{
static
std
::
mutex
grpc_profile_mutex
;
static
bool
grpc_profile_begin
=
false
;
static
bool
grpc_profile_end
=
false
;
enum
WireType
{
enum
WireType
{
WIRETYPE_VARINT
=
0
,
WIRETYPE_VARINT
=
0
,
WIRETYPE_LENGTH_DELIMITED
=
2
,
WIRETYPE_LENGTH_DELIMITED
=
2
,
...
@@ -283,13 +287,25 @@ int GRPCVariableResponse::Parse(Source* source) {
...
@@ -283,13 +287,25 @@ int GRPCVariableResponse::Parse(Source* source) {
}
}
if
(
profiling
==
platform
::
kEnableProfiler
&&
if
(
profiling
==
platform
::
kEnableProfiler
&&
!
platform
::
IsProfileEnabled
())
{
!
platform
::
IsProfileEnabled
())
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
if
(
grpc_profile_mutex
.
try_lock
())
{
if
(
!
grpc_profile_begin
&&
!
grpc_profile_end
)
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kAll
);
grpc_profile_begin
=
true
;
grpc_profile_mutex
.
unlock
();
}
}
}
else
if
(
profiling
==
platform
::
kDisableProfiler
&&
}
else
if
(
profiling
==
platform
::
kDisableProfiler
&&
platform
::
IsProfileEnabled
())
{
platform
::
IsProfileEnabled
())
{
platform
::
DisableProfiler
(
if
(
grpc_profile_mutex
.
try_lock
())
{
platform
::
EventSortingKey
::
kDefault
,
if
(
grpc_profile_begin
&&
!
grpc_profile_end
)
{
string
::
Sprintf
(
"%s_%lld"
,
FLAGS_rpc_server_profile_path
,
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kTotal
,
listener_id
));
string
::
Sprintf
(
"./%s_%s_profile.log"
,
getenv
(
"TRAINING_ROLE"
),
getenv
(
"PADDLE_PORT"
)));
grpc_profile_end
=
true
;
grpc_profile_mutex
.
unlock
();
}
}
}
}
break
;
break
;
}
}
...
...
paddle/fluid/operators/distributed/parameter_prefetch.cc
浏览文件 @
299fb8a7
...
@@ -26,11 +26,11 @@
...
@@ -26,11 +26,11 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -195,6 +195,8 @@ void prefetchs(const std::vector<std::string> &id_var_names,
...
@@ -195,6 +195,8 @@ void prefetchs(const std::vector<std::string> &id_var_names,
const
std
::
vector
<
std
::
string
>
&
endpoints
,
const
std
::
vector
<
std
::
string
>
&
endpoints
,
const
framework
::
ExecutionContext
&
context
,
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Scope
&
scope
)
{
const
framework
::
Scope
&
scope
)
{
platform
::
RecordEvent
record_event
(
"Distributed_lookup_table::prefetchs"
,
platform
::
EventRole
::
kInnerOp
);
auto
vec_dim_1
=
0
;
auto
vec_dim_1
=
0
;
auto
vec_dim_0
=
0
;
auto
vec_dim_0
=
0
;
framework
::
Variable
*
var
=
scope
.
FindVar
(
persistable_var_name
);
framework
::
Variable
*
var
=
scope
.
FindVar
(
persistable_var_name
);
...
...
paddle/fluid/operators/distributed/parameter_recv.cc
浏览文件 @
299fb8a7
...
@@ -24,12 +24,12 @@
...
@@ -24,12 +24,12 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -43,6 +43,8 @@ using DDim = framework::DDim;
...
@@ -43,6 +43,8 @@ using DDim = framework::DDim;
template
<
typename
T
>
template
<
typename
T
>
void
RecvSelectedRows
(
const
CommContext
&
rpc_ctx
,
void
RecvSelectedRows
(
const
CommContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
)
{
const
framework
::
Scope
&
scope
)
{
platform
::
RecordEvent
record_event
(
"ParameterRecv::RecvSelectedRows"
,
platform
::
EventRole
::
kInnerOp
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
cpu_place
=
platform
::
CPUPlace
();
auto
cpu_place
=
platform
::
CPUPlace
();
auto
&
cpu_ctx
=
*
pool
.
Get
(
cpu_place
);
auto
&
cpu_ctx
=
*
pool
.
Get
(
cpu_place
);
...
@@ -112,6 +114,8 @@ void RecvSelectedRows(const CommContext &rpc_ctx,
...
@@ -112,6 +114,8 @@ void RecvSelectedRows(const CommContext &rpc_ctx,
template
<
typename
T
>
template
<
typename
T
>
void
RecvLodTensor
(
const
CommContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
)
{
void
RecvLodTensor
(
const
CommContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
)
{
platform
::
RecordEvent
record_event
(
"ParameterRecv::RecvLodTensor"
,
platform
::
EventRole
::
kInnerOp
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
cpu_place
=
platform
::
CPUPlace
();
auto
cpu_place
=
platform
::
CPUPlace
();
auto
&
cpu_ctx
=
*
pool
.
Get
(
cpu_place
);
auto
&
cpu_ctx
=
*
pool
.
Get
(
cpu_place
);
...
...
paddle/fluid/operators/distributed/parameter_send.cc
浏览文件 @
299fb8a7
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -97,6 +98,8 @@ template <typename T>
...
@@ -97,6 +98,8 @@ template <typename T>
void
ParameterSend
<
T
>::
operator
()(
const
CommContext
&
rpc_ctx
,
void
ParameterSend
<
T
>::
operator
()(
const
CommContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
,
bool
sync
,
const
framework
::
Scope
&
scope
,
bool
sync
,
int
multi_parts
)
{
int
multi_parts
)
{
platform
::
RecordEvent
record_event
(
"ParameterSend::operator"
,
platform
::
EventRole
::
kInnerOp
);
if
(
rpc_ctx
.
var_name
==
STEP_COUNTER
)
{
if
(
rpc_ctx
.
var_name
==
STEP_COUNTER
)
{
SendByNotifyRPC
(
rpc_ctx
,
scope
);
SendByNotifyRPC
(
rpc_ctx
,
scope
);
return
;
return
;
...
@@ -114,6 +117,8 @@ void ParameterSend<T>::operator()(const CommContext &rpc_ctx,
...
@@ -114,6 +117,8 @@ void ParameterSend<T>::operator()(const CommContext &rpc_ctx,
auto
*
send_var
=
scope
.
FindVar
(
rpc_ctx
.
var_name
);
auto
*
send_var
=
scope
.
FindVar
(
rpc_ctx
.
var_name
);
if
(
send_var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
send_var
->
IsType
<
framework
::
LoDTensor
>
())
{
platform
::
RecordEvent
record_event
(
"ParameterSend::LoDTensor"
,
platform
::
EventRole
::
kInnerOp
);
size_t
out_num
=
rpc_ctx
.
splited_varnames
.
size
();
size_t
out_num
=
rpc_ctx
.
splited_varnames
.
size
();
if
(
out_num
>
1
)
{
if
(
out_num
>
1
)
{
auto
&
send_tensor
=
send_var
->
Get
<
framework
::
LoDTensor
>
();
auto
&
send_tensor
=
send_var
->
Get
<
framework
::
LoDTensor
>
();
...
@@ -162,6 +167,8 @@ void ParameterSend<T>::operator()(const CommContext &rpc_ctx,
...
@@ -162,6 +167,8 @@ void ParameterSend<T>::operator()(const CommContext &rpc_ctx,
}
}
}
}
}
else
if
(
send_var
->
IsType
<
framework
::
SelectedRows
>
())
{
}
else
if
(
send_var
->
IsType
<
framework
::
SelectedRows
>
())
{
platform
::
RecordEvent
record_event
(
"ParameterSend::SelectedRows"
,
platform
::
EventRole
::
kInnerOp
);
auto
&
send_slr
=
send_var
->
Get
<
framework
::
SelectedRows
>
();
auto
&
send_slr
=
send_var
->
Get
<
framework
::
SelectedRows
>
();
auto
&
send_rows
=
send_slr
.
rows
();
auto
&
send_rows
=
send_slr
.
rows
();
...
...
paddle/fluid/operators/distributed/variable_response.cc
浏览文件 @
299fb8a7
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include <vector>
#include <vector>
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string
(
rpc_server_profile_path
,
"./profile_ps"
,
DEFINE_string
(
rpc_server_profile_path
,
"./profile_ps"
,
"the profile log file path"
);
"the profile log file path"
);
...
@@ -27,6 +28,8 @@ bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
...
@@ -27,6 +28,8 @@ bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
const
platform
::
DeviceContext
&
dev_ctx
,
const
platform
::
DeviceContext
&
dev_ctx
,
platform
::
Place
place
,
void
*
dest
,
platform
::
Place
place
,
void
*
dest
,
int64_t
size
)
{
int64_t
size
)
{
platform
::
RecordEvent
record_event
(
"VariableResponse::ReadRaw"
,
platform
::
EventRole
::
kInnerOp
);
const
void
*
data
=
NULL
;
const
void
*
data
=
NULL
;
int
size_to_write
=
0
;
int
size_to_write
=
0
;
int64_t
length
=
size
;
int64_t
length
=
size
;
...
@@ -123,6 +126,8 @@ bool VariableResponse::CopyLodTensorData(
...
@@ -123,6 +126,8 @@ bool VariableResponse::CopyLodTensorData(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
DDim
&
dims
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
DDim
&
dims
,
int
length
)
{
int
length
)
{
platform
::
RecordEvent
record_event
(
"VariableResponse::CopyLodTensorData"
,
platform
::
EventRole
::
kInnerOp
);
auto
server_var
=
GetVar
();
auto
server_var
=
GetVar
();
if
(
!
server_var
)
{
if
(
!
server_var
)
{
LOG
(
ERROR
)
<<
"recved var should not on current server: "
LOG
(
ERROR
)
<<
"recved var should not on current server: "
...
@@ -164,6 +169,9 @@ bool VariableResponse::CopySelectRowsTensorData(
...
@@ -164,6 +169,9 @@ bool VariableResponse::CopySelectRowsTensorData(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
DDim
&
dims
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
DDim
&
dims
,
int
length
)
{
int
length
)
{
platform
::
RecordEvent
record_event
(
"VariableResponse::CopySelectRowsTensorData"
,
platform
::
EventRole
::
kInnerOp
);
auto
*
slr
=
GetVar
()
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
GetVar
()
->
GetMutable
<
framework
::
SelectedRows
>
();
slr
->
set_height
(
meta_
.
slr_height
());
slr
->
set_height
(
meta_
.
slr_height
());
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
tensor
=
slr
->
mutable_value
();
...
@@ -186,6 +194,8 @@ bool VariableResponse::CopySelectRowsTensorData(
...
@@ -186,6 +194,8 @@ bool VariableResponse::CopySelectRowsTensorData(
bool
VariableResponse
::
CopySelectRowsData
(
bool
VariableResponse
::
CopySelectRowsData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
int
length
)
{
const
platform
::
DeviceContext
&
ctx
,
int
length
)
{
platform
::
RecordEvent
record_event
(
"VariableResponse::CopySelectRowsData"
,
platform
::
EventRole
::
kInnerOp
);
auto
*
slr
=
GetVar
()
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
GetVar
()
->
GetMutable
<
framework
::
SelectedRows
>
();
slr
->
mutable_rows
()
->
clear
();
slr
->
mutable_rows
()
->
clear
();
slr
->
mutable_rows
()
->
resize
(
length
/
sizeof
(
int64_t
));
// int64
slr
->
mutable_rows
()
->
resize
(
length
/
sizeof
(
int64_t
));
// int64
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录