提交 be113756 编写于 作者: Y Yu Yang

Refine code

上级 8d940115
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/operators/math/matrix_bit_code.h"
#include <iostream> #include <iostream>
#include <map>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -133,8 +134,7 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, ...@@ -133,8 +134,7 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
auto weight_value = weight->data<T>(); auto weight_value = weight->data<T>();
auto input_value = input.data<T>(); auto input_value = input.data<T>();
std::unordered_map<int, std::vector<std::pair<T, const T*>>> ops; std::map<int, std::vector<std::pair<T, const T*>>> ops;
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i); auto code = code_table_->get_code(i);
int code_length = code->get_length(); int code_length = code->get_length();
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -109,7 +110,7 @@ class Code { ...@@ -109,7 +110,7 @@ class Code {
// set a CodeTable interface to create multiple code table // set a CodeTable interface to create multiple code table
class CodeTable { class CodeTable {
public: public:
virtual std::unique_ptr<Code> get_code(int64_t code) const = 0; virtual Code* get_code(int64_t code) const = 0;
virtual size_t size() const = 0; virtual size_t size() const = 0;
virtual int get_max_code_length() const = 0; virtual int get_max_code_length() const = 0;
virtual ~CodeTable() {} virtual ~CodeTable() {}
...@@ -180,14 +181,23 @@ class SimpleCodeTable : public CodeTable { ...@@ -180,14 +181,23 @@ class SimpleCodeTable : public CodeTable {
public: public:
SimpleCodeTable(size_t num_classes, const int64_t* ids) SimpleCodeTable(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {} : num_classes_(num_classes), ids_(ids) {}
std::unique_ptr<Code> get_code(int64_t code) const {
std::unique_ptr<Code> coder(new SimpleCode(code, num_classes_, ids_)); Code* get_code(int64_t code) const {
return coder; auto it = codes_.find(code);
if (it != codes_.end()) {
return it->second.get();
}
auto* result = new SimpleCode(code, num_classes_, ids_);
codes_.emplace(code, std::unique_ptr<Code>(result));
return result;
} }
size_t size() const { return num_classes_; } size_t size() const { return num_classes_; }
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
private: private:
mutable std::map<int64_t, std::unique_ptr<Code>> codes_;
size_t num_classes_; size_t num_classes_;
const int64_t* ids_; const int64_t* ids_;
}; };
...@@ -199,9 +209,14 @@ class CustomCodeTable : public CodeTable { ...@@ -199,9 +209,14 @@ class CustomCodeTable : public CodeTable {
const framework::Tensor& pcode, const int64_t* ids) const framework::Tensor& pcode, const int64_t* ids)
: ptable_(ptable), pcode_(pcode), ids_(ids) {} : ptable_(ptable), pcode_(pcode), ids_(ids) {}
std::unique_ptr<Code> get_code(int64_t code) const { Code* get_code(int64_t code) const {
std::unique_ptr<Code> coder(new CustomCode<T>(ptable_, pcode_, ids_, code)); auto it = codes_.find(code);
return coder; if (it != codes_.end()) {
return it->second.get();
}
auto* result = new CustomCode<T>(ptable_, pcode_, ids_, code);
codes_.emplace(code, std::unique_ptr<Code>(result));
return result;
} }
size_t size() const { return static_cast<size_t>(ptable_.dims()[1]); } size_t size() const { return static_cast<size_t>(ptable_.dims()[1]); }
...@@ -210,6 +225,7 @@ class CustomCodeTable : public CodeTable { ...@@ -210,6 +225,7 @@ class CustomCodeTable : public CodeTable {
} }
private: private:
mutable std::unordered_map<int64_t, std::unique_ptr<Code>> codes_;
const framework::Tensor& ptable_; const framework::Tensor& ptable_;
const framework::Tensor& pcode_; const framework::Tensor& pcode_;
const int64_t* ids_; const int64_t* ids_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册