提交 8d940115 编写于 作者: Y Yu Yang

Refine w2v

上级 c00e07cd
...@@ -23,12 +23,14 @@ void MatrixBitCodeFunctor<T>::Add(const framework::Tensor& vec, ...@@ -23,12 +23,14 @@ void MatrixBitCodeFunctor<T>::Add(const framework::Tensor& vec,
framework::Tensor* tmat) { framework::Tensor* tmat) {
size_t batch_size = tmat->dims()[0]; size_t batch_size = tmat->dims()[0];
size_t width = tmat->dims()[1]; size_t width = tmat->dims()[1];
auto* tmat_data = tmat->data<T>();
auto* vec_data = vec.data<T>();
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table_->get_code(i); auto code = code_table_->get_code(i);
int code_length = code->get_length(); int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j); size_t index = code->calc_index(j);
tmat->data<T>()[i * width + j] += vec.data<T>()[index]; tmat_data[i * width + j] += vec_data[index];
} }
} }
} }
...@@ -38,12 +40,14 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat, ...@@ -38,12 +40,14 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
framework::Tensor* vec) { framework::Tensor* vec) {
size_t batch_size = tmat.dims()[0]; size_t batch_size = tmat.dims()[0];
size_t width = tmat.dims()[1]; size_t width = tmat.dims()[1];
auto* vec_data = vec->data<T>();
auto* tmat_data = tmat.data<T>();
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table_->get_code(i); auto code = code_table_->get_code(i);
int code_length = code->get_length(); int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j); size_t index = code->calc_index(j);
vec->data<T>()[index] += tmat.data<T>()[i * width + j]; vec_data[index] += tmat_data[i * width + j];
} }
} }
} }
...@@ -53,14 +57,15 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat, ...@@ -53,14 +57,15 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
framework::SelectedRows* vec) { framework::SelectedRows* vec) {
size_t batch_size = tmat.dims()[0]; size_t batch_size = tmat.dims()[0];
size_t width = tmat.dims()[1]; size_t width = tmat.dims()[1];
auto* vec_data = vec->mutable_value()->data<T>();
auto* tmat_data = tmat.data<T>();
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table_->get_code(i); auto code = code_table_->get_code(i);
int code_length = code->get_length(); int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
size_t index = code->calc_index(j); size_t index = code->calc_index(j);
int64_t row_index = vec->GetIndexFromId(static_cast<int64_t>(index)); int64_t row_index = vec->GetIndexFromId(static_cast<int64_t>(index));
vec->mutable_value()->data<T>()[row_index] += vec_data[row_index] += tmat_data[i * width + j];
tmat.data<T>()[i * width + j];
} }
} }
} }
...@@ -70,6 +75,8 @@ void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat, ...@@ -70,6 +75,8 @@ void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
framework::Tensor* sum, T scale_sum) { framework::Tensor* sum, T scale_sum) {
size_t num_samples = tmat.dims()[0]; size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1]; size_t o_width = tmat.dims()[1];
auto* tmat_data = tmat.data<T>();
auto* sum_data = sum->data<T>();
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
T sm = static_cast<T>(0.0); T sm = static_cast<T>(0.0);
auto code = code_table_->get_code(i); auto code = code_table_->get_code(i);
...@@ -78,10 +85,10 @@ void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat, ...@@ -78,10 +85,10 @@ void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
if (code->calc_bit(j)) { if (code->calc_bit(j)) {
// calc_bit starts from right most bit, while data in tmat[i] is in the // calc_bit starts from right most bit, while data in tmat[i] is in the
// reverse order. // reverse order.
sm += tmat.data<T>()[i * o_width + j]; sm += tmat_data[i * o_width + j];
} }
} }
sum->data<T>()[i] = scale_sum * sm; sum_data[i] = scale_sum * sm;
} }
} }
...@@ -217,12 +224,13 @@ template <typename T> ...@@ -217,12 +224,13 @@ template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) { void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
size_t num_samples = tmat->dims()[0]; size_t num_samples = tmat->dims()[0];
size_t o_width = tmat->dims()[1]; size_t o_width = tmat->dims()[1];
auto* tmat_data = tmat->data<T>();
for (size_t i = 0; i < num_samples; ++i) { for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table_->get_code(i); auto code = code_table_->get_code(i);
int code_length = code->get_length(); int code_length = code->get_length();
for (int j = 0; j < code_length; ++j) { for (int j = 0; j < code_length; ++j) {
if (code->calc_bit(j)) { if (code->calc_bit(j)) {
tmat->data<T>()[i * o_width + j] -= 1; tmat_data[i * o_width + j] -= 1;
} }
} }
} }
......
...@@ -140,13 +140,13 @@ template <typename T> ...@@ -140,13 +140,13 @@ template <typename T>
class CustomCode : public Code { class CustomCode : public Code {
public: public:
CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode, CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode,
const int64_t* ids, int index) const int64_t* ids, int index) {
: ids_(ids), index_(index) { seq_len_ = ptable.dims()[1];
ptable_ = ptable.Slice(index, index + 1); ptable_data_ = ptable.data<T>() + seq_len_ * index;
pcode_ = pcode.Slice(index, index + 1); pcode_data_ = pcode.data<T>() + seq_len_ * index;
} }
/** /**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c * Here the id of root should be 1 rather than 0, thus the encoding of class c
* is `c + num_classes` and all siblings can get the same weight indice using * is `c + num_classes` and all siblings can get the same weight indice using
* prefixes. * prefixes.
* Weight index is the prefixes of encoding, thus leave out the right most * Weight index is the prefixes of encoding, thus leave out the right most
...@@ -154,26 +154,26 @@ class CustomCode : public Code { ...@@ -154,26 +154,26 @@ class CustomCode : public Code {
* Binary classification path is the suffixes of encoding, thus leave out the * Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit. * left most bit in calc_bit.
*/ */
size_t calc_index(int bit) const { return ptable_.data<T>()[bit]; } size_t calc_index(int bit) const override { return ptable_data_[bit]; }
bool calc_bit(int bit) const { return pcode_.data<T>()[bit]; } bool calc_bit(int bit) const override { return pcode_data_[bit]; }
int get_length() const {
int length = 0;
for (int i = 0; i < static_cast<int>(ptable_.dims()[1]); i++) { // NOTE: this function is not thread-safe.
if (ptable_.data<T>()[i] >= 0) { int get_length() const override {
length++; if (length_ < 0) {
} else { auto len = seq_len_;
return length; length_ =
} static_cast<int>(std::find_if(ptable_data_, ptable_data_ + len,
[](const T& val) { return val < 0; }) -
ptable_data_);
} }
return length; return length_;
} }
private: private:
framework::Tensor ptable_; int64_t seq_len_;
framework::Tensor pcode_; const T* ptable_data_;
const int64_t* ids_; const T* pcode_data_;
const int index_; mutable int length_{-1};
}; };
class SimpleCodeTable : public CodeTable { class SimpleCodeTable : public CodeTable {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册