未验证 提交 5d7a8b05 编写于 作者: T tangwei12 提交者: GitHub

fix sycn training error (#31357)

* fix sycn training error

Change-Id: Ie2feebcf0b5b2984fd59cfcdde0c817840e203d2
上级 ec72f5b2
......@@ -120,6 +120,7 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) {
}
int32_t CommonDenseTable::pour() {
pull_reservoir_.avg();
_push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size());
pull_reservoir_.reset();
return 0;
......
......@@ -55,12 +55,13 @@ struct ReservoirValue {
}
void avg() {
if (counter == 0) return;
auto scale = 1 / static_cast<T>(counter);
GetBlas<T>().SCAL(values.size(), scale, values.data());
}
void reset() {
values.resize(dim, 0);
std::fill(values.begin(), values.end(), 0);
counter = 0;
}
};
......@@ -134,15 +135,15 @@ class BarrierTable : public Table {
return 0;
}
int32_t shrink(const std::string &param) override { return 0; }
virtual void clear(){};
virtual int32_t flush() { return 0; };
virtual void clear() {}
virtual int32_t flush() { return 0; }
virtual int32_t load(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t save(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t initialize_shard() { return 0; };
virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize() override;
// only for barrier
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册