提交 b9203958 编写于 作者: Y Yihua Xu 提交者: Tao Luo

Use sparse matrix to implement fused emb_seq_pool operator (#19064)

* Implement the operator with sprase matrix multiply

* Update the URL of mklml library.

test=develop

* Disable MKLML implematation when using no-linux.

test=develop

* Ignore the deprecated status for windows

test=develop
上级 6e326ca2
...@@ -39,6 +39,8 @@ if(WIN32) ...@@ -39,6 +39,8 @@ if(WIN32)
set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}") set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}")
else(WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations")
endif(WIN32) endif(WIN32)
find_package(CUDA QUIET) find_package(CUDA QUIET)
......
...@@ -43,7 +43,7 @@ IF(WIN32) ...@@ -43,7 +43,7 @@ IF(WIN32)
ELSE() ELSE()
#TODO(intel-huying): #TODO(intel-huying):
# Now enable Erf function in mklml library temporarily, it will be updated as offical version later. # Now enable Erf function in mklml library temporarily, it will be updated as offical version later.
SET(MKLML_VER "Glibc225_vsErf_mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE) SET(MKLML_VER "csrmm2_mklml_lnx_2019.0.2" CACHE STRING "" FORCE)
SET(MKLML_URL "http://paddlepaddledeps.bj.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) SET(MKLML_URL "http://paddlepaddledeps.bj.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -22,6 +23,7 @@ limitations under the License. */ ...@@ -22,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -31,6 +33,44 @@ using LoDTensor = framework::LoDTensor; ...@@ -31,6 +33,44 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
template <typename T>
void prepare_csr_data(const std::vector<uint64_t> &offset,
const int64_t *ids_data, const size_t idx_width,
T *csr_vals, int *csr_colmuns, int *csr_row_idx) {
int val_idx = 0;
int row_idx = 0;
csr_row_idx[0] = 0;
std::map<int, int> ids_map;
// for each sequence in batch
for (size_t i = 0; i < offset.size() - 1; ++i) {
for (size_t idx = 0; idx < idx_width; ++idx) {
ids_map.clear();
// construct a map for creating csr
for (size_t j = offset[i]; j < offset[i + 1]; ++j) {
unsigned int word_idx =
static_cast<unsigned int>(ids_data[idx + j * idx_width]);
++ids_map[word_idx];
}
VLOG(4) << "====sequence %d====" << i;
for (std::map<int, int>::const_iterator it = ids_map.begin();
it != ids_map.end(); ++it) {
VLOG(4) << it->first << " => " << it->second;
csr_vals[val_idx] = it->second;
csr_colmuns[val_idx] = it->first;
++val_idx;
}
csr_row_idx[row_idx + 1] = csr_row_idx[row_idx] + ids_map.size();
++row_idx;
}
}
}
#else
template <typename T> template <typename T>
struct EmbeddingVSumFunctor { struct EmbeddingVSumFunctor {
void operator()(const framework::ExecutionContext &context, void operator()(const framework::ExecutionContext &context,
...@@ -60,6 +100,7 @@ struct EmbeddingVSumFunctor { ...@@ -60,6 +100,7 @@ struct EmbeddingVSumFunctor {
} }
} }
}; };
#endif
inline int FusedEmbeddingSeqPoolLastDim(const framework::DDim &table_dims, inline int FusedEmbeddingSeqPoolLastDim(const framework::DDim &table_dims,
const framework::DDim &ids_dims) { const framework::DDim &ids_dims) {
...@@ -91,8 +132,44 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { ...@@ -91,8 +132,44 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
output_t->Resize({batch_size, last_dim}); output_t->Resize({batch_size, last_dim});
if (combiner_type == "sum") { if (combiner_type == "sum") {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
auto output = output_t->mutable_data<T>(context.GetPlace());
int64_t table_height = table_var->dims()[0];
int64_t table_width = table_var->dims()[1];
auto weights = table_var->data<T>();
const std::vector<uint64_t> offset = ids_lod[0];
auto len = ids_t->numel();
int idx_width = len / offset.back();
Tensor csr_vals_t, csr_colmuns_t, csr_row_idx_t;
csr_vals_t.Resize({len});
csr_colmuns_t.Resize({len});
csr_row_idx_t.Resize({(batch_size + 1) * idx_width});
auto csr_vals = csr_vals_t.mutable_data<T>(context.GetPlace());
auto csr_colmuns = csr_colmuns_t.mutable_data<int>(context.GetPlace());
auto csr_row_idx = csr_row_idx_t.mutable_data<int>(context.GetPlace());
prepare_csr_data<T>(offset, ids_t->data<int64_t>(), idx_width, csr_vals,
csr_colmuns, csr_row_idx);
const char transa = 'N';
const T alpha = 1.0;
const T beta = 0.0;
const char matdescra[] = {'G', 'L', 'N', 'C'};
const int m = batch_size * idx_width;
const int n = table_width;
const int k = table_height;
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
blas.CSRMM(&transa, &m, &n, &k, &alpha, matdescra, (const T *)csr_vals,
(const int *)csr_colmuns, (const int *)csr_row_idx,
(const int *)csr_row_idx + 1, weights, &n, &beta, output, &n);
#else
EmbeddingVSumFunctor<T> functor; EmbeddingVSumFunctor<T> functor;
functor(context, table_var, ids_t, output_t); functor(context, table_var, ids_t, output_t);
#endif
} }
} }
}; };
......
...@@ -113,6 +113,12 @@ class Blas { ...@@ -113,6 +113,12 @@ class Blas {
template <typename T> template <typename T>
void GEMM_FREE(T* data) const; void GEMM_FREE(T* data) const;
template <typename T>
void CSRMM(const char* transa, const int* m, const int* n, const int* k,
const T* alpha, const char* matdescra, const T* val,
const int* indx, const int* pntrb, const int* pntre, const T* b,
const int* ldb, const T* beta, T* c, const int* ldc) const;
#if !defined(PADDLE_WITH_CUDA) #if !defined(PADDLE_WITH_CUDA)
template <typename T> template <typename T>
void MatMulWithHead(const framework::Tensor& mat_a, void MatMulWithHead(const framework::Tensor& mat_a,
...@@ -239,6 +245,11 @@ class BlasT : private Blas<DeviceContext> { ...@@ -239,6 +245,11 @@ class BlasT : private Blas<DeviceContext> {
Base()->template GEMM_FREE<T>(args...); Base()->template GEMM_FREE<T>(args...);
} }
template <typename... ARGS>
void CSRMM(ARGS... args) const {
Base()->template CSRMM<T>(args...);
}
#if !defined(PADDLE_WITH_CUDA) #if !defined(PADDLE_WITH_CUDA)
template <typename... ARGS> template <typename... ARGS>
void MatMulWithHead(ARGS... args) const { void MatMulWithHead(ARGS... args) const {
......
...@@ -128,6 +128,11 @@ struct CBlas<float> { ...@@ -128,6 +128,11 @@ struct CBlas<float> {
static void VMERF(ARGS... args) { static void VMERF(ARGS... args) {
platform::dynload::vmsErf(args...); platform::dynload::vmsErf(args...);
} }
template <typename... ARGS>
static void CSRMM(ARGS... args) {
platform::dynload::mkl_scsrmm(args...);
}
}; };
template <> template <>
...@@ -233,6 +238,11 @@ struct CBlas<double> { ...@@ -233,6 +238,11 @@ struct CBlas<double> {
static void VMERF(ARGS... args) { static void VMERF(ARGS... args) {
platform::dynload::vmdErf(args...); platform::dynload::vmdErf(args...);
} }
template <typename... ARGS>
static void CSRMM(ARGS... args) {
platform::dynload::mkl_dcsrmm(args...);
}
}; };
#else #else
...@@ -748,6 +758,19 @@ void Blas<platform::CPUDeviceContext>::VMERF(int n, const T *a, T *y, ...@@ -748,6 +758,19 @@ void Blas<platform::CPUDeviceContext>::VMERF(int n, const T *a, T *y,
#endif #endif
} }
#ifdef PADDLE_WITH_MKLML
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::CSRMM(
const char *transa, const int *m, const int *n, const int *k,
const T *alpha, const char *matdescra, const T *val, const int *indx,
const int *pntrb, const int *pntre, const T *b, const int *ldb,
const T *beta, T *c, const int *ldc) const {
CBlas<T>::CSRMM(transa, m, n, k, alpha, matdescra, val, indx, pntrb, pntre, b,
ldb, beta, c, ldc);
}
#endif
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -88,6 +88,8 @@ extern void* mklml_dso_handle; ...@@ -88,6 +88,8 @@ extern void* mklml_dso_handle;
__macro(vdInv); \ __macro(vdInv); \
__macro(vmsErf); \ __macro(vmsErf); \
__macro(vmdErf); \ __macro(vmdErf); \
__macro(mkl_scsrmm); \
__macro(mkl_dcsrmm); \
__macro(MKL_Set_Num_Threads) __macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册