diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index dbf4f5e3325c324b0676393e346cae486859b993..92affa0e4ed762cc660baf0f84fd62f13ee3de29 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/matrix_bit_code.h" #include +#include namespace paddle { namespace operators { namespace math { @@ -133,8 +134,7 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, auto weight_value = weight->data(); auto input_value = input.data(); - std::unordered_map>> ops; - + std::map>> ops; for (size_t i = 0; i < num_samples; ++i) { auto code = code_table_->get_code(i); int code_length = code->get_length(); diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index ba1745b86dabf2505a0874bb318d3d94bdb5af18..cf43ad9d449430749086cc9b13e246d550c3661e 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include #include @@ -109,7 +110,7 @@ class Code { // set a CodeTable interface to create multiple code table class CodeTable { public: - virtual std::unique_ptr get_code(int64_t code) const = 0; + virtual Code* get_code(int64_t code) const = 0; virtual size_t size() const = 0; virtual int get_max_code_length() const = 0; virtual ~CodeTable() {} @@ -180,14 +181,23 @@ class SimpleCodeTable : public CodeTable { public: SimpleCodeTable(size_t num_classes, const int64_t* ids) : num_classes_(num_classes), ids_(ids) {} - std::unique_ptr get_code(int64_t code) const { - std::unique_ptr coder(new SimpleCode(code, num_classes_, ids_)); - return coder; + + Code* get_code(int64_t code) const { + auto it = codes_.find(code); + if (it != codes_.end()) { + return it->second.get(); + } + auto* result = new SimpleCode(code, num_classes_, ids_); + codes_.emplace(code, std::unique_ptr(result)); + return result; } + size_t size() const { return num_classes_; } int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } private: + mutable std::map> codes_; + size_t num_classes_; const int64_t* ids_; }; @@ -199,9 +209,14 @@ class CustomCodeTable : public CodeTable { const framework::Tensor& pcode, const int64_t* ids) : ptable_(ptable), pcode_(pcode), ids_(ids) {} - std::unique_ptr get_code(int64_t code) const { - std::unique_ptr coder(new CustomCode(ptable_, pcode_, ids_, code)); - return coder; + Code* get_code(int64_t code) const { + auto it = codes_.find(code); + if (it != codes_.end()) { + return it->second.get(); + } + auto* result = new CustomCode(ptable_, pcode_, ids_, code); + codes_.emplace(code, std::unique_ptr(result)); + return result; } size_t size() const { return static_cast(ptable_.dims()[1]); } @@ -210,6 +225,7 @@ class CustomCodeTable : public CodeTable { } private: + mutable std::unordered_map> codes_; const framework::Tensor& ptable_; const framework::Tensor& pcode_; const int64_t* ids_;