提交 b5fa9164 编写于 作者: J JiabinYang

fix bug after merge reyoung optimization, test=develop

上级 656040c7
......@@ -71,7 +71,6 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
// server
auto height_sections = ctx.Attr<std::vector<int>>("height_sections");
auto table_names = ctx.Attr<std::vector<std::string>>("table_names");
VLOG(3) << "path type is " << path->type().name();
std::vector<int64_t> real_rows = PathToRows(*path);
framework::Scope& local_scope = ctx.scope().NewScope();
auto* ids = local_scope.Var("Ids@Prefetch");
......
......@@ -84,41 +84,6 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor &tmat,
code_table_.apply_visitor(func);
}
template <typename T>
struct MatrixBitCodeFunctorSelectedRowsAddGrad
: public boost::static_visitor<void> {
const framework::Tensor &tmat_;
framework::SelectedRows *vec_;
MatrixBitCodeFunctorSelectedRowsAddGrad(const framework::Tensor &tmat,
framework::SelectedRows *vec)
: tmat_(tmat), vec_(vec) {}
template <typename CodeTable>
void operator()(const CodeTable &code_table) {
size_t batch_size = tmat_.dims()[0];
size_t width = tmat_.dims()[1];
auto *vec_data = vec_->mutable_value()->template data<T>();
auto *tmat_data = tmat_.data<T>();
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table.get_code(i);
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
int64_t row_index = vec_->GetIndexFromId(static_cast<int64_t>(index));
vec_data[row_index] += tmat_data[i * width + j];
}
}
}
};
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor &tmat,
framework::SelectedRows *vec) {
MatrixBitCodeFunctorSelectedRowsAddGrad<T> func(tmat, vec);
code_table_.apply_visitor(func);
}
template <typename T>
struct MatrixBitCodeFunctorSum : public boost::static_visitor<void> {
const framework::Tensor &tmat_;
......
......@@ -124,11 +124,12 @@ class SimpleCode {
template <typename T>
class CustomCode {
public:
CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode,
const int64_t* ids, int index) {
seq_len_ = ptable.dims()[1];
ptable_data_ = ptable.data<T>() + seq_len_ * index;
pcode_data_ = pcode.data<T>() + seq_len_ * index;
CustomCode(const framework::Tensor& path_table,
const framework::Tensor& path_code, const int64_t* ids,
int index) {
seq_len_ = path_table.dims()[1];
path_table_data_ = path_table.data<T>() + seq_len_ * index;
path_code_data_ = path_code.data<T>() + seq_len_ * index;
}
/**
* Here the id of root should be 1 rather than 0, thus the encoding of class c
......@@ -139,25 +140,25 @@ class CustomCode {
* Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit.
*/
size_t calc_index(int bit) const { return ptable_data_[bit]; }
bool calc_bit(int bit) const { return pcode_data_[bit]; }
size_t calc_index(int bit) const { return path_table_data_[bit]; }
bool calc_bit(int bit) const { return path_code_data_[bit]; }
// NOTE: this function is not thread-safe.
int get_length() const {
if (length_ < 0) {
auto len = seq_len_;
length_ =
static_cast<int>(std::find_if(ptable_data_, ptable_data_ + len,
[](const T& val) { return val < 0; }) -
ptable_data_);
length_ = static_cast<int>(
std::find_if(path_table_data_, path_table_data_ + len,
[](const T& val) { return val < 0; }) -
path_table_data_);
}
return length_;
}
private:
int64_t seq_len_;
const T* ptable_data_;
const T* pcode_data_;
const T* path_table_data_;
const T* path_code_data_;
mutable int length_{-1};
};
......@@ -214,7 +215,7 @@ class MatrixBitCodeFunctor {
const framework::Tensor& path_code, const int64_t* ids)
: num_classes_(static_cast<size_t>(path_table.dims()[1])),
ids_(ids),
code_table_(CustomCodeTable<int64_t>(ptable, pcode, ids)) {}
code_table_(CustomCodeTable<int64_t>(path_table, path_code, ids)) {}
/* For j < code_length
tmat(i, j) += vec(0, index(i, j))
*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册