提交 81e14576 编写于 作者: J JiabinYang

refine code and comments, test=develop

上级 2f6b529a
......@@ -30,14 +30,14 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform;
std::vector<int64_t> cal_rows(const framework::LoDTensor* path) {
std::vector<int64_t> cal_rows(const framework::LoDTensor& path) {
std::set<int64_t> tmp;
std::vector<int64_t> rows;
rows.clear();
for (size_t i = 0; i < static_cast<size_t>(path->dims()[0]); i++) {
for (size_t j = 0; j < static_cast<size_t>(path->dims()[1]); j++) {
for (size_t i = 0; i < static_cast<size_t>(path.dims()[0]); i++) {
for (size_t j = 0; j < static_cast<size_t>(path.dims()[1]); j++) {
int64_t temp =
path->data<int64_t>()[i * static_cast<size_t>(path->dims()[1]) + j];
path.data<int64_t>()[i * static_cast<size_t>(path.dims()[1]) + j];
if (temp >= 0) {
tmp.insert(temp);
}
......@@ -188,7 +188,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
zero(dev_ctx, w_grad, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, *in);
} else {
framework::Vector<int64_t> real_rows = cal_rows(path);
framework::Vector<int64_t> real_rows = cal_rows(*path);
auto* w_grad =
ctx.Output<framework::SelectedRows>(framework::GradVarName("W"));
w_grad->set_rows(real_rows);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册