未验证 提交 5d7a8b05 编写于 作者: T tangwei12 提交者: GitHub

fix sycn training error (#31357)

* fix sycn training error

Change-Id: Ie2feebcf0b5b2984fd59cfcdde0c817840e203d2
上级 ec72f5b2
...@@ -120,6 +120,7 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) { ...@@ -120,6 +120,7 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) {
} }
int32_t CommonDenseTable::pour() { int32_t CommonDenseTable::pour() {
pull_reservoir_.avg();
_push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size()); _push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size());
pull_reservoir_.reset(); pull_reservoir_.reset();
return 0; return 0;
......
...@@ -55,12 +55,13 @@ struct ReservoirValue { ...@@ -55,12 +55,13 @@ struct ReservoirValue {
} }
void avg() { void avg() {
if (counter == 0) return;
auto scale = 1 / static_cast<T>(counter); auto scale = 1 / static_cast<T>(counter);
GetBlas<T>().SCAL(values.size(), scale, values.data()); GetBlas<T>().SCAL(values.size(), scale, values.data());
} }
void reset() { void reset() {
values.resize(dim, 0); std::fill(values.begin(), values.end(), 0);
counter = 0; counter = 0;
} }
}; };
...@@ -134,15 +135,15 @@ class BarrierTable : public Table { ...@@ -134,15 +135,15 @@ class BarrierTable : public Table {
return 0; return 0;
} }
int32_t shrink(const std::string &param) override { return 0; } int32_t shrink(const std::string &param) override { return 0; }
virtual void clear(){}; virtual void clear() {}
virtual int32_t flush() { return 0; }; virtual int32_t flush() { return 0; }
virtual int32_t load(const std::string &path, const std::string &param) { virtual int32_t load(const std::string &path, const std::string &param) {
return 0; return 0;
} }
virtual int32_t save(const std::string &path, const std::string &param) { virtual int32_t save(const std::string &path, const std::string &param) {
return 0; return 0;
} }
virtual int32_t initialize_shard() { return 0; }; virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize() override; virtual int32_t initialize() override;
// only for barrier // only for barrier
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册