Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6b3e1a68
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6b3e1a68
编写于
7月 10, 2020
作者:
Z
ZPaC
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add worker proxy.
上级
d1e7b977
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
311 addition
and
0 deletion
+311
-0
mindspore/ccsrc/parallel/ps/worker_proxy.h
mindspore/ccsrc/parallel/ps/worker_proxy.h
+311
-0
未找到文件。
mindspore/ccsrc/parallel/ps/worker_proxy.h
0 → 100644
浏览文件 @
6b3e1a68
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_
#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_
#include <unordered_map>
#include <algorithm>
#include <utility>
#include <memory>
#include <vector>
#include "ps/ps.h"
#include "parallel/ps/util.h"
namespace
mindspore
{
namespace
parallel
{
namespace
ps
{
template
<
typename
T
>
class
WorkerProxy
:
public
::
ps
::
KVWorker
<
T
>
{
public:
using
Worker
=
::
ps
::
KVWorker
<
T
>
;
using
Callback
=
std
::
function
<
void
()
>
;
using
SlicedKVs
=
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
;
using
Slicer
=
std
::
function
<
void
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
ranges
,
SlicedKVs
*
sliced
)
>
;
using
::
ps
::
SimpleApp
::
obj_
;
explicit
WorkerProxy
(
int
app_id
,
int
customer_id
,
int
lookup_customer_id
)
:
Worker
(
app_id
,
customer_id
)
{
using
_1
=
std
::
placeholders
::
_1
;
using
_2
=
std
::
placeholders
::
_2
;
using
_3
=
std
::
placeholders
::
_3
;
lookup_customer_
=
std
::
unique_ptr
<::
ps
::
Customer
>
(
new
::
ps
::
Customer
(
app_id
,
lookup_customer_id
,
std
::
bind
(
&
WorkerProxy
<
T
>::
ProcessLookupResult
,
this
,
_1
)));
lookup_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
LookupIdSlicer
,
this
,
_1
,
_2
,
_3
);
init_embedding_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
EmbeddingTableInitSlicer
,
this
,
_1
,
_2
,
_3
);
push_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
PushSlicer
,
this
,
_1
,
_2
,
_3
);
broadcast_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
BroadcastSlicer
,
this
,
_1
,
_2
,
_3
);
}
~
WorkerProxy
()
override
=
default
;
void
AddEmbeddingTable
(
const
::
ps
::
Key
&
key
,
const
size_t
&
row_count
);
void
EmbeddingLookup
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
lookup_ids
,
const
::
ps
::
SArray
<
int
>
&
lens
,
::
ps
::
SArray
<
T
>
*
outs
,
int
cmd
=
0
,
const
Callback
&
cb
=
nullptr
,
int
priority
=
0
);
int
InitEmbeddingTable
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
vals
,
const
::
ps
::
SArray
<
int
>
&
lens
=
{},
const
Callback
&
cb
=
nullptr
,
int
priority
=
0
);
void
PushData
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
vals
,
const
::
ps
::
SArray
<
int
>
&
lens
=
{},
int
cmd
=
0
,
int
priority
=
0
);
private:
template
<
typename
C
>
int
AddLookupCB
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
lookup_ids
,
C
*
vals
,
int
cmd
,
const
Callback
&
cb
);
void
LookupIdSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
);
void
EmbeddingTableInitSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
);
void
PushSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
);
void
BroadcastSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
);
void
ProcessLookupResult
(
const
::
ps
::
Message
&
msg
);
void
Send
(
::
ps
::
Customer
*
customer
,
int
timestamp
,
bool
push
,
bool
pull
,
int
cmd
,
const
::
ps
::
KVPairs
<
T
>
&
kvs
,
const
Slicer
&
slicer
);
std
::
unique_ptr
<::
ps
::
Customer
>
lookup_customer_
;
std
::
unordered_map
<::
ps
::
Key
,
std
::
shared_ptr
<
std
::
vector
<::
ps
::
Range
>>>
embedding_table_ranges_
;
std
::
unordered_map
<
int
,
std
::
vector
<::
ps
::
KVPairs
<
T
>>>
lookup_results_
;
std
::
mutex
mutex_
;
Slicer
lookup_slicer_
;
Slicer
init_embedding_slicer_
;
Slicer
push_slicer_
;
Slicer
broadcast_slicer_
;
std
::
unordered_map
<
int
,
Callback
>
lookup_callbacks_
;
};
template
<
typename
T
>
void
WorkerProxy
<
T
>::
AddEmbeddingTable
(
const
::
ps
::
Key
&
key
,
const
size_t
&
row_count
)
{
uint64_t
begin
=
0
;
uint64_t
end
=
0
;
int
server_num
=
::
ps
::
NumServers
();
for
(
int
i
=
0
;
i
<
server_num
;
i
++
)
{
int
local_row_cnt
=
Util
::
LocalShard
(
row_count
,
i
,
server_num
);
if
(
i
==
0
)
{
end
=
local_row_cnt
-
1
;
}
else
{
begin
=
end
+
1
;
end
+=
local_row_cnt
;
}
::
ps
::
Range
range
(
begin
,
end
);
if
(
embedding_table_ranges_
.
count
(
key
)
==
0
)
{
embedding_table_ranges_
[
key
]
=
std
::
make_shared
<
std
::
vector
<::
ps
::
Range
>>
();
}
embedding_table_ranges_
[
key
]
->
push_back
(
range
);
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
EmbeddingLookup
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
lookup_ids
,
const
::
ps
::
SArray
<
int
>
&
lens
,
::
ps
::
SArray
<
T
>
*
outs
,
int
cmd
,
const
Callback
&
cb
,
int
priority
)
{
int
ts
=
AddLookupCB
(
keys
,
lookup_ids
,
outs
,
cmd
,
cb
);
::
ps
::
KVPairs
<
T
>
kvs
;
kvs
.
keys
=
keys
;
kvs
.
vals
=
lookup_ids
;
kvs
.
lens
=
lens
;
kvs
.
priority
=
priority
;
Send
(
lookup_customer_
.
get
(),
ts
,
true
,
true
,
cmd
,
kvs
,
broadcast_slicer_
);
lookup_customer_
->
WaitRequest
(
ts
);
}
template
<
typename
T
>
int
WorkerProxy
<
T
>::
InitEmbeddingTable
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
vals
,
const
::
ps
::
SArray
<
int
>
&
lens
,
const
Callback
&
cb
,
int
priority
)
{
int
ts
=
obj_
->
NewRequest
(
::
ps
::
kServerGroup
);
::
ps
::
KVPairs
<
T
>
kvs
;
kvs
.
keys
=
keys
;
kvs
.
vals
=
vals
;
kvs
.
lens
=
lens
;
kvs
.
priority
=
priority
;
Send
(
obj_
,
ts
,
true
,
false
,
kInitEmbeddingsCmd
,
kvs
,
init_embedding_slicer_
);
return
ts
;
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
PushData
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
vals
,
const
::
ps
::
SArray
<
int
>
&
lens
,
int
cmd
,
int
priority
)
{
int
ts
=
obj_
->
NewRequest
(
::
ps
::
kServerGroup
);
::
ps
::
KVPairs
<
T
>
kvs
;
kvs
.
keys
=
keys
;
kvs
.
vals
=
vals
;
kvs
.
lens
=
lens
;
kvs
.
priority
=
priority
;
Send
(
obj_
,
ts
,
true
,
false
,
cmd
,
kvs
,
push_slicer_
);
obj_
->
WaitRequest
(
ts
);
}
template
<
typename
T
>
template
<
typename
C
>
int
WorkerProxy
<
T
>::
AddLookupCB
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
lookup_ids
,
C
*
lookup_result
,
int
cmd
,
const
Callback
&
cb
)
{
int
ts
=
lookup_customer_
->
NewRequest
(
::
ps
::
kServerGroup
);
const
auto
&
callback
=
[
this
,
ts
,
keys
,
lookup_ids
,
lookup_result
,
cb
]()
mutable
{
mutex_
.
lock
();
auto
&
kvs
=
lookup_results_
[
ts
];
mutex_
.
unlock
();
size_t
total_len
=
0
;
const
auto
&
s
=
kvs
[
0
];
for
(
size_t
i
=
0
;
i
<
s
.
lens
.
size
();
i
++
)
{
total_len
+=
s
.
lens
[
i
];
}
lookup_result
->
resize
(
total_len
,
0
);
T
*
result_addr
=
lookup_result
->
data
();
for
(
const
auto
&
s
:
kvs
)
{
size_t
offset
=
0
;
for
(
size_t
i
=
0
;
i
<
s
.
vals
.
size
();
i
++
)
{
result_addr
[
offset
++
]
+=
s
.
vals
[
i
];
}
}
mutex_
.
lock
();
lookup_results_
.
erase
(
ts
);
mutex_
.
unlock
();
if
(
cb
)
cb
();
};
lookup_callbacks_
[
ts
]
=
callback
;
return
ts
;
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
LookupIdSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
)
{
int
*
data
=
send
.
lens
.
data
();
size_t
size
=
send
.
lens
.
size
();
std
::
vector
<
int
>
lookup_ids
(
data
,
data
+
size
);
std
::
sort
(
lookup_ids
.
begin
(),
lookup_ids
.
end
());
const
Key
&
key
=
send
.
keys
[
0
];
const
std
::
vector
<::
ps
::
Range
>
&
ranges
=
*
(
embedding_table_ranges_
[
key
]);
sliced
->
resize
(
ranges
.
size
());
size_t
index
=
0
;
for
(
size_t
i
=
0
;
i
<
ranges
.
size
();
i
++
)
{
const
::
ps
::
Range
&
range
=
ranges
[
i
];
const
auto
&
begin
=
range
.
begin
();
const
auto
&
end
=
range
.
end
();
auto
&
kvs
=
sliced
->
at
(
i
).
second
;
auto
lookup_id
=
static_cast
<
uint64_t
>
(
lookup_ids
[
index
]);
while
(
lookup_id
>=
begin
&&
lookup_id
<=
end
)
{
kvs
.
vals
.
push_back
(
lookup_id
);
if
(
++
index
>=
lookup_ids
.
size
())
{
break
;
}
lookup_id
=
static_cast
<
uint64_t
>
(
lookup_ids
[
index
]);
}
kvs
.
keys
.
push_back
(
key
);
kvs
.
lens
.
push_back
(
kvs
.
vals
.
size
());
if
(
kvs
.
vals
.
size
()
==
0
)
{
sliced
->
at
(
i
).
first
=
false
;
}
else
{
sliced
->
at
(
i
).
first
=
true
;
}
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
EmbeddingTableInitSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
)
{
const
Key
&
key
=
send
.
keys
[
0
];
const
std
::
vector
<::
ps
::
Range
>
&
ranges
=
*
(
embedding_table_ranges_
[
key
]);
sliced
->
resize
(
ranges
.
size
());
for
(
size_t
i
=
0
;
i
<
ranges
.
size
();
i
++
)
{
sliced
->
at
(
i
).
first
=
true
;
sliced
->
at
(
i
).
second
=
send
;
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
PushSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
)
{
auto
server_num
=
::
ps
::
Postoffice
::
Get
()
->
num_servers
();
sliced
->
resize
(
server_num
);
for
(
int
i
=
0
;
i
<
server_num
;
i
++
)
{
sliced
->
at
(
i
).
first
=
true
;
sliced
->
at
(
i
).
second
=
send
;
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
BroadcastSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
)
{
auto
server_num
=
::
ps
::
Postoffice
::
Get
()
->
num_servers
();
sliced
->
resize
(
server_num
);
for
(
int
i
=
0
;
i
<
server_num
;
i
++
)
{
sliced
->
at
(
i
).
first
=
true
;
sliced
->
at
(
i
).
second
=
send
;
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
ProcessLookupResult
(
const
::
ps
::
Message
&
msg
)
{
int
ts
=
msg
.
meta
.
timestamp
;
if
(
msg
.
meta
.
pull
)
{
CHECK_GE
(
msg
.
data
.
size
(),
(
size_t
)
2
);
::
ps
::
KVPairs
<
T
>
kvs
;
kvs
.
keys
=
msg
.
data
[
0
];
kvs
.
vals
=
msg
.
data
[
1
];
if
(
msg
.
data
.
size
()
>
(
size_t
)
2
)
{
kvs
.
lens
=
msg
.
data
[
2
];
}
mutex_
.
lock
();
lookup_results_
[
ts
].
push_back
(
kvs
);
mutex_
.
unlock
();
}
if
(
lookup_customer_
->
NumResponse
(
ts
)
==
::
ps
::
Postoffice
::
Get
()
->
num_servers
()
-
1
)
{
const
auto
&
cb
=
lookup_callbacks_
[
ts
];
cb
();
lookup_callbacks_
.
erase
(
ts
);
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
Send
(
::
ps
::
Customer
*
customer
,
int
timestamp
,
bool
push
,
bool
pull
,
int
cmd
,
const
::
ps
::
KVPairs
<
T
>
&
kvs
,
const
Slicer
&
slicer
)
{
SlicedKVs
sliced
;
slicer
(
kvs
,
::
ps
::
Postoffice
::
Get
()
->
GetServerKeyRanges
(),
&
sliced
);
for
(
size_t
i
=
0
;
i
<
sliced
.
size
();
i
++
)
{
const
auto
&
s
=
sliced
[
i
];
if
(
!
s
.
first
)
continue
;
::
ps
::
Message
msg
;
msg
.
meta
.
app_id
=
customer
->
app_id
();
msg
.
meta
.
customer_id
=
customer
->
customer_id
();
msg
.
meta
.
request
=
true
;
msg
.
meta
.
push
=
push
;
msg
.
meta
.
pull
=
pull
;
msg
.
meta
.
head
=
cmd
;
msg
.
meta
.
timestamp
=
timestamp
;
msg
.
meta
.
recver
=
::
ps
::
Postoffice
::
Get
()
->
ServerRankToID
(
i
);
msg
.
meta
.
priority
=
kvs
.
priority
;
const
auto
&
kvs
=
s
.
second
;
if
(
kvs
.
keys
.
size
())
{
msg
.
AddData
(
kvs
.
keys
);
msg
.
AddData
(
kvs
.
vals
);
if
(
kvs
.
lens
.
size
())
{
msg
.
AddData
(
kvs
.
lens
);
}
}
::
ps
::
Postoffice
::
Get
()
->
van
()
->
Send
(
msg
);
}
}
}
// namespace ps
}
// namespace parallel
}
// namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录