提交 b5fa9164 编写于 作者: J JiabinYang

fix bug after merge reyoung optimization, test=develop

上级 656040c7
...@@ -71,7 +71,6 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -71,7 +71,6 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
// server // server
auto height_sections = ctx.Attr<std::vector<int>>("height_sections"); auto height_sections = ctx.Attr<std::vector<int>>("height_sections");
auto table_names = ctx.Attr<std::vector<std::string>>("table_names"); auto table_names = ctx.Attr<std::vector<std::string>>("table_names");
VLOG(3) << "path type is " << path->type().name();
std::vector<int64_t> real_rows = PathToRows(*path); std::vector<int64_t> real_rows = PathToRows(*path);
framework::Scope& local_scope = ctx.scope().NewScope(); framework::Scope& local_scope = ctx.scope().NewScope();
auto* ids = local_scope.Var("Ids@Prefetch"); auto* ids = local_scope.Var("Ids@Prefetch");
......
...@@ -84,41 +84,6 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor &tmat, ...@@ -84,41 +84,6 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor &tmat,
code_table_.apply_visitor(func); code_table_.apply_visitor(func);
} }
template <typename T>
struct MatrixBitCodeFunctorSelectedRowsAddGrad
: public boost::static_visitor<void> {
const framework::Tensor &tmat_;
framework::SelectedRows *vec_;
MatrixBitCodeFunctorSelectedRowsAddGrad(const framework::Tensor &tmat,
framework::SelectedRows *vec)
: tmat_(tmat), vec_(vec) {}
template <typename CodeTable>
void operator()(const CodeTable &code_table) {
size_t batch_size = tmat_.dims()[0];
size_t width = tmat_.dims()[1];
auto *vec_data = vec_->mutable_value()->template data<T>();
auto *tmat_data = tmat_.data<T>();
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table.get_code(i);
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
int64_t row_index = vec_->GetIndexFromId(static_cast<int64_t>(index));
vec_data[row_index] += tmat_data[i * width + j];
}
}
}
};
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor &tmat,
framework::SelectedRows *vec) {
MatrixBitCodeFunctorSelectedRowsAddGrad<T> func(tmat, vec);
code_table_.apply_visitor(func);
}
template <typename T> template <typename T>
struct MatrixBitCodeFunctorSum : public boost::static_visitor<void> { struct MatrixBitCodeFunctorSum : public boost::static_visitor<void> {
const framework::Tensor &tmat_; const framework::Tensor &tmat_;
......
...@@ -124,11 +124,12 @@ class SimpleCode { ...@@ -124,11 +124,12 @@ class SimpleCode {
template <typename T> template <typename T>
class CustomCode { class CustomCode {
public: public:
CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode, CustomCode(const framework::Tensor& path_table,
const int64_t* ids, int index) { const framework::Tensor& path_code, const int64_t* ids,
seq_len_ = ptable.dims()[1]; int index) {
ptable_data_ = ptable.data<T>() + seq_len_ * index; seq_len_ = path_table.dims()[1];
pcode_data_ = pcode.data<T>() + seq_len_ * index; path_table_data_ = path_table.data<T>() + seq_len_ * index;
path_code_data_ = path_code.data<T>() + seq_len_ * index;
} }
/** /**
* Here the id of root should be 1 rather than 0, thus the encoding of class c * Here the id of root should be 1 rather than 0, thus the encoding of class c
...@@ -139,25 +140,25 @@ class CustomCode { ...@@ -139,25 +140,25 @@ class CustomCode {
* Binary classification path is the suffixes of encoding, thus leave out the * Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit. * left most bit in calc_bit.
*/ */
size_t calc_index(int bit) const { return ptable_data_[bit]; } size_t calc_index(int bit) const { return path_table_data_[bit]; }
bool calc_bit(int bit) const { return pcode_data_[bit]; } bool calc_bit(int bit) const { return path_code_data_[bit]; }
// NOTE: this function is not thread-safe. // NOTE: this function is not thread-safe.
int get_length() const { int get_length() const {
if (length_ < 0) { if (length_ < 0) {
auto len = seq_len_; auto len = seq_len_;
length_ = length_ = static_cast<int>(
static_cast<int>(std::find_if(ptable_data_, ptable_data_ + len, std::find_if(path_table_data_, path_table_data_ + len,
[](const T& val) { return val < 0; }) - [](const T& val) { return val < 0; }) -
ptable_data_); path_table_data_);
} }
return length_; return length_;
} }
private: private:
int64_t seq_len_; int64_t seq_len_;
const T* ptable_data_; const T* path_table_data_;
const T* pcode_data_; const T* path_code_data_;
mutable int length_{-1}; mutable int length_{-1};
}; };
...@@ -214,7 +215,7 @@ class MatrixBitCodeFunctor { ...@@ -214,7 +215,7 @@ class MatrixBitCodeFunctor {
const framework::Tensor& path_code, const int64_t* ids) const framework::Tensor& path_code, const int64_t* ids)
: num_classes_(static_cast<size_t>(path_table.dims()[1])), : num_classes_(static_cast<size_t>(path_table.dims()[1])),
ids_(ids), ids_(ids),
code_table_(CustomCodeTable<int64_t>(ptable, pcode, ids)) {} code_table_(CustomCodeTable<int64_t>(path_table, path_code, ids)) {}
/* For j < code_length /* For j < code_length
tmat(i, j) += vec(0, index(i, j)) tmat(i, j) += vec(0, index(i, j))
*/ */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册