提交 2e3ee579 编写于 作者: 翟飞跃 提交者: Tao Luo

Use sparse matrix to implement FusedEmbeddingSeqPoolGradKernel (#19153)

* Implement the operator with sprase matrix multiply

* Update the URL of mklml library.

test=develop

* Disable MKLML implematation when using no-linux.

test=develop

* optimize bp with mkl sparse matrix
test=develop
上级 a9d5fc51
...@@ -225,7 +225,52 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -225,7 +225,52 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
vbroadcast(src, dst, h, out_width); vbroadcast(src, dst, h, out_width);
} }
} else { } else {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
d_table->Resize(table_dim);
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
memset(d_table_data, 0, d_table->numel() * sizeof(T));
const auto &ids_lod = ids->lod();
PADDLE_ENFORCE(ids_lod.size(), 1UL,
"The LoD level of Input(Ids) must be 1");
const std::vector<uint64_t> offset = ids_lod[0];
auto len = ids->numel();
int idx_width = len / offset.back();
Tensor csr_vals_t, csr_colmuns_t, csr_row_idx_t;
csr_vals_t.Resize({len});
csr_colmuns_t.Resize({len});
int64_t batch_size = ids_lod[0].size() - 1;
csr_row_idx_t.Resize({(batch_size + 1) * idx_width});
auto csr_vals = csr_vals_t.mutable_data<T>(context.GetPlace());
auto csr_colmuns = csr_colmuns_t.mutable_data<int>(context.GetPlace());
auto csr_row_idx = csr_row_idx_t.mutable_data<int>(context.GetPlace());
prepare_csr_data<T>(offset, ids->data<int64_t>(), idx_width, csr_vals,
csr_colmuns, csr_row_idx);
auto *d_output_data = d_output->data<T>();
const char transa = 'T';
const T alpha = 1.0;
const T beta = 0.0;
const char matdescra[] = {'G', 'L', 'N', 'C'};
const int m = batch_size * idx_width;
const int n = table_dim[1];
const int k = table_dim[1];
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
blas.CSRMM(&transa, &m, &n, &k, &alpha, matdescra, (const T *)csr_vals,
(const int *)csr_colmuns, (const int *)csr_row_idx,
(const int *)csr_row_idx + 1, d_output_data, &n, &beta,
d_table_data, &n);
#else
LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now"; LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now";
#endif
} }
} }
}; };
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import platform
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -46,6 +47,13 @@ class TestFusedEmbeddingSeqPoolOp(OpTest): ...@@ -46,6 +47,13 @@ class TestFusedEmbeddingSeqPoolOp(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
if fluid.core.is_compiled_with_mkldnn(
) and not fluid.core.is_compiled_with_cuda(
) and 'Linux' in platform.platform():
self.attrs = {'is_sparse': False}
self.check_grad(['W'], 'Out', no_grad_set=('Ids'))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册