提交 93c85c93 编写于 作者: 翟飞跃 提交者: Tao Luo

Implement FusedEmbeddingSeqPoolGradKernel with cblas_saxpy (#19770)

* Implement the operator with sprase matrix multiply

* Update the URL of mklml library.

test=develop

* Disable MKLML implematation when using no-linux.

test=develop

* optimize bp with mkl sparse matrix
test=develop

* tmp add fused_emb_seq layer

* Add the support of padding_idx attribute.

test=develop

* add padding_idx support
test=develop

* implement grad refer lego
test=develop
上级 2729c174
...@@ -78,6 +78,12 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -78,6 +78,12 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"are supported, sum computes the weighted sum of the " "are supported, sum computes the weighted sum of the "
"embedding results for each row.") "embedding results for each row.")
.SetDefault("sum"); .SetDefault("sum");
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(kNoPadding);
// NOTE(minqiyang): grad_inplace is an temporal attribute, // NOTE(minqiyang): grad_inplace is an temporal attribute,
// please do NOT set this attribute in python layer. // please do NOT set this attribute in python layer.
AddAttr<bool>("grad_inplace", AddAttr<bool>("grad_inplace",
......
...@@ -33,12 +33,15 @@ using LoDTensor = framework::LoDTensor; ...@@ -33,12 +33,15 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
constexpr int64_t kNoPadding = -1;
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA) !defined(__OSX__)
template <typename T> template <typename T>
void prepare_csr_data(const std::vector<uint64_t> &offset, void prepare_csr_data(const std::vector<uint64_t> &offset,
const int64_t *ids_data, const size_t idx_width, const int64_t *ids_data, const size_t idx_width,
T *csr_vals, int *csr_colmuns, int *csr_row_idx) { T *csr_vals, int *csr_colmuns, int *csr_row_idx,
int64_t padding_idx = kNoPadding) {
int val_idx = 0; int val_idx = 0;
int row_idx = 0; int row_idx = 0;
csr_row_idx[0] = 0; csr_row_idx[0] = 0;
...@@ -52,10 +55,12 @@ void prepare_csr_data(const std::vector<uint64_t> &offset, ...@@ -52,10 +55,12 @@ void prepare_csr_data(const std::vector<uint64_t> &offset,
// construct a map for creating csr // construct a map for creating csr
for (size_t j = offset[i]; j < offset[i + 1]; ++j) { for (size_t j = offset[i]; j < offset[i + 1]; ++j) {
unsigned int word_idx = auto ids_value = ids_data[idx + j * idx_width];
static_cast<unsigned int>(ids_data[idx + j * idx_width]); if (ids_value != padding_idx) {
unsigned int word_idx = static_cast<unsigned int>(ids_value);
++ids_map[word_idx]; ++ids_map[word_idx];
} }
}
VLOG(4) << "====sequence %d====" << i; VLOG(4) << "====sequence %d====" << i;
for (std::map<int, int>::const_iterator it = ids_map.begin(); for (std::map<int, int>::const_iterator it = ids_map.begin();
...@@ -124,7 +129,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { ...@@ -124,7 +129,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims()); FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims());
const auto &ids_lod = ids_t->lod(); const auto &ids_lod = ids_t->lod();
// in run time, the LoD of ids must be 1 // in run time, the LoD of ids must be 1
PADDLE_ENFORCE(ids_lod.size(), 1UL, PADDLE_ENFORCE_EQ(ids_lod.size(), 1UL,
"The LoD level of Input(Ids) must be 1"); "The LoD level of Input(Ids) must be 1");
int64_t batch_size = ids_lod[0].size() - 1; int64_t batch_size = ids_lod[0].size() - 1;
// in run time, the shape from Ids -> output // in run time, the shape from Ids -> output
...@@ -133,7 +138,8 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { ...@@ -133,7 +138,8 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
if (combiner_type == "sum") { if (combiner_type == "sum") {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA) !defined(__OSX__)
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto output = output_t->mutable_data<T>(context.GetPlace()); auto output = output_t->mutable_data<T>(context.GetPlace());
int64_t table_height = table_var->dims()[0]; int64_t table_height = table_var->dims()[0];
int64_t table_width = table_var->dims()[1]; int64_t table_width = table_var->dims()[1];
...@@ -151,7 +157,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { ...@@ -151,7 +157,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
auto csr_colmuns = csr_colmuns_t.mutable_data<int>(context.GetPlace()); auto csr_colmuns = csr_colmuns_t.mutable_data<int>(context.GetPlace());
auto csr_row_idx = csr_row_idx_t.mutable_data<int>(context.GetPlace()); auto csr_row_idx = csr_row_idx_t.mutable_data<int>(context.GetPlace());
prepare_csr_data<T>(offset, ids_t->data<int64_t>(), idx_width, csr_vals, prepare_csr_data<T>(offset, ids_t->data<int64_t>(), idx_width, csr_vals,
csr_colmuns, csr_row_idx); csr_colmuns, csr_row_idx, padding_idx);
const char transa = 'N'; const char transa = 'N';
const T alpha = 1.0; const T alpha = 1.0;
...@@ -226,17 +232,18 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -226,17 +232,18 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
} }
} else { } else {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \ #if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA) !defined(__OSX__)
auto *ids = context.Input<LoDTensor>("Ids"); auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W")); auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
d_table->Resize(table_dim); d_table->Resize(table_dim);
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace()); auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
memset(d_table_data, 0, d_table->numel() * sizeof(T)); memset(d_table_data, 0, d_table->numel() * sizeof(T));
const auto &ids_lod = ids->lod(); const auto &ids_lod = ids->lod();
PADDLE_ENFORCE(ids_lod.size(), 1UL, PADDLE_ENFORCE_EQ(ids_lod.size(), 1UL,
"The LoD level of Input(Ids) must be 1"); "The LoD level of Input(Ids) must be 1");
const std::vector<uint64_t> offset = ids_lod[0]; const std::vector<uint64_t> offset = ids_lod[0];
auto len = ids->numel(); auto len = ids->numel();
...@@ -251,23 +258,21 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -251,23 +258,21 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
auto csr_colmuns = csr_colmuns_t.mutable_data<int>(context.GetPlace()); auto csr_colmuns = csr_colmuns_t.mutable_data<int>(context.GetPlace());
auto csr_row_idx = csr_row_idx_t.mutable_data<int>(context.GetPlace()); auto csr_row_idx = csr_row_idx_t.mutable_data<int>(context.GetPlace());
prepare_csr_data<T>(offset, ids->data<int64_t>(), idx_width, csr_vals, prepare_csr_data<T>(offset, ids->data<int64_t>(), idx_width, csr_vals,
csr_colmuns, csr_row_idx); csr_colmuns, csr_row_idx, padding_idx);
auto *d_output_data = d_output->data<T>(); auto *d_output_data = d_output->data<T>();
const char transa = 'T';
const T alpha = 1.0;
const T beta = 0.0;
const char matdescra[] = {'G', 'L', 'N', 'C'};
const int m = batch_size * idx_width;
const int n = table_dim[1];
const int k = table_dim[1];
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
blas.CSRMM(&transa, &m, &n, &k, &alpha, matdescra, (const T *)csr_vals, int width = static_cast<int>(table_dim[1]);
(const int *)csr_colmuns, (const int *)csr_row_idx, int num_seq = batch_size * idx_width;
(const int *)csr_row_idx + 1, d_output_data, &n, &beta, LOG(INFO) << "num seq = " << num_seq << " width = " << width;
d_table_data, &n); for (int i = 0; i < num_seq; ++i) {
for (int j = csr_row_idx[i]; j < csr_row_idx[i + 1]; ++j) {
unsigned int word_idx = csr_colmuns[j];
T val = csr_vals[j];
blas.AXPY(width, val, d_output_data + i * width,
d_table_data + word_idx * width);
}
}
#else #else
LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now"; LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now";
#endif #endif
......
...@@ -22,25 +22,25 @@ import paddle.fluid.core as core ...@@ -22,25 +22,25 @@ import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
import paddle.compat as cpt import paddle.compat as cpt
import paddle.version as ver
class TestFusedEmbeddingSeqPoolOp(OpTest): class TestFusedEmbeddingSeqPoolOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fused_embedding_seq_pool" self.op_type = "fused_embedding_seq_pool"
self.emb_size = 2 self.emb_size = 2
table = np.random.random((17, self.emb_size)).astype("float32") self.table = np.random.random((17, self.emb_size)).astype("float32")
ids = np.array([[[4], [3]], [[4], [3]], [[2], [1]], self.ids = np.array([[[4], [3]], [[4], [3]], [[2], [1]],
[[16], [1]]]).astype("int64") [[16], [1]]]).astype("int64")
merged_ids = np.array([4, 2, 16]).astype("int64") ids_expand = np.expand_dims(self.ids, axis=1)
ids_expand = np.expand_dims(ids, axis=1)
self.lod = [[3, 1]] self.lod = [[3, 1]]
self.attrs = {'is_sparse': True} self.attrs = {'is_sparse': True}
self.inputs = {'W': table, 'Ids': (ids_expand, self.lod)} self.inputs = {'W': self.table, 'Ids': (ids_expand, self.lod)}
self.outputs = { self.outputs = {
'Out': np.reshape( 'Out': np.reshape(
np.array([ np.array([
table[[4, 3]] + table[[4, 3]] + table[[2, 1]], self.table[[4, 3]] + self.table[[4, 3]] +
table[[16, 1]] self.table[[2, 1]], self.table[[16, 1]]
]), [len(self.lod[0]), 2 * self.emb_size]) ]), [len(self.lod[0]), 2 * self.emb_size])
} }
...@@ -48,12 +48,41 @@ class TestFusedEmbeddingSeqPoolOp(OpTest): ...@@ -48,12 +48,41 @@ class TestFusedEmbeddingSeqPoolOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
if fluid.core.is_compiled_with_mkldnn( if ver.mkl() == "ON" and 'Linux' in platform.platform():
) and not fluid.core.is_compiled_with_cuda(
) and 'Linux' in platform.platform():
self.attrs = {'is_sparse': False} self.attrs = {'is_sparse': False}
self.check_grad(['W'], 'Out', no_grad_set=('Ids')) self.check_grad(['W'], 'Out', no_grad_set=('Ids'))
class TestLookupTableOpWithPadding(TestFusedEmbeddingSeqPoolOp):
def test_check_output(self):
if ver.mkl() == "ON" and 'Linux' in platform.platform():
ids = np.squeeze(self.ids, axis=2)
padding_idx = np.random.choice(ids.flatten(), 1)[0]
output = list()
index = 0
for count in self.lod[0]:
arr = ids[index:count + index]
out = np.reshape(self.table[arr.flatten()],
[arr.shape[0], arr.shape[1], self.emb_size])
idx = np.argwhere(arr == padding_idx)
for item in idx:
out[item[0], item[1], :] = np.zeros(self.emb_size)
output.append(np.sum(out, 0))
index += count
self.outputs = {
'Out': np.reshape(
np.array(output), [len(self.lod[0]), 2 * self.emb_size])
}
self.attrs = {'padding_idx': int(padding_idx)}
self.check_output()
def test_check_grad(self):
if ver.mkl() == "ON" and 'Linux' in platform.platform():
ids = np.squeeze(self.ids, axis=2)
padding_idx = np.random.choice(ids.flatten(), 1)[0]
self.attrs = {'padding_idx': int(padding_idx), 'is_sparse': False}
self.check_grad(['W'], 'Out', no_grad_set=('Ids'))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册