diff --git a/paddle/fluid/distributed/table/common_dense_table.cc b/paddle/fluid/distributed/table/common_dense_table.cc index 87a9f5fb2426aa088701ebd35818e32aea5c8e2d..8d8b43b37403a4aa49b208f255000f8f320965ad 100644 --- a/paddle/fluid/distributed/table/common_dense_table.cc +++ b/paddle/fluid/distributed/table/common_dense_table.cc @@ -120,6 +120,7 @@ int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) { } int32_t CommonDenseTable::pour() { + pull_reservoir_.avg(); _push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size()); pull_reservoir_.reset(); return 0; diff --git a/paddle/fluid/distributed/table/common_table.h b/paddle/fluid/distributed/table/common_table.h index 034769e021207cce0277fad6739566bf8da1fa67..dc3cfa75ff689863773e88ef2d077b80c1f0a5d5 100644 --- a/paddle/fluid/distributed/table/common_table.h +++ b/paddle/fluid/distributed/table/common_table.h @@ -55,12 +55,13 @@ struct ReservoirValue { } void avg() { + if (counter == 0) return; auto scale = 1 / static_cast(counter); GetBlas().SCAL(values.size(), scale, values.data()); } void reset() { - values.resize(dim, 0); + std::fill(values.begin(), values.end(), 0); counter = 0; } }; @@ -134,15 +135,15 @@ class BarrierTable : public Table { return 0; } int32_t shrink(const std::string ¶m) override { return 0; } - virtual void clear(){}; - virtual int32_t flush() { return 0; }; + virtual void clear() {} + virtual int32_t flush() { return 0; } virtual int32_t load(const std::string &path, const std::string ¶m) { return 0; } virtual int32_t save(const std::string &path, const std::string ¶m) { return 0; } - virtual int32_t initialize_shard() { return 0; }; + virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize() override; // only for barrier