提交 62be492e 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fixed bfloat16 integration of LIBXSMM sparse mat-mul.

Change: 149617825
上级 89df2a1b
......@@ -1522,21 +1522,24 @@ void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
void wrapper_libxsmm_spmdm_compute_generic_thread(
empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
char transA, char transB, const bfloat16* alpha,
libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
char transA, char transB, libxsmm_CSR_sparseslice* A_sparse,
const bfloat16* B, char transC, float* C, int block_id, int tid,
int nthreads) {
const uint16 alpha = 1;
const uint16 beta = 0;
return libxsmm_spmdm_compute_bfloat16_thread(
handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse,
reinterpret_cast<const uint16*>(B), transC,
reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads);
handle, transA, transB, &alpha, A_sparse,
reinterpret_cast<const uint16*>(B), transC, &beta, C, block_id, tid,
nthreads);
}
void wrapper_libxsmm_spmdm_compute_generic_thread(
empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse,
const float* B, char transC, const float* beta, float* C, int block_id,
int tid, int nthreads) {
return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha,
A_sparse, B, transC, beta, C,
char transB, libxsmm_CSR_sparseslice* A_sparse, const float* B, char transC,
float* C, int block_id, int tid, int nthreads) {
const float alpha = 1.f;
const float beta = 0.f;
return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, &alpha,
A_sparse, B, transC, &beta, C,
block_id, tid, nthreads);
}
......@@ -1648,13 +1651,11 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
while (true) {
int work_item = cur_mult_block_number.fetch_add(1);
if (work_item >= total_num_mult_blocks) break;
const TL alpha(1.0); // Stored in a variable so we can get a pointer
const TL beta(0.0); // Stored in a variable so we can get a pointer
wrapper_libxsmm_spmdm_compute_generic_thread(
empty_type_wrapper<TL>{}, &entry->handle,
(transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr,
right_data, (transpose_output ? 'T' : 'N'), &beta, output_data,
work_item, i, actual_num_threads);
(transpose_left ? 'T' : 'N'), 'N', entry->output_csr, right_data,
(transpose_output ? 'T' : 'N'), output_data, work_item, i,
actual_num_threads);
}
});
// Put handle + CSR storage back into cache
......@@ -1802,15 +1803,17 @@ inline void SparseMatMul<TL, TR>::Compute(
SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
#endif
REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
REGISTER_SPARSE_MATMUL(float, bfloat16);
REGISTER_SPARSE_MATMUL(bfloat16, float);
#ifdef TENSORFLOW_USE_LIBXSMM
REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16);
REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
#else
REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
REGISTER_SPARSE_MATMUL(float, float);
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册