Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
82c61dbd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
82c61dbd
编写于
5月 07, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix testing
上级
0598a4b3
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
117 addition
and
101 deletion
+117
-101
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+1
-1
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+4
-2
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+1
-1
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+2
-0
paddle/fluid/operators/detail/sendrecvop_utils.cc
paddle/fluid/operators/detail/sendrecvop_utils.cc
+83
-86
paddle/fluid/operators/detail/variable_response.cc
paddle/fluid/operators/detail/variable_response.cc
+15
-8
paddle/fluid/operators/gen_nccl_id_op.cc
paddle/fluid/operators/gen_nccl_id_op.cc
+11
-3
未找到文件。
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
82c61dbd
...
@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
...
@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
// stub context
// stub context
SendProcessor
*
s
=
new
SendProcessor
(
ch
);
SendProcessor
*
s
=
new
SendProcessor
(
ch
);
s
->
Prepare
(
var_h
,
time_out
);
s
->
Prepare
(
var_h
,
time_out
);
s
->
response_call_back_
=
NULL
;
s
->
response_call_back_
=
nullptr
;
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/SendVariable"
,
req
,
&
cq_
);
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/SendVariable"
,
req
,
&
cq_
);
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
82c61dbd
...
@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
...
@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class
BaseProcessor
{
class
BaseProcessor
{
public:
public:
explicit
BaseProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
{
context_
=
NULL
;
}
explicit
BaseProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
{
context_
=
nullptr
;
}
virtual
~
BaseProcessor
()
{}
virtual
~
BaseProcessor
()
{}
...
@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor {
...
@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor {
::
grpc
::
GenericStub
stub_g_
;
::
grpc
::
GenericStub
stub_g_
;
::
grpc
::
ByteBuffer
reply_
;
::
grpc
::
ByteBuffer
reply_
;
RequestSendCallBack
response_call_back_
=
NULL
;
RequestSendCallBack
response_call_back_
=
nullptr
;
};
};
typedef
std
::
function
<
void
(
const
VarHandle
&
,
const
::
grpc
::
ByteBuffer
&
)
>
typedef
std
::
function
<
void
(
const
VarHandle
&
,
const
::
grpc
::
ByteBuffer
&
)
>
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
82c61dbd
...
@@ -261,8 +261,8 @@ void AsyncGRPCServer::ShutdownQueue() {
...
@@ -261,8 +261,8 @@ void AsyncGRPCServer::ShutdownQueue() {
// This URL explains why shutdown is complicate:
// This URL explains why shutdown is complicate:
void
AsyncGRPCServer
::
ShutDown
()
{
void
AsyncGRPCServer
::
ShutDown
()
{
is_shut_down_
=
true
;
is_shut_down_
=
true
;
ShutdownQueue
();
server_
->
Shutdown
();
server_
->
Shutdown
();
ShutdownQueue
();
}
}
void
AsyncGRPCServer
::
TryToRegisterNewSendOne
()
{
void
AsyncGRPCServer
::
TryToRegisterNewSendOne
()
{
...
...
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
82c61dbd
...
@@ -47,6 +47,8 @@ class AsyncGRPCServer final {
...
@@ -47,6 +47,8 @@ class AsyncGRPCServer final {
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
bool
sync_mode
)
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
bool
sync_mode
)
:
address_
(
address
),
sync_mode_
(
sync_mode
)
{}
:
address_
(
address
),
sync_mode_
(
sync_mode
)
{}
~
AsyncGRPCServer
()
{}
void
RunSyncUpdate
();
void
RunSyncUpdate
();
// functions to sync server barrier status.
// functions to sync server barrier status.
...
...
paddle/fluid/operators/detail/sendrecvop_utils.cc
浏览文件 @
82c61dbd
...
@@ -53,14 +53,15 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -53,14 +53,15 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
1
);
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
1
);
}
else
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
}
else
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
// NOTE: sendrecv only support RAW type for NCCL_ID
// NOTE: sendrecv only support RAW type for NCCL_ID
VLOG
(
3
)
<<
"serilizing: setting var type nccl id"
;
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
2
);
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
2
);
}
}
if
(
!
out_name
.
empty
())
{
if
(
!
out_name
.
empty
())
{
e
.
WriteString
(
VarMsg
::
kOutVarnameFieldNumber
,
out_name
);
e
.
WriteString
(
VarMsg
::
kOutVarnameFieldNumber
,
out_name
);
}
}
switch
(
framework
::
ToVarType
(
var
->
Type
()
))
{
if
(
var
->
IsType
<
framework
::
LoDTensor
>
(
))
{
case
framework
::
proto
::
VarType_Type_LOD_TENSOR
:
{
// ===========================Tensor==================================
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
e
.
WriteUint64
(
VarMsg
::
kDataTypeFieldNumber
,
e
.
WriteUint64
(
VarMsg
::
kDataTypeFieldNumber
,
framework
::
ToDataType
(
tensor
.
type
()));
framework
::
ToDataType
(
tensor
.
type
()));
...
@@ -86,8 +87,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -86,8 +87,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
tensor
.
place
()));
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
tensor
.
place
()));
platform
::
CPUPlace
cpu
;
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
auto
copy_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
...
@@ -107,8 +107,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -107,8 +107,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
}
}
payload_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
payload_size
=
tensor
.
numel
()
*
framework
::
SizeOfType
(
tensor
.
type
());
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
}
break
;
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
case
framework
::
proto
::
VarType_Type_SELECTED_ROWS
:
{
// ===========================SELECTED
// ROWS==================================
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
e
.
WriteUint64
(
VarMsg
::
kDataTypeFieldNumber
,
e
.
WriteUint64
(
VarMsg
::
kDataTypeFieldNumber
,
...
@@ -122,10 +123,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -122,10 +123,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
platform
::
CPUPlace
cpu
;
platform
::
CPUPlace
cpu
;
auto
&
gpu_dev_ctx
=
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
);
auto
copy_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
auto
copy_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
payload
=
memory
::
Alloc
(
cpu
,
copy_size
);
memory
::
Copy
(
cpu
,
payload
,
memory
::
Copy
(
cpu
,
payload
,
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
->
place
()),
...
@@ -142,20 +141,18 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -142,20 +141,18 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
}
}
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
payload_size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
tensor
->
type
());
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload_size
);
}
break
;
}
else
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
case
framework
::
proto
::
VarType_Type_RAW
:
{
// ===========================NCCL ID==================================
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
NCCL_UNIQUE_ID_BYTES
);
NCCL_UNIQUE_ID_BYTES
);
ncclUniqueId
*
uid
=
var
->
GetMutable
<
ncclUniqueId
>
();
ncclUniqueId
*
uid
=
var
->
GetMutable
<
ncclUniqueId
>
();
e
.
WriteRawBytes
(
std
::
string
(
uid
->
internal
,
NCCL_UNIQUE_ID_BYTES
));
e
.
WriteRawBytes
(
std
::
string
(
uid
->
internal
,
NCCL_UNIQUE_ID_BYTES
));
}
break
;
}
else
{
default:
PADDLE_THROW
(
"Serialize does not support type: %s"
,
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
typeid
(
var
->
Type
()).
name
());
break
;
}
}
if
(
framework
::
ToVarType
(
var
->
Type
())
==
framework
::
proto
::
VarType_Type_RAW
)
{
if
(
var
->
IsType
<
ncclUniqueId
>
()
)
{
// for serialize NCCL_ID
// for serialize NCCL_ID
::
grpc
::
Slice
slices
(
e
.
size
());
::
grpc
::
Slice
slices
(
e
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
.
begin
()),
e
.
data
(),
e
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
.
begin
()),
e
.
data
(),
e
.
size
());
...
...
paddle/fluid/operators/detail/variable_response.cc
浏览文件 @
82c61dbd
...
@@ -371,19 +371,26 @@ int VariableResponse::Parse(Source* source) {
...
@@ -371,19 +371,26 @@ int VariableResponse::Parse(Source* source) {
meta_
.
type
()
==
sendrecv
::
NCCL_ID
)
&&
meta_
.
type
()
==
sendrecv
::
NCCL_ID
)
&&
meta_
.
varname
()
!=
""
,
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
"meta info should be got first!"
);
int
length
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
return
tag
;
}
if
(
meta_
.
type
()
==
sendrecv
::
NCCL_ID
)
{
if
(
meta_
.
type
()
==
sendrecv
::
NCCL_ID
)
{
VLOG
(
3
)
<<
"parse nccl id request"
;
auto
*
var
=
scope_
->
FindVar
(
meta_
.
varname
());
auto
*
var
=
scope_
->
FindVar
(
meta_
.
varname
());
if
(
var
!=
nullptr
)
{
if
(
var
!=
nullptr
)
{
VLOG
(
3
)
<<
"parse nccl id: length "
<<
length
;
ncclUniqueId
*
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
ncclUniqueId
*
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
memcpy
(
id
->
internal
,
meta_
.
serialized
().
c_str
(),
if
(
!
ReadRaw
(
&
input
,
*
dev_ctx_
,
platform
::
CPUPlace
(),
id
->
internal
,
meta_
.
serialized
().
size
());
length
))
{
return
tag
;
}
}
// memcpy(id->internal, meta_.serialized().c_str(),
// meta_.serialized().size());
}
}
break
;
int
length
=
0
;
if
(
wt
!=
WIRETYPE_LENGTH_DELIMITED
||
!
ReadVarintSizeAsInt
(
&
input
,
&
length
))
{
return
tag
;
}
}
framework
::
DDim
dims
=
GetDims
(
meta_
.
dims
());
framework
::
DDim
dims
=
GetDims
(
meta_
.
dims
());
...
...
paddle/fluid/operators/gen_nccl_id_op.cc
浏览文件 @
82c61dbd
...
@@ -37,7 +37,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
...
@@ -37,7 +37,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
// put nccl id in CPUPlace
auto
&
dev_ctx
=
*
pool
.
Get
(
platform
::
CPUPlace
());
int
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
int
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
framework
::
Scope
&
local_scope
=
scope
.
NewScope
();
framework
::
Scope
&
local_scope
=
scope
.
NewScope
();
...
@@ -60,9 +61,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
...
@@ -60,9 +61,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoint_list"
);
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoint_list"
);
detail
::
RPCClient
client
;
detail
::
RPCClient
client
;
for
(
auto
&
ep
:
endpoint_list
)
{
for
(
auto
&
ep
:
endpoint_list
)
{
VLOG
(
3
)
<<
"sending nccl id to "
<<
ep
;
client
.
AsyncSendVariable
(
ep
,
dev_ctx
,
*
scope
,
"NCCLID"
);
client
.
AsyncSendVariable
(
ep
,
dev_ctx
,
*
scope
,
"NCCLID"
);
}
}
client
.
Wait
();
client
.
Wait
();
VLOG
(
3
)
<<
"sending completed..."
;
}
}
void
GetIdByServer
(
framework
::
Scope
*
scope
,
void
GetIdByServer
(
framework
::
Scope
*
scope
,
...
@@ -78,9 +81,14 @@ class GenNCCLIdOp : public framework::OperatorBase {
...
@@ -78,9 +81,14 @@ class GenNCCLIdOp : public framework::OperatorBase {
server_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
server_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
RunSyncUpdate
,
rpc_service_
.
get
())));
&
detail
::
AsyncGRPCServer
::
RunSyncUpdate
,
rpc_service_
.
get
())));
rpc_service_
->
SetCond
(
0
);
VLOG
(
3
)
<<
"start getting nccl id from trainer 0..."
;
auto
recv
=
rpc_service_
->
Get
();
auto
recv
=
rpc_service_
->
Get
();
rpc_service_
->
ShutDown
();
VLOG
(
3
)
<<
"got nccl id and stop server..."
;
// rpc_service_->SetCond(1);
// rpc_service_->ShutDown();
rpc_service
->
Push
(
LISTEN_TERMINATE_MESSAGE
);
VLOG
(
3
)
<<
"rpc server stopped"
;
// TODO(wuyi): reinit nccl communicators
// TODO(wuyi): reinit nccl communicators
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录