Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e0a2d2f9
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e0a2d2f9
编写于
8月 03, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 03, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3858 Fix worker sgd error in master
Merge pull request !3858 from ZPaC/master-fix-ps-sgd
上级
f4e4fd2c
273f3b39
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
136 addition
and
19 deletion
+136
-19
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h
+4
-1
mindspore/ccsrc/frontend/parallel/ps/common.h
mindspore/ccsrc/frontend/parallel/ps/common.h
+2
-0
mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc
...pore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc
+7
-4
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
+84
-11
mindspore/ccsrc/frontend/parallel/ps/worker.h
mindspore/ccsrc/frontend/parallel/ps/worker.h
+15
-3
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
+24
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h
浏览文件 @
e0a2d2f9
...
...
@@ -43,7 +43,10 @@ class PushKernel : public CPUKernel {
sizes
.
push_back
(
SizeToInt
(
input
->
size
)
/
sizeof
(
T
));
}
parallel
::
ps
::
Worker
<
T
>::
GetInstance
().
Push
(
keys
,
addrs
,
sizes
);
memcpy_s
(
outputs
[
0
]
->
addr
,
sizeof
(
size_t
),
&
key_
,
sizeof
(
size_t
));
auto
ret
=
memcpy_s
(
outputs
[
0
]
->
addr
,
sizeof
(
size_t
),
&
key_
,
sizeof
(
size_t
));
if
(
ret
!=
EOK
)
{
MS_LOG
(
EXCEPTION
)
<<
"Lookup id memcpy failed."
;
}
return
true
;
}
...
...
mindspore/ccsrc/frontend/parallel/ps/common.h
浏览文件 @
e0a2d2f9
...
...
@@ -66,6 +66,8 @@ constexpr int kInitWeightToOptimIdCmd = 11;
constexpr
int
kInitOptimInputsShapeCmd
=
12
;
constexpr
int
kInitKeyToPushNodeIdCmd
=
13
;
constexpr
int
kInitEmbeddingsCmd
=
20
;
constexpr
int
kCheckReadyForPushCmd
=
25
;
constexpr
int
kCheckReadyForPullCmd
=
26
;
constexpr
int
kEmbeddingLookupCmd
=
30
;
constexpr
int
kFinalizeCmd
=
40
;
...
...
mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc
浏览文件 @
e0a2d2f9
...
...
@@ -158,16 +158,19 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
}
AddressPtr
linear
=
std
::
make_shared
<
kernel
::
Address
>
();
linear
->
addr
=
new
float
[
weight
->
size
()];
memcpy_s
(
linear
->
addr
,
weight
->
size
()
*
sizeof
(
float
),
0x00
,
weight
->
size
()
*
sizeof
(
float
));
auto
ret
=
memset_s
(
linear
->
addr
,
weight
->
size
()
*
sizeof
(
float
),
0x00
,
weight
->
size
()
*
sizeof
(
float
));
if
(
ret
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"memcpy_s error, errorno("
<<
ret
<<
")"
;
}
linear
->
size
=
weight
->
size
()
*
sizeof
(
float
);
const
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
&
grad_shape
=
(
*
inputs_shape
)[
3
];
size_t
total_grad_size
=
std
::
accumulate
((
*
grad_shape
).
begin
(),
(
*
grad_shape
).
end
(),
1
,
std
::
multiplies
<
size_t
>
());
AddressPtr
grad
=
std
::
make_shared
<
kernel
::
Address
>
();
grad
->
addr
=
new
float
[
total_grad_size
*
worker_num
];
auto
ret
=
memcpy_s
(
grad
->
addr
,
lens
[
0
]
*
sizeof
(
float
),
values
.
data
(),
lens
[
0
]
*
sizeof
(
float
));
if
(
ret
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"memcpy_s error, errorno("
<<
ret
<<
")"
;
auto
ret
1
=
memcpy_s
(
grad
->
addr
,
lens
[
0
]
*
sizeof
(
float
),
values
.
data
(),
lens
[
0
]
*
sizeof
(
float
));
if
(
ret
1
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"memcpy_s error, errorno("
<<
ret
1
<<
")"
;
}
grad
->
size
=
lens
[
0
]
*
sizeof
(
float
);
...
...
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
浏览文件 @
e0a2d2f9
...
...
@@ -91,6 +91,8 @@ class ParameterServer {
::
ps
::
KVPairs
<
T
>
*
res
);
void
HandleInitInputsShape
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
);
void
HandleInitEmbeddings
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
);
void
HandleCheckReadyForPush
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
);
void
HandleCheckReadyForPull
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
);
void
HandleEmbeddingLookup
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
);
void
HandleFinalize
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
);
...
...
@@ -98,6 +100,9 @@ class ParameterServer {
typedef
void
(
ServerHandler
::*
RequestHandler
)(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
);
std
::
unordered_map
<
int
,
RequestHandler
>
handlers_
;
std
::
unordered_map
<
Key
,
bool
>
init_weights_
;
std
::
unordered_map
<
Key
,
bool
>
init_weight_to_optim_
;
std
::
unordered_map
<
Key
,
bool
>
init_optim_info_
;
};
bool
Init
(
const
FuncGraphPtr
&
func_graph
);
...
...
@@ -115,9 +120,11 @@ class ParameterServer {
void
DoEmbeddingLookup
(
Key
key
,
const
LookupIds
&
lookup_ids
,
::
ps
::
KVPairs
<
T
>
*
res
);
int
SumOfShapes
(
const
std
::
vector
<
int
>
&
shapes
)
const
;
bool
ReadyForUpdateWeights
();
bool
ReadyForAccumGrads
();
bool
ReadyForPush
(
const
Key
&
key
);
bool
ReadyForPull
(
const
Key
&
key
);
void
ResetGradAccumCount
();
const
CNodePtr
GetCNode
(
const
std
::
string
&
name
)
const
;
std
::
mutex
&
mutex
();
size_t
pserver_num_
;
size_t
worker_num_
;
...
...
@@ -136,13 +143,14 @@ class ParameterServer {
std
::
unordered_map
<
Key
,
std
::
string
>
weight_key_to_optims_
;
std
::
unordered_map
<
Key
,
std
::
string
>
weight_key_to_optim_op_
;
std
::
unordered_map
<
Key
,
WeightPtr
>
weights_
;
std
::
unordered_map
<
Key
,
bool
>
is_embedding_
;
std
::
unordered_map
<
Key
,
WeightPtr
>
grads_
;
std
::
unordered_map
<
Key
,
size_t
>
grads_accum_counter_
;
std
::
unordered_map
<
Key
,
std
::
shared_ptr
<
PServerKernel
>>
embedding_lookup_ops_
;
std
::
unordered_map
<
Key
,
uint64_t
>
tokens_
;
std
::
mutex
mutex_
;
std
::
condition_variable
apply_grads_cv_
;
std
::
condition_variable
accum_grads_cv_
;
std
::
unique_ptr
<
std
::
thread
>
thread_
;
...
...
@@ -171,6 +179,8 @@ void ParameterServer<T>::ServerHandler::Init() {
handlers_
[
kInitWeightToOptimIdCmd
]
=
&
ServerHandler
::
HandleInitWeightToOptimId
;
handlers_
[
kInitOptimInputsShapeCmd
]
=
&
ServerHandler
::
HandleInitInputsShape
;
handlers_
[
kInitEmbeddingsCmd
]
=
&
ServerHandler
::
HandleInitEmbeddings
;
handlers_
[
kCheckReadyForPushCmd
]
=
&
ServerHandler
::
HandleCheckReadyForPush
;
handlers_
[
kCheckReadyForPullCmd
]
=
&
ServerHandler
::
HandleCheckReadyForPull
;
handlers_
[
kEmbeddingLookupCmd
]
=
&
ServerHandler
::
HandleEmbeddingLookup
;
handlers_
[
kFinalizeCmd
]
=
&
ServerHandler
::
HandleFinalize
;
}
...
...
@@ -192,11 +202,17 @@ void ParameterServer<T>::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_me
template
<
typename
T
>
void
ParameterServer
<
T
>::
ServerHandler
::
HandleInitWeights
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
ps_
->
mutex
());
size_t
key_num
=
req_data
.
keys
.
size
();
T
*
data_ptr
=
req_data
.
vals
.
data
();
size_t
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
key_num
;
i
++
)
{
Key
key
=
req_data
.
keys
[
i
];
if
(
init_weights_
[
key
])
{
continue
;
}
else
{
init_weights_
[
key
]
=
true
;
}
size_t
data_len
=
req_data
.
lens
.
size
()
!=
key_num
?
req_data
.
vals
.
size
()
/
key_num
:
req_data
.
lens
[
i
];
WeightPtr
weight_ptr
=
std
::
make_shared
<::
ps
::
SArray
<
T
>>
();
...
...
@@ -213,10 +229,16 @@ template <typename T>
void
ParameterServer
<
T
>::
ServerHandler
::
HandleInitWeightToOptimId
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
ps_
->
mutex
());
size_t
key_num
=
req_data
.
keys
.
size
();
for
(
size_t
i
=
0
;
i
<
key_num
;
i
++
)
{
Key
key
=
req_data
.
keys
[
i
];
T
val
=
req_data
.
vals
[
i
];
if
(
init_weight_to_optim_
[
key
])
{
continue
;
}
else
{
init_weight_to_optim_
[
key
]
=
true
;
}
ps_
->
InitWeightKeyToOptims
(
key
,
val
);
}
}
...
...
@@ -224,12 +246,26 @@ void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KV
template
<
typename
T
>
void
ParameterServer
<
T
>::
ServerHandler
::
HandleInitInputsShape
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
ps_
->
mutex
());
const
Key
&
key
=
req_data
.
keys
[
0
];
if
(
init_optim_info_
[
key
])
{
return
;
}
else
{
init_optim_info_
[
key
]
=
true
;
}
ps_
->
InitOptimInputsShape
(
req_data
.
keys
,
req_data
.
vals
,
req_data
.
lens
);
}
template
<
typename
T
>
void
ParameterServer
<
T
>::
ServerHandler
::
HandleInitEmbeddings
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
ps_
->
mutex
());
const
Key
&
key
=
req_data
.
keys
[
0
];
if
(
init_weights_
[
key
])
{
return
;
}
else
{
init_weights_
[
key
]
=
true
;
}
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
shapes
=
std
::
make_shared
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
();
std
::
shared_ptr
<
std
::
vector
<
size_t
>>
input_shape
=
std
::
make_shared
<
std
::
vector
<
size_t
>>
();
...
...
@@ -239,7 +275,6 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
shapes
->
push_back
(
indices_shape
);
shapes
->
push_back
(
output_shape
);
const
Key
&
key
=
req_data
.
keys
[
0
];
const
Lengths
&
lens
=
req_data
.
lens
;
size_t
index
=
0
;
for
(
int
i
=
0
;
i
<
lens
[
0
];
i
++
)
{
...
...
@@ -254,6 +289,26 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
ps_
->
InitEmbeddingTable
(
key
,
shapes
);
}
template
<
typename
T
>
void
ParameterServer
<
T
>::
ServerHandler
::
HandleCheckReadyForPush
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
)
{
const
Key
&
key
=
req_data
.
keys
[
0
];
bool
ready
=
ps_
->
ReadyForPush
(
key
);
res
->
keys
.
push_back
(
key
);
res
->
vals
.
push_back
(
ready
);
}
template
<
typename
T
>
void
ParameterServer
<
T
>::
ServerHandler
::
HandleCheckReadyForPull
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
)
{
const
Key
&
key
=
req_data
.
keys
[
0
];
bool
ready
=
ps_
->
ReadyForPull
(
key
);
res
->
keys
.
push_back
(
key
);
res
->
vals
.
push_back
(
ready
);
}
template
<
typename
T
>
void
ParameterServer
<
T
>::
ServerHandler
::
HandleEmbeddingLookup
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
)
{
...
...
@@ -365,6 +420,8 @@ void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) {
MS_LOG
(
INFO
)
<<
"Initializing weight for key "
<<
key
;
if
(
weights_
.
count
(
key
)
==
0
)
{
weights_
[
key
]
=
weight
;
tokens_
[
key
]
=
0
;
is_embedding_
[
key
]
=
false
;
}
}
...
...
@@ -399,6 +456,8 @@ void ParameterServer<T>::InitEmbeddingTable(
embedding_data
[
i
]
=
random
(
engine
);
}
weights_
[
key
]
=
embedding
;
tokens_
[
key
]
=
0
;
is_embedding_
[
key
]
=
true
;
grads_accum_counter_
[
key
]
=
0
;
}
...
...
@@ -439,17 +498,17 @@ void ParameterServer<T>::UpdateWeights() {
optim_info
->
ComputeMean
(
worker_num_
);
optimizer
->
Execute
(
inputs
,
workspaces
,
outputs
);
optim_info
->
Reset
();
if
(
!
is_embedding_
[
key
])
{
tokens_
[
key
]
=
worker_num_
;
}
}
ResetGradAccumCount
();
accum_grads_cv_
.
notify_all
();
}
}
template
<
typename
T
>
void
ParameterServer
<
T
>::
AccumGrad
(
const
Keys
&
keys
,
const
Values
&
values
,
const
Lengths
&
lengths
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
accum_grads_cv_
.
wait
(
lock
,
[
this
]
{
return
this
->
ReadyForAccumGrads
();
});
const
Key
&
key
=
keys
[
0
];
std
::
shared_ptr
<
OptimizerInfo
>
optim_info
=
optim_infos_
[
key
];
...
...
@@ -482,14 +541,13 @@ void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const
template
<
typename
T
>
WeightPtr
ParameterServer
<
T
>::
weight
(
const
Key
&
key
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
weights_
.
count
(
key
)
==
0
)
{
MS_LOG
(
ERROR
)
<<
"Invalid weight key "
<<
key
;
return
nullptr
;
MS_LOG
(
EXCEPTION
)
<<
"Invalid weight key "
<<
key
;
}
WeightPtr
weight_ptr
=
weights_
[
key
];
WeightPtr
copy_weight_ptr
=
std
::
make_shared
<::
ps
::
SArray
<
T
>>
(
weight_ptr
->
size
(),
0
);
copy_weight_ptr
->
CopyFrom
(
weight_ptr
->
data
(),
weight_ptr
->
size
());
tokens_
[
key
]
-=
1
;
return
copy_weight_ptr
;
}
...
...
@@ -560,12 +618,22 @@ inline bool ParameterServer<T>::ReadyForUpdateWeights() {
}
template
<
typename
T
>
inline
bool
ParameterServer
<
T
>::
ReadyForAccumGrads
()
{
inline
bool
ParameterServer
<
T
>::
ReadyForPush
(
const
Key
&
key
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
weights_
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
"kInitWeightsCmd command. 2.The Server failed to initialize weights."
;
}
return
grad_accum_count_
<
weights_
.
size
();
return
grad_accum_count_
<
weights_
.
size
()
&&
tokens_
[
key
]
<=
0
;
}
template
<
typename
T
>
inline
bool
ParameterServer
<
T
>::
ReadyForPull
(
const
Key
&
key
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
tokens_
.
count
(
key
)
==
0
||
weights_
[
key
]
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Invalid weight key "
<<
key
;
}
return
tokens_
[
key
]
>
0
;
}
template
<
typename
T
>
...
...
@@ -576,6 +644,11 @@ inline void ParameterServer<T>::ResetGradAccumCount() {
}
}
template
<
typename
T
>
inline
std
::
mutex
&
ParameterServer
<
T
>::
mutex
()
{
return
mutex_
;
}
template
<
typename
T
>
void
ParameterServer
<
T
>::
Run
(
const
FuncGraphPtr
&
func_graph
)
{
::
ps
::
Start
(
0
);
...
...
mindspore/ccsrc/frontend/parallel/ps/worker.h
浏览文件 @
e0a2d2f9
...
...
@@ -99,18 +99,30 @@ void Worker<T>::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> add
::
ps
::
SArray
<
T
>
total_buffer
(
total_size
,
0
);
size_t
offset
=
0
;
for
(
size_t
i
=
0
;
i
<
sizes
.
size
();
i
++
)
{
memcpy_s
(
total_buffer
.
data
()
+
offset
/
sizeof
(
T
),
sizes
[
i
]
*
sizeof
(
T
),
reinterpret_cast
<
void
*>
(
addrs
[
i
]),
sizes
[
i
]
*
sizeof
(
T
));
auto
ret
=
memcpy_s
(
total_buffer
.
data
()
+
offset
/
sizeof
(
T
),
sizes
[
i
]
*
sizeof
(
T
),
reinterpret_cast
<
void
*>
(
addrs
[
i
]),
sizes
[
i
]
*
sizeof
(
T
));
if
(
ret
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"memcpy_s error, errorno("
<<
ret
<<
")"
;
}
offset
+=
sizes
[
i
]
*
sizeof
(
T
);
}
while
(
!
kv_worker_
->
IsReadyForPush
(
keys
[
0
]))
{
continue
;
}
kv_worker_
->
PushData
(
::
ps
::
SArray
<::
ps
::
Key
>
(
keys
),
total_buffer
,
::
ps
::
SArray
<
int
>
(
sizes
));
}
template
<
typename
T
>
void
Worker
<
T
>::
Pull
(
const
size_t
key
,
void
*
dev_addr
,
const
size_t
size
)
{
::
ps
::
SArray
<
T
>
variables
(
size
/
sizeof
(
T
),
0
);
while
(
!
kv_worker_
->
IsReadyForPull
(
key
))
{
continue
;
}
kv_worker_
->
Wait
(
kv_worker_
->
ZPull
({
key
},
&
variables
));
memcpy_s
(
dev_addr
,
size
,
variables
.
data
(),
size
);
auto
ret
=
memcpy_s
(
dev_addr
,
size
,
variables
.
data
(),
size
);
if
(
ret
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"memcpy_s error, errorno("
<<
ret
<<
")"
;
}
}
template
<
typename
T
>
...
...
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
浏览文件 @
e0a2d2f9
...
...
@@ -56,6 +56,8 @@ class WorkerProxy : public ::ps::KVWorker<T> {
int
priority
=
0
);
int
InitEmbeddingTable
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
vals
,
const
::
ps
::
SArray
<
int
>
&
lens
=
{},
const
Callback
&
cb
=
nullptr
,
int
priority
=
0
);
bool
IsReadyForPush
(
const
Key
&
key
);
bool
IsReadyForPull
(
const
Key
&
key
);
void
PushData
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
vals
,
const
::
ps
::
SArray
<
int
>
&
lens
=
{},
int
cmd
=
0
,
int
priority
=
0
);
void
Finalize
();
...
...
@@ -134,6 +136,28 @@ int WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, cons
return
ts
;
}
template
<
typename
T
>
bool
WorkerProxy
<
T
>::
IsReadyForPush
(
const
Key
&
key
)
{
::
ps
::
SArray
<
T
>
result
(
1
,
0
);
this
->
Wait
(
this
->
ZPull
({
key
},
&
result
,
nullptr
,
kCheckReadyForPushCmd
));
if
(
result
[
0
]
>
0
)
{
return
true
;
}
else
{
return
false
;
}
}
template
<
typename
T
>
bool
WorkerProxy
<
T
>::
IsReadyForPull
(
const
Key
&
key
)
{
::
ps
::
SArray
<
T
>
result
(
1
,
0
);
this
->
Wait
(
this
->
ZPull
({
key
},
&
result
,
nullptr
,
kCheckReadyForPullCmd
));
if
(
result
[
0
]
>
0
)
{
return
true
;
}
else
{
return
false
;
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
PushData
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
vals
,
const
::
ps
::
SArray
<
int
>
&
lens
,
int
cmd
,
int
priority
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录