From 93c85c930a9e23229a73c5d8300d1587a4e5e2e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BF=9F=E9=A3=9E=E8=B7=83?= <34468585+zhaify@users.noreply.github.com> Date: Tue, 17 Sep 2019 09:48:01 +0800 Subject: [PATCH] Implement FusedEmbeddingSeqPoolGradKernel with cblas_saxpy (#19770) * Implement the operator with sprase matrix multiply * Update the URL of mklml library. test=develop * Disable MKLML implematation when using no-linux. test=develop * optimize bp with mkl sparse matrix test=develop * tmp add fused_emb_seq layer * Add the support of padding_idx attribute. test=develop * add padding_idx support test=develop * implement grad refer lego test=develop --- .../fused/fused_embedding_seq_pool_op.cc | 6 ++ .../fused/fused_embedding_seq_pool_op.h | 57 ++++++++++--------- .../unittests/test_fused_emb_seq_pool_op.py | 51 +++++++++++++---- 3 files changed, 77 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index 3ee962d37b..9110099013 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -78,6 +78,12 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker { "are supported, sum computes the weighted sum of the " "embedding results for each row.") .SetDefault("sum"); + AddAttr("padding_idx", + "(int64, default -1) " + "If the value is -1, it makes no effect to lookup. " + "Otherwise the given value indicates padding the output " + "with zeros whenever lookup encounters it in Ids.") + .SetDefault(kNoPadding); // NOTE(minqiyang): grad_inplace is an temporal attribute, // please do NOT set this attribute in python layer. AddAttr("grad_inplace", diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index e624b2ffdb..3fffdf7e02 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -33,12 +33,15 @@ using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; +constexpr int64_t kNoPadding = -1; + #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ - !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) + !defined(__OSX__) template void prepare_csr_data(const std::vector &offset, const int64_t *ids_data, const size_t idx_width, - T *csr_vals, int *csr_colmuns, int *csr_row_idx) { + T *csr_vals, int *csr_colmuns, int *csr_row_idx, + int64_t padding_idx = kNoPadding) { int val_idx = 0; int row_idx = 0; csr_row_idx[0] = 0; @@ -52,9 +55,11 @@ void prepare_csr_data(const std::vector &offset, // construct a map for creating csr for (size_t j = offset[i]; j < offset[i + 1]; ++j) { - unsigned int word_idx = - static_cast(ids_data[idx + j * idx_width]); - ++ids_map[word_idx]; + auto ids_value = ids_data[idx + j * idx_width]; + if (ids_value != padding_idx) { + unsigned int word_idx = static_cast(ids_value); + ++ids_map[word_idx]; + } } VLOG(4) << "====sequence %d====" << i; @@ -124,8 +129,8 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims()); const auto &ids_lod = ids_t->lod(); // in run time, the LoD of ids must be 1 - PADDLE_ENFORCE(ids_lod.size(), 1UL, - "The LoD level of Input(Ids) must be 1"); + PADDLE_ENFORCE_EQ(ids_lod.size(), 1UL, + "The LoD level of Input(Ids) must be 1"); int64_t batch_size = ids_lod[0].size() - 1; // in run time, the shape from Ids -> output // should be [seq_length, 1] -> [batch_size, last_dim] @@ -133,7 +138,8 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { if (combiner_type == "sum") { #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ - !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) + !defined(__OSX__) + int64_t padding_idx = context.Attr("padding_idx"); auto output = output_t->mutable_data(context.GetPlace()); int64_t table_height = table_var->dims()[0]; int64_t table_width = table_var->dims()[1]; @@ -151,7 +157,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { auto csr_colmuns = csr_colmuns_t.mutable_data(context.GetPlace()); auto csr_row_idx = csr_row_idx_t.mutable_data(context.GetPlace()); prepare_csr_data(offset, ids_t->data(), idx_width, csr_vals, - csr_colmuns, csr_row_idx); + csr_colmuns, csr_row_idx, padding_idx); const char transa = 'N'; const T alpha = 1.0; @@ -226,18 +232,19 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { } } else { #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ - !defined(__OSX__) && !defined(PADDLE_WITH_CUDA) + !defined(__OSX__) auto *ids = context.Input("Ids"); auto *d_output = context.Input(framework::GradVarName("Out")); auto *d_table = context.Output(framework::GradVarName("W")); + int64_t padding_idx = context.Attr("padding_idx"); d_table->Resize(table_dim); auto *d_table_data = d_table->mutable_data(context.GetPlace()); memset(d_table_data, 0, d_table->numel() * sizeof(T)); const auto &ids_lod = ids->lod(); - PADDLE_ENFORCE(ids_lod.size(), 1UL, - "The LoD level of Input(Ids) must be 1"); + PADDLE_ENFORCE_EQ(ids_lod.size(), 1UL, + "The LoD level of Input(Ids) must be 1"); const std::vector offset = ids_lod[0]; auto len = ids->numel(); int idx_width = len / offset.back(); @@ -251,23 +258,21 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { auto csr_colmuns = csr_colmuns_t.mutable_data(context.GetPlace()); auto csr_row_idx = csr_row_idx_t.mutable_data(context.GetPlace()); prepare_csr_data(offset, ids->data(), idx_width, csr_vals, - csr_colmuns, csr_row_idx); + csr_colmuns, csr_row_idx, padding_idx); auto *d_output_data = d_output->data(); - const char transa = 'T'; - const T alpha = 1.0; - const T beta = 0.0; - const char matdescra[] = {'G', 'L', 'N', 'C'}; - - const int m = batch_size * idx_width; - const int n = table_dim[1]; - const int k = table_dim[1]; - auto blas = math::GetBlas(context); - blas.CSRMM(&transa, &m, &n, &k, &alpha, matdescra, (const T *)csr_vals, - (const int *)csr_colmuns, (const int *)csr_row_idx, - (const int *)csr_row_idx + 1, d_output_data, &n, &beta, - d_table_data, &n); + int width = static_cast(table_dim[1]); + int num_seq = batch_size * idx_width; + LOG(INFO) << "num seq = " << num_seq << " width = " << width; + for (int i = 0; i < num_seq; ++i) { + for (int j = csr_row_idx[i]; j < csr_row_idx[i + 1]; ++j) { + unsigned int word_idx = csr_colmuns[j]; + T val = csr_vals[j]; + blas.AXPY(width, val, d_output_data + i * width, + d_table_data + word_idx * width); + } + } #else LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now"; #endif diff --git a/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py b/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py index 69c550a4ea..09523d65b4 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py @@ -22,25 +22,25 @@ import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid.op import Operator import paddle.compat as cpt +import paddle.version as ver class TestFusedEmbeddingSeqPoolOp(OpTest): def setUp(self): self.op_type = "fused_embedding_seq_pool" self.emb_size = 2 - table = np.random.random((17, self.emb_size)).astype("float32") - ids = np.array([[[4], [3]], [[4], [3]], [[2], [1]], - [[16], [1]]]).astype("int64") - merged_ids = np.array([4, 2, 16]).astype("int64") - ids_expand = np.expand_dims(ids, axis=1) + self.table = np.random.random((17, self.emb_size)).astype("float32") + self.ids = np.array([[[4], [3]], [[4], [3]], [[2], [1]], + [[16], [1]]]).astype("int64") + ids_expand = np.expand_dims(self.ids, axis=1) self.lod = [[3, 1]] self.attrs = {'is_sparse': True} - self.inputs = {'W': table, 'Ids': (ids_expand, self.lod)} + self.inputs = {'W': self.table, 'Ids': (ids_expand, self.lod)} self.outputs = { 'Out': np.reshape( np.array([ - table[[4, 3]] + table[[4, 3]] + table[[2, 1]], - table[[16, 1]] + self.table[[4, 3]] + self.table[[4, 3]] + + self.table[[2, 1]], self.table[[16, 1]] ]), [len(self.lod[0]), 2 * self.emb_size]) } @@ -48,12 +48,41 @@ class TestFusedEmbeddingSeqPoolOp(OpTest): self.check_output() def test_check_grad(self): - if fluid.core.is_compiled_with_mkldnn( - ) and not fluid.core.is_compiled_with_cuda( - ) and 'Linux' in platform.platform(): + if ver.mkl() == "ON" and 'Linux' in platform.platform(): self.attrs = {'is_sparse': False} self.check_grad(['W'], 'Out', no_grad_set=('Ids')) +class TestLookupTableOpWithPadding(TestFusedEmbeddingSeqPoolOp): + def test_check_output(self): + if ver.mkl() == "ON" and 'Linux' in platform.platform(): + ids = np.squeeze(self.ids, axis=2) + padding_idx = np.random.choice(ids.flatten(), 1)[0] + output = list() + index = 0 + for count in self.lod[0]: + arr = ids[index:count + index] + out = np.reshape(self.table[arr.flatten()], + [arr.shape[0], arr.shape[1], self.emb_size]) + idx = np.argwhere(arr == padding_idx) + for item in idx: + out[item[0], item[1], :] = np.zeros(self.emb_size) + output.append(np.sum(out, 0)) + index += count + self.outputs = { + 'Out': np.reshape( + np.array(output), [len(self.lod[0]), 2 * self.emb_size]) + } + self.attrs = {'padding_idx': int(padding_idx)} + self.check_output() + + def test_check_grad(self): + if ver.mkl() == "ON" and 'Linux' in platform.platform(): + ids = np.squeeze(self.ids, axis=2) + padding_idx = np.random.choice(ids.flatten(), 1)[0] + self.attrs = {'padding_idx': int(padding_idx), 'is_sparse': False} + self.check_grad(['W'], 'Out', no_grad_set=('Ids')) + + if __name__ == "__main__": unittest.main() -- GitLab