diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index 7565823fc6b377a13e800ab46923c81736ec460e..e624b2ffdb54d06d5c0b6a915b90129865fae9e0 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -225,7 +225,52 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { vbroadcast(src, dst, h, out_width); } } else { +#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ + !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) + auto *ids = context.Input("Ids"); + auto *d_output = context.Input(framework::GradVarName("Out")); + auto *d_table = context.Output(framework::GradVarName("W")); + + d_table->Resize(table_dim); + auto *d_table_data = d_table->mutable_data(context.GetPlace()); + memset(d_table_data, 0, d_table->numel() * sizeof(T)); + + const auto &ids_lod = ids->lod(); + PADDLE_ENFORCE(ids_lod.size(), 1UL, + "The LoD level of Input(Ids) must be 1"); + const std::vector offset = ids_lod[0]; + auto len = ids->numel(); + int idx_width = len / offset.back(); + + Tensor csr_vals_t, csr_colmuns_t, csr_row_idx_t; + csr_vals_t.Resize({len}); + csr_colmuns_t.Resize({len}); + int64_t batch_size = ids_lod[0].size() - 1; + csr_row_idx_t.Resize({(batch_size + 1) * idx_width}); + auto csr_vals = csr_vals_t.mutable_data(context.GetPlace()); + auto csr_colmuns = csr_colmuns_t.mutable_data(context.GetPlace()); + auto csr_row_idx = csr_row_idx_t.mutable_data(context.GetPlace()); + prepare_csr_data(offset, ids->data(), idx_width, csr_vals, + csr_colmuns, csr_row_idx); + + auto *d_output_data = d_output->data(); + const char transa = 'T'; + const T alpha = 1.0; + const T beta = 0.0; + const char matdescra[] = {'G', 'L', 'N', 'C'}; + + const int m = batch_size * idx_width; + const int n = table_dim[1]; + const int k = table_dim[1]; + + auto blas = math::GetBlas(context); + blas.CSRMM(&transa, &m, &n, &k, &alpha, matdescra, (const T *)csr_vals, + (const int *)csr_colmuns, (const int *)csr_row_idx, + (const int *)csr_row_idx + 1, d_output_data, &n, &beta, + d_table_data, &n); +#else LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now"; +#endif } } }; diff --git a/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py b/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py index 584e309befcee18ad913d935c803fdd387a92745..69c550a4ea13b1e4ff3088e3002e731f467d9e7e 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest +import platform import numpy as np from op_test import OpTest import paddle.fluid.core as core @@ -46,6 +47,13 @@ class TestFusedEmbeddingSeqPoolOp(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + if fluid.core.is_compiled_with_mkldnn( + ) and not fluid.core.is_compiled_with_cuda( + ) and 'Linux' in platform.platform(): + self.attrs = {'is_sparse': False} + self.check_grad(['W'], 'Out', no_grad_set=('Ids')) + if __name__ == "__main__": unittest.main()