From 5d7a8b05f82c175b43515e45d76e3fbb7bc3416b Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 2 Mar 2021 19:13:03 +0800 Subject: [PATCH] fix sycn training error (#31357) * fix sycn training error Change-Id: Ie2feebcf0b5b2984fd59cfcdde0c817840e203d2 --- paddle/fluid/distributed/table/common_dense_table.cc | 1 + paddle/fluid/distributed/table/common_table.h | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/distributed/table/common_dense_table.cc b/paddle/fluid/distributed/table/common_dense_table.cc index 87a9f5fb242..8d8b43b3740 100644 --- a/paddle/fluid/distributed/table/common_dense_table.cc +++ b/paddle/fluid/distributed/table/common_dense_table.cc @@ -120,6 +120,7 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) { } int32_t CommonDenseTable::pour() { + pull_reservoir_.avg(); _push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size()); pull_reservoir_.reset(); return 0; diff --git a/paddle/fluid/distributed/table/common_table.h b/paddle/fluid/distributed/table/common_table.h index 034769e0212..dc3cfa75ff6 100644 --- a/paddle/fluid/distributed/table/common_table.h +++ b/paddle/fluid/distributed/table/common_table.h @@ -55,12 +55,13 @@ struct ReservoirValue { } void avg() { + if (counter == 0) return; auto scale = 1 / static_cast(counter); GetBlas().SCAL(values.size(), scale, values.data()); } void reset() { - values.resize(dim, 0); + std::fill(values.begin(), values.end(), 0); counter = 0; } }; @@ -134,15 +135,15 @@ class BarrierTable : public Table { return 0; } int32_t shrink(const std::string ¶m) override { return 0; } - virtual void clear(){}; - virtual int32_t flush() { return 0; }; + virtual void clear() {} + virtual int32_t flush() { return 0; } virtual int32_t load(const std::string &path, const std::string ¶m) { return 0; } virtual int32_t save(const std::string &path, const std::string ¶m) { return 0; } - virtual int32_t initialize_shard() { return 0; }; + virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize() override; // only for barrier -- GitLab