提交 15550a27 编写于 作者: Y Yu Yang

Polish code

上级 9e0b33d7
......@@ -18,8 +18,8 @@ ENDIF()
INCLUDE(python_module)
FIND_PACKAGE(PythonInterp ${PY_VERSION})
FIND_PACKAGE(PythonLibs ${PY_VERSION})
FIND_PACKAGE(PythonInterp ${PY_VERSION} REQUIRED)
FIND_PACKAGE(PythonLibs ${PY_VERSION} REQUIRED)
if(WIN32)
execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c"
......@@ -79,6 +79,6 @@ IF(PYTHONINTERP_FOUND)
"please use pip to upgrade protobuf. pip install -U protobuf")
ENDIF()
ENDIF(PYTHONINTERP_FOUND)
message(STATUS ${PYTHON_INCLUDE_DIR})
INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIR})
INCLUDE_DIRECTORIES(${PYTHON_NUMPY_INCLUDE_DIR})
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/variant.h"
#if defined(_WIN32)
#include <intrin.h>
......@@ -99,24 +100,7 @@ inline int clz(const T& value) {
inline size_t FindLastSet(size_t x) { return sizeof(size_t) * 8 - clz(x); }
#endif // !_WIN32
// set a code interface to create multiple code
class Code {
public:
virtual ~Code() {}
virtual size_t calc_index(int bit) const = 0;
virtual bool calc_bit(int bit) const = 0;
virtual int get_length() const = 0;
};
// set a CodeTable interface to create multiple code table
class CodeTable {
public:
virtual Code* get_code(int64_t code) const = 0;
virtual size_t size() const = 0;
virtual int get_max_code_length() const = 0;
virtual ~CodeTable() {}
};
class SimpleCode : public Code {
class SimpleCode {
public:
SimpleCode(size_t code, size_t num_classes, const int64_t* ids)
: c_(static_cast<size_t>(ids[code]) + num_classes) {}
......@@ -138,7 +122,7 @@ class SimpleCode : public Code {
};
template <typename T>
class CustomCode : public Code {
class CustomCode {
public:
CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode,
const int64_t* ids, int index) {
......@@ -155,11 +139,11 @@ class CustomCode : public Code {
* Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit.
*/
size_t calc_index(int bit) const override { return ptable_data_[bit]; }
bool calc_bit(int bit) const override { return pcode_data_[bit]; }
size_t calc_index(int bit) const { return ptable_data_[bit]; }
bool calc_bit(int bit) const { return pcode_data_[bit]; }
// NOTE: this function is not thread-safe.
int get_length() const override {
int get_length() const {
if (length_ < 0) {
auto len = seq_len_;
length_ =
......@@ -177,46 +161,32 @@ class CustomCode : public Code {
mutable int length_{-1};
};
class SimpleCodeTable : public CodeTable {
class SimpleCodeTable {
public:
SimpleCodeTable(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {}
Code* get_code(int64_t code) const {
auto it = codes_.find(code);
if (it != codes_.end()) {
return it->second.get();
}
auto* result = new SimpleCode(code, num_classes_, ids_);
codes_.emplace(code, std::unique_ptr<Code>(result));
return result;
SimpleCode get_code(int64_t code) const {
return SimpleCode(code, num_classes_, ids_);
}
size_t size() const { return num_classes_; }
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
private:
mutable std::map<int64_t, std::unique_ptr<Code>> codes_;
size_t num_classes_;
const int64_t* ids_;
};
template <typename T>
class CustomCodeTable : public CodeTable {
class CustomCodeTable {
public:
CustomCodeTable(const framework::Tensor& ptable,
const framework::Tensor& pcode, const int64_t* ids)
: ptable_(ptable), pcode_(pcode), ids_(ids) {}
Code* get_code(int64_t code) const {
auto it = codes_.find(code);
if (it != codes_.end()) {
return it->second.get();
}
auto* result = new CustomCode<T>(ptable_, pcode_, ids_, code);
codes_.emplace(code, std::unique_ptr<Code>(result));
return result;
CustomCode<T> get_code(int64_t code) const {
return CustomCode<T>(ptable_, pcode_, ids_, code);
}
size_t size() const { return static_cast<size_t>(ptable_.dims()[1]); }
......@@ -225,25 +195,26 @@ class CustomCodeTable : public CodeTable {
}
private:
mutable std::unordered_map<int64_t, std::unique_ptr<Code>> codes_;
const framework::Tensor& ptable_;
const framework::Tensor& pcode_;
const int64_t* ids_;
};
using CodeTable = boost::variant<SimpleCodeTable, CustomCodeTable<int64_t>>;
template <typename T>
class MatrixBitCodeFunctor {
public:
MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes),
ids_(ids),
code_table_(new SimpleCodeTable(num_classes, ids)) {}
code_table_(SimpleCodeTable(num_classes, ids)) {}
MatrixBitCodeFunctor(const framework::Tensor& ptable,
const framework::Tensor& pcode, const int64_t* ids)
: num_classes_(static_cast<size_t>(ptable.dims()[1])),
ids_(ids),
code_table_(new CustomCodeTable<int64_t>(ptable, pcode, ids)) {}
code_table_(CustomCodeTable<int64_t>(ptable, pcode, ids)) {}
/* For j < code_length
tmat(i, j) += vec(0, index(i, j))
*/
......@@ -293,7 +264,7 @@ class MatrixBitCodeFunctor {
size_t num_classes_;
const int64_t* ids_;
std::unique_ptr<CodeTable> code_table_;
CodeTable code_table_;
};
} // namespace math
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册