提交 be113756 编写于 作者: Y Yu Yang

Refine code

上级 8d940115
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include <iostream>
#include <map>
namespace paddle {
namespace operators {
namespace math {
......@@ -133,8 +134,7 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
auto weight_value = weight->data<T>();
auto input_value = input.data<T>();
std::unordered_map<int, std::vector<std::pair<T, const T*>>> ops;
std::map<int, std::vector<std::pair<T, const T*>>> ops;
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i);
int code_length = code->get_length();
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <unordered_map>
#include <utility>
#include <vector>
......@@ -109,7 +110,7 @@ class Code {
// set a CodeTable interface to create multiple code table
class CodeTable {
public:
virtual std::unique_ptr<Code> get_code(int64_t code) const = 0;
virtual Code* get_code(int64_t code) const = 0;
virtual size_t size() const = 0;
virtual int get_max_code_length() const = 0;
virtual ~CodeTable() {}
......@@ -180,14 +181,23 @@ class SimpleCodeTable : public CodeTable {
public:
SimpleCodeTable(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {}
std::unique_ptr<Code> get_code(int64_t code) const {
std::unique_ptr<Code> coder(new SimpleCode(code, num_classes_, ids_));
return coder;
Code* get_code(int64_t code) const {
auto it = codes_.find(code);
if (it != codes_.end()) {
return it->second.get();
}
auto* result = new SimpleCode(code, num_classes_, ids_);
codes_.emplace(code, std::unique_ptr<Code>(result));
return result;
}
size_t size() const { return num_classes_; }
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
private:
mutable std::map<int64_t, std::unique_ptr<Code>> codes_;
size_t num_classes_;
const int64_t* ids_;
};
......@@ -199,9 +209,14 @@ class CustomCodeTable : public CodeTable {
const framework::Tensor& pcode, const int64_t* ids)
: ptable_(ptable), pcode_(pcode), ids_(ids) {}
std::unique_ptr<Code> get_code(int64_t code) const {
std::unique_ptr<Code> coder(new CustomCode<T>(ptable_, pcode_, ids_, code));
return coder;
Code* get_code(int64_t code) const {
auto it = codes_.find(code);
if (it != codes_.end()) {
return it->second.get();
}
auto* result = new CustomCode<T>(ptable_, pcode_, ids_, code);
codes_.emplace(code, std::unique_ptr<Code>(result));
return result;
}
size_t size() const { return static_cast<size_t>(ptable_.dims()[1]); }
......@@ -210,6 +225,7 @@ class CustomCodeTable : public CodeTable {
}
private:
mutable std::unordered_map<int64_t, std::unique_ptr<Code>> codes_;
const framework::Tensor& ptable_;
const framework::Tensor& pcode_;
const int64_t* ids_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册