提交 78f4923e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4848 Combine duplicate embedding lookup IDs

Merge pull request !4848 from chengang/unique_id_5
......@@ -314,7 +314,7 @@ template <typename T>
void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
const Key &key = req_data.keys[0];
for (size_t i = 0; i < req_data.keys.size(); i++) {
for (size_t i = 1; i < req_data.keys.size(); i++) {
res->keys.push_back(req_data.keys[i]);
}
ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res);
......
......@@ -259,13 +259,29 @@ int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps:
auto &kvs = lookup_results_[ts];
mutex_.unlock();
::ps::SArray<T> result(kvs[0].vals.size(), 0);
for (auto k : kvs) {
for (size_t i = 0; i < k.vals.size(); i++) {
result[i] += k.vals[i];
std::unordered_map<Key, std::shared_ptr<std::pair<T *, int>>> id_addr_map;
for (const auto &s : kvs) {
int offset = 0;
int len = s.vals.size() / s.keys.size();
for (size_t i = 0; i < s.keys.size(); i++) {
const Key &key = s.keys[i];
T *addr = s.vals.data() + offset;
offset += len;
id_addr_map[key] = std::make_shared<std::pair<T *, int>>(std::make_pair(addr, len));
}
}
*lookup_result = result;
T *result_addr = lookup_result->data();
int offset = 0;
for (size_t i = 0; i < lookup_ids.size(); i++) {
auto &pair = id_addr_map[static_cast<Key>(lookup_ids[i])];
int size = pair->second * sizeof(T);
auto ret = memcpy_s(result_addr + offset, size, pair->first, size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
offset += pair->second;
}
mutex_.lock();
lookup_results_.erase(ts);
......@@ -312,12 +328,23 @@ void WorkerProxy<T>::LookupIdSlicer(int timestamp, const ::ps::KVPairs<T> &send,
sliced->resize(ranges.size());
for (size_t i = 0; i < ranges.size(); i++) {
const ::ps::Range &range = ranges[i];
const auto &begin = range.begin();
const auto &end = range.end();
std::unordered_set<int> unique_ids;
auto &kvs = sliced->at(i).second;
kvs.keys.push_back(key);
kvs.vals.push_back(0.0f);
for (size_t j = 0; j < id_size; j++) {
kvs.keys.push_back(lookup_ids[j]);
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
if (lookup_id >= begin && lookup_id <= end) {
unique_ids.insert(lookup_id);
}
}
for (const auto &lookup_id : unique_ids) {
kvs.keys.push_back(lookup_id);
kvs.vals.push_back(0.0f);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册