“2eae3616006ef1e1ff440211c5bb8c4399089318”上不存在“develop/doc/design/kernel_selection.html”
提交 6eab5638 编写于 作者: 武毅 提交者: GitHub

Fix remote large update core (#3518)

* fix remote large update core

* wip

* working version

* fix style check

* fix style check

* update style check
上级 9bc4cf65
...@@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector<Argument>& inArgs) { ...@@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector<Argument>& inArgs) {
auto mat = dynamic_cast<SparsePrefetchRowCpuMatrix*>( auto mat = dynamic_cast<SparsePrefetchRowCpuMatrix*>(
para->getMat(PARAMETER_VALUE).get()); para->getMat(PARAMETER_VALUE).get());
para->clearGradient(); para->clearGradient();
mat->clearIndices(); if (mat) mat->clearIndices();
} }
} }
} }
......
...@@ -65,8 +65,11 @@ public: ...@@ -65,8 +65,11 @@ public:
size_t getSize() const { return config_.size(); } size_t getSize() const { return config_.size(); }
bool isFullSize() const { bool isFullSize() const {
if (bufs_[PARAMETER_VALUE]) {
return this->getSize() == bufs_[PARAMETER_VALUE]->getSize(); return this->getSize() == bufs_[PARAMETER_VALUE]->getSize();
} }
return false;
}
inline bool useGpu() const { return useGpu_; } inline bool useGpu() const { return useGpu_; }
......
...@@ -65,7 +65,6 @@ void ParameterClient2::initThreads() { ...@@ -65,7 +65,6 @@ void ParameterClient2::initThreads() {
LOG(INFO) << "parallel_thread_num dosent need to set"; LOG(INFO) << "parallel_thread_num dosent need to set";
} }
syncThreadPool_.reset(new SyncThreadPool(threadNum_)); syncThreadPool_.reset(new SyncThreadPool(threadNum_));
startThreads(); startThreads();
} }
...@@ -224,6 +223,14 @@ void ParameterClient2::prepareSendData( ...@@ -224,6 +223,14 @@ void ParameterClient2::prepareSendData(
request.set_cost(cost); request.set_cost(cost);
request.set_batch_status(batchStatus); request.set_batch_status(batchStatus);
CHECK_EQ(request.blocks_size(), 0); CHECK_EQ(request.blocks_size(), 0);
VLOG(10) << "request: trainer_id: " << request.trainer_id()
<< " update_mode" << request.update_mode()
<< " send_back_parameter: " << request.send_back_parameter()
<< " send_back_parameter_type: "
<< request.send_back_parameter_type()
<< " num_samples: " << request.num_samples()
<< " cost: " << request.cost()
<< " batch_status: " << request.batch_status();
} }
for (const auto& segments : parameterSegments) { for (const auto& segments : parameterSegments) {
const auto it = parameterMap_.find(segments.id); const auto it = parameterMap_.find(segments.id);
...@@ -251,11 +258,17 @@ void ParameterClient2::prepareSendData( ...@@ -251,11 +258,17 @@ void ParameterClient2::prepareSendData(
CHECK(sendMat != nullptr) << "sendMat is nullptr"; CHECK(sendMat != nullptr) << "sendMat is nullptr";
syncThreadPool_->exec([&](int tid, size_t numThreads) { syncThreadPool_->exec([&](int tid, size_t numThreads) {
std::lock_guard<std::mutex> guard(sparseAutoGrowthMutex_);
const auto& localIndices = prefetchMat->getLocalIndices(); const auto& localIndices = prefetchMat->getLocalIndices();
/// num of sparse rows /// num of sparse rows
size_t nLocalBlocks = localIndices.size(); size_t nLocalBlocks = localIndices.size();
uint64_t beginDim = 0; uint64_t beginDim = 0;
uint64_t endDim = 0; uint64_t endDim = 0;
// FIXME(typhoonzero): let it resize first
prefetchMat->getLocalRow(nLocalBlocks + 1);
sendMat->getLocalRow(nLocalBlocks + 1);
for (size_t row = 0; row < nLocalBlocks; ++row) { for (size_t row = 0; row < nLocalBlocks; ++row) {
int64_t blockId = localIndices[row]; // local row -> sparse row int64_t blockId = localIndices[row]; // local row -> sparse row
int serverId = std::abs((blockId + nameHash) % serviceNum_); int serverId = std::abs((blockId + nameHash) % serviceNum_);
...@@ -275,7 +288,6 @@ void ParameterClient2::prepareSendData( ...@@ -275,7 +288,6 @@ void ParameterClient2::prepareSendData(
block->set_begin_pos(row * blockSize); block->set_begin_pos(row * blockSize);
/// block len /// block len
block->set_block_size(endDim - beginDim); block->set_block_size(endDim - beginDim);
if (sendingPara) { if (sendingPara) {
sendJob->parallelInputIovs[serverId].push_back( sendJob->parallelInputIovs[serverId].push_back(
{sendMat->getLocalRow(row), sizeof(real) * (size_t)blockSize}); {sendMat->getLocalRow(row), sizeof(real) * (size_t)blockSize});
......
...@@ -583,6 +583,7 @@ protected: ...@@ -583,6 +583,7 @@ protected:
#ifndef PADDLE_DISABLE_TIMER #ifndef PADDLE_DISABLE_TIMER
uint64_t forwardbackwordTime_; uint64_t forwardbackwordTime_;
#endif #endif
std::mutex sparseAutoGrowthMutex_;
/// map id to parameter used for decoding protobuf data /// map id to parameter used for decoding protobuf data
std::unordered_map<size_t, ParameterPtr> parameterMap_; std::unordered_map<size_t, ParameterPtr> parameterMap_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册