提交 ecbff689 编写于 作者: T tianbingsz 提交者: GitHub

Merge pull request #1147 from tianbingsz/paddle_func_sparse

Matrix::MUL operators using and test Daoyuan's Paddle Function, SparseMatrixArg and Function Test
......@@ -32,14 +32,20 @@ const SparseMatrixArg& BufferArg::sparse() const {
SparseMatrixArg::SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType)
: BufferArg(sparse, argType),
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32),
nnz_(sparse.getElementCnt()),
format_(static_cast<SparseDataFormat>(sparse.getFormat())),
type_(static_cast<SparseDataType>(sparse.getValueType())) {
bufferType_ = TENSOR_SPARSE;
}
SparseMatrixArg::SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType)
: BufferArg(sparse, argType),
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32),
nnz_(sparse.getElementCnt()),
format_(static_cast<SparseDataFormat>(sparse.getFormat())),
type_(static_cast<SparseDataType>(sparse.getValueType())) {
bufferType_ = TENSOR_SPARSE;
}
......
......@@ -30,13 +30,6 @@ enum BufferType {
TENSOR_SPARSE = 4
};
enum SparseDataType {
SPARSE_NO_VALUE = 0, // do not need value pointer, all values are 1
SPARSE_FLOAT_VALUE = 1
};
enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 };
class BufferArg;
class SequenceArg;
class SparseMatrixArg;
......@@ -79,19 +72,21 @@ public:
BufferArg(ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: buf_(nullptr),
valueType_(valueType),
shape_(shape),
argType_(argType) {}
: buf_(nullptr), valueType_(valueType), shape_(shape), argType_(argType) {
bufferType_ = TENSOR_NORMAL;
}
BufferArg(void* buf,
ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {}
: buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {
bufferType_ = TENSOR_NORMAL;
}
BufferArg(void* buf, ValueType valueType)
: buf_(buf), valueType_(valueType) {}
BufferArg(void* buf, ValueType valueType) : buf_(buf), valueType_(valueType) {
bufferType_ = TENSOR_NORMAL;
}
BufferArg(const Matrix& matrix, ArgType argType = UNSPECIFIED)
: buf_(
......@@ -167,8 +162,9 @@ public:
ValueType valueType() const { return valueType_; }
BufferType bufferType() const { return bufferType_; }
const TensorShape& shape() const { return shape_; }
bool isSparse() const { return (TENSOR_SPARSE == bufferType_); }
bool isSparseArg() const { return TENSOR_SPARSE == bufferType_; }
bool isSequenceArg() const { return TENSOR_SEQUENCE_DATA == bufferType_; }
virtual size_t numElements() const { return shape_.getElements(); }
const SequenceArg& sequence() const;
const SparseMatrixArg& sparse() const;
......@@ -179,6 +175,7 @@ protected:
TensorShape shape_;
BufferType bufferType_{TENSOR_UNKNOWN};
ArgType argType_{UNSPECIFIED};
// TODO(tianbing), add deviceType_
// leading dimensions. The size is dims_.size()
// Dims lds_;
};
......@@ -191,6 +188,7 @@ class SequenceIdArg : public BufferArg {
public:
SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
: BufferArg(VALUE_TYPE_INT32, shape, argType) {
bufferType_ = TENSOR_SEQUENCE_ID;
CHECK_EQ(shape_.ndims(), (size_t)1);
CHECK_GT(shape_[0], 1);
numSeqs_ = shape_[0] - 1;
......@@ -228,7 +226,9 @@ public:
SequenceArg(ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {}
: BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {
bufferType_ = TENSOR_SEQUENCE_DATA;
}
SequenceArg(void* buf,
ValueType valueType,
......@@ -269,31 +269,75 @@ public:
const BufferArg& row,
const BufferArg& col,
size_t nnz,
SparseDataFormat format,
SparseDataType type,
SparseFormat format,
SparseValueType type,
ArgType argType = UNSPECIFIED)
: BufferArg(buf, valueType, shape, argType),
row_(row),
col_(col),
nnz_(nnz),
format_(format),
type_(type) {
format_(static_cast<SparseDataFormat>(format)),
type_(static_cast<SparseDataType>(type)) {
bufferType_ = TENSOR_SPARSE;
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
CHECK_EQ(shape_.ndims(), (size_t)2);
CHECK_EQ(row_.shape().ndims(), (size_t)1);
CHECK_EQ(col_.shape().ndims(), (size_t)1);
if (format == SPARSE_CSR_FORMAT) {
if (format_ == T_SPARSE_CSR) {
CHECK_EQ(nnz, col.shape()[0]);
} else if (format == SPARSE_CSC_FORMAT) {
} else if (format_ == T_SPARSE_CSC) {
CHECK_EQ(nnz, row.shape()[0]);
}
}
SparseMatrixArg(ValueType valueType,
const TensorShape& shape,
size_t nnz,
SparseFormat format,
SparseValueType type,
ArgType argType = UNSPECIFIED)
: BufferArg(valueType, shape, argType),
row_(BufferArg(nullptr, VALUE_TYPE_INT32)),
col_(BufferArg(nullptr, VALUE_TYPE_INT32)),
nnz_(nnz),
format_(static_cast<SparseDataFormat>(format)),
type_(static_cast<SparseDataType>(type)) {
bufferType_ = TENSOR_SPARSE;
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
CHECK_EQ(shape_.ndims(), (size_t)2);
/// len of row_ : height + 1 (CSR) or nnz (CSC), buf_ == nullptr
row_ = (format_ == T_SPARSE_CSR
? BufferArg(VALUE_TYPE_INT32, TensorShape{shape_[0] + 1})
: BufferArg(VALUE_TYPE_INT32, TensorShape{nnz}));
/// len of col_ : width + 1 (CSC) or nnz (CSR), buf_ == nullptr
col_ = (format_ == T_SPARSE_CSR
? BufferArg(VALUE_TYPE_INT32, TensorShape{nnz})
: BufferArg(VALUE_TYPE_INT32, TensorShape{shape_[1] + 1}));
}
SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED);
SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED);
template <DeviceType DType>
typename Tensor<real, DType>::SparseMatrix SparseMatrix() const {
CHECK(buf_);
CHECK(valueType_ == DataType<real>::value);
// CHECK(deviceType_ == DType);
CHECK_EQ(2, shape_.ndims());
return typename Tensor<real, DType>::SparseMatrix(
reinterpret_cast<real*>(buf_),
reinterpret_cast<int*>(row_.data()),
reinterpret_cast<int*>(col_.data()),
shape_[0],
shape_[1],
nnz_,
static_cast<SparseValueType>(type_),
static_cast<SparseFormat>(format_),
false);
}
~SparseMatrixArg() {}
void* getRowBuf() const { return row_.data(); }
......@@ -302,6 +346,8 @@ public:
size_t nnz() const { return nnz_; }
size_t numElements() const override { return nnz_; }
SparseDataFormat dataFormat() const { return format_; }
SparseDataType dataType() const { return type_; }
......
......@@ -26,6 +26,7 @@ if(WITH_TESTING)
add_simple_unittest(FunctionTest)
add_simple_unittest(ContextProjectionOpTest)
add_simple_unittest(PadOpTest)
add_simple_unittest(MulOpTest)
endif()
endif()
......
......@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "Function.h"
#include "paddle/math/Vector.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"
#include "paddle/math/tests/TensorCheck.h"
#include "paddle/testing/TestUtil.h"
......@@ -69,7 +70,7 @@ public:
}
// output need only contains shape, do not contains data.
void addOutputs(const BufferArg& output) {
void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) {
size_t size =
output.shape().getElements() * sizeOfValuType(output.valueType());
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size));
......@@ -79,12 +80,40 @@ public:
std::make_shared<BufferArg>(cpuMemory_.back()->getBuf(),
output.valueType(),
output.shape(),
ASSIGN_TO));
argType));
gpuOutputs_.emplace_back(
std::make_shared<BufferArg>(gpuMemory_.back()->getBuf(),
output.valueType(),
output.shape(),
ASSIGN_TO));
argType));
}
/// add and init output sparse matrix
void addOutputs(const SparseMatrixArg& output, ArgType argType = ASSIGN_TO) {
cpuSparse_ = std::make_shared<CpuSparseMatrix>(
output.shape()[0],
output.shape()[1],
output.nnz(),
static_cast<SparseValueType>(output.dataType()),
static_cast<SparseFormat>(output.dataFormat()));
gpuSparse_ = std::make_shared<GpuSparseMatrix>(
output.shape()[0],
output.shape()[1],
output.nnz(),
static_cast<SparseValueType>(output.dataType()),
static_cast<SparseFormat>(output.dataFormat()));
/// init sparse matrix
hl_stream_t stream(HPPL_STREAM_1);
cpuSparse_->randomizeUniform();
gpuSparse_->copyFrom(*cpuSparse_, stream);
hl_stream_synchronize(stream);
cpuOutputs_.emplace_back(
std::make_shared<SparseMatrixArg>(*cpuSparse_, argType));
gpuOutputs_.emplace_back(
std::make_shared<SparseMatrixArg>(*gpuSparse_, argType));
}
void addInputs(const SequenceArg& input) {
......@@ -107,10 +136,36 @@ public:
// TODO: need be implemented.
}
void addInputs(const SparseMatrixArg& input) {
cpuSparse_ = std::make_shared<CpuSparseMatrix>(
input.shape()[0],
input.shape()[1],
input.nnz(),
static_cast<SparseValueType>(input.dataType()),
static_cast<SparseFormat>(input.dataFormat()));
gpuSparse_ = std::make_shared<GpuSparseMatrix>(
input.shape()[0],
input.shape()[1],
input.nnz(),
static_cast<SparseValueType>(input.dataType()),
static_cast<SparseFormat>(input.dataFormat()));
/// init sparse matrix
hl_stream_t stream(HPPL_STREAM_1);
cpuSparse_->randomizeUniform();
gpuSparse_->copyFrom(*cpuSparse_, stream);
hl_stream_synchronize(stream);
cpuInputs_.emplace_back(std::make_shared<SparseMatrixArg>(*cpuSparse_));
gpuInputs_.emplace_back(std::make_shared<SparseMatrixArg>(*gpuSparse_));
}
void run() {
// prepare cpu/gpu arguments
initInputs();
initOutputs();
// function calculate
auto callFunction = [](FunctionBase* function,
std::vector<BufferArgPtr>& inputs,
......@@ -129,7 +184,7 @@ public:
callFunction(cpuFunc_.get(), cpuInputs_, cpuOutputs_);
callFunction(gpuFunc_.get(), gpuInputs_, gpuOutputs_);
// check outputs and inouts
// check outputs
compareOutputs();
}
......@@ -140,6 +195,10 @@ public:
protected:
void initInputs() {
for (size_t i = 0; i < cpuInputs_.size(); i++) {
if (cpuInputs_[i]->isSparseArg()) {
continue; /// sparse matrix already init
}
initArg(*cpuInputs_[i]);
// TODO: Need a BufferCopy used to copy from one BufferArg to another.
......@@ -152,14 +211,32 @@ protected:
}
}
void initOutputs() {
for (size_t i = 0; i < cpuOutputs_.size(); i++) {
if (cpuOutputs_[i]->isSparseArg()) {
continue; /// sparse matrix already init
}
initArg(*cpuOutputs_[i]);
// TODO: Need a BufferCopy used to copy from one BufferArg to another.
CpuVector cpuVector(cpuOutputs_[i]->shape().getElements(),
(real*)cpuOutputs_[i]->data());
GpuVector gpuVector(gpuOutputs_[i]->shape().getElements(),
(real*)gpuOutputs_[i]->data());
gpuVector.copyFrom(cpuVector);
}
}
void compareOutputs() {
for (size_t i = 0; i < cpuOutputs_.size(); i++) {
// TODO, Need a BufferCheck used to compare the two buffers.
auto cpu = cpuOutputs_[i];
auto gpu = gpuOutputs_[i];
CpuVector cpuVector(cpu->shape().getElements(), (real*)cpu->data());
GpuVector gpuVector(cpu->shape().getElements(), (real*)gpu->data());
const auto cpu = cpuOutputs_[i];
const auto gpu = gpuOutputs_[i];
CHECK_EQ(cpu->numElements(), gpu->numElements());
CpuVector cpuVector(cpu->numElements(), (real*)cpu->data());
GpuVector gpuVector(gpu->numElements(), (real*)gpu->data());
autotest::TensorCheckErr(cpuVector, gpuVector);
}
}
......@@ -195,6 +272,8 @@ protected:
std::vector<BufferArgPtr> cpuOutputs_;
std::vector<BufferArgPtr> gpuInputs_;
std::vector<BufferArgPtr> gpuOutputs_;
std::shared_ptr<CpuSparseMatrix> cpuSparse_;
std::shared_ptr<GpuSparseMatrix> gpuSparse_;
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MulOp.h"
/// todo(tianbing), delete it
#include <iostream>
#include "paddle/math/MathFunctions.h"
#include "paddle/math/SIMDFunctions.h"
#include "paddle/utils/ThreadLocal.h"
#ifndef PADDLE_TYPE_DOUBLE
#define GEMM paddle::gemm<float>
#else
#define GEMM paddle::gemm<double>
#endif
namespace {
inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
for (unsigned int i = 0; i < len; ++i) {
a[i] += (1.0 == scaleB) ? b[i] : scaleB * b[i];
}
}
inline void colVecAddTo(
real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) {
for (unsigned int i = 0; i < len; ++i) {
a[i * aWidth] += (1.0 == c) ? b[i * bWidth] : b[i * bWidth] * c;
}
}
} // namespace
namespace paddle {
/// sparse matrix (+)= dense matrix * dense matrix
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
const CpuMatrix& a,
const CpuMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans) {
CHECK_EQ(out.getValueType(), FLOAT_VALUE);
if (scaleT == 0) {
out.zeroMem();
}
const real* A = a.getData();
const real* B = b.getData();
real* C = out.getValue();
int* rows = out.getRows();
int* cols = out.getCols();
size_t width = out.getWidth();
size_t height = out.getHeight();
/// SPARSE_CSC, {a any, b not trans}
if (out.getFormat() == SPARSE_CSC) {
/// b not trans and a any
CHECK(!bTrans);
size_t m = !aTrans ? a.getWidth() : a.getHeight();
for (size_t i = 0; i < width; i++) {
size_t start = out.getColStartIdx(i);
size_t end = out.getColStartIdx(i + 1);
for (size_t j = start; j < end; j++) {
real sum = 0;
size_t rowIdx = rows[j];
for (size_t k = 0; k < m; k++) {
sum += (!aTrans ? A[rowIdx * m + k] : A[k * height + rowIdx]) *
B[k * width + i];
}
C[j] = scaleAB * sum + scaleT * C[j];
}
}
return;
}
/// SPARSE_CSR, {a any, b not trans} or {a not trans, b trans}
if (out.getFormat() == SPARSE_CSR) {
/// a and b can not both transpose
CHECK(!(aTrans && bTrans));
size_t m = a.getWidth();
for (size_t i = 0; i < height; i++) {
size_t start = out.getRowStartIdx(i);
size_t end = out.getRowStartIdx(i + 1);
for (size_t j = start; j < end; j++) {
real sum = 0;
size_t colIdx = cols[j];
for (size_t k = 0; k < m; k++) {
sum += (!aTrans ? A[i * m + k] : A[k * height + i]) *
(!bTrans ? B[k * width + colIdx] : B[colIdx * m + k]);
}
C[j] = scaleAB * sum + scaleT * C[j];
}
}
return;
}
}
/// dense matrix (+)= dense matrix * dense matrix
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuMatrix& a,
const CpuMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans) {
GEMM(aTrans ? CblasTrans : CblasNoTrans,
bTrans ? CblasTrans : CblasNoTrans,
out.getHeight(),
out.getWidth(),
!aTrans ? a.getWidth() : a.getHeight(),
scaleAB,
a.getData(),
a.getStride(),
b.getData(),
b.getStride(),
scaleT,
out.getData(),
out.getStride());
}
/// dense matrix (+)= sparse matrix * dense matrix
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuSparseMatrix& a,
const CpuMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans) {
if (scaleT == 0) {
out.zeroMem();
}
const real* B = b.getData();
real* C = out.getData();
if (out.getWidth() % 32 == 0) {
CHECK_EQ((size_t)B % 32, 0UL);
CHECK_EQ((size_t)C % 32, 0UL);
}
int* cols = a.getCols();
real* values = a.getValue();
for (size_t i = 0; i < a.getHeight(); ++i) {
const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
vecAddTo(!aTrans ? out.getRow(i) : out.getRow(cols[j]),
!aTrans ? const_cast<CpuMatrix&>(b).getRow(cols[j])
: const_cast<CpuMatrix&>(b).getRow(i),
(a.getValueType() == FLOAT_VALUE) ? values[j] : (real)1.0,
out.getWidth());
}
}
}
/// dense matrix (+)= dense matrix * sparse matrix
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuMatrix& a,
const CpuSparseMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans) {
if (scaleT == 0) {
out.zeroMem();
}
real* A = const_cast<real*>(a.getData());
real* B = const_cast<real*>(b.getValue());
real* C = out.getData();
int* rows = b.getRows();
int* cols = b.getCols();
/// SPARSE_CSC format
if (b.getFormat() == SPARSE_CSC) {
for (size_t j = 0; j < b.getWidth(); ++j) {
int start = b.getColStartIdx(j);
int end = b.getColStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(!bTrans ? C + j : C + rows[i],
!bTrans ? A + rows[i] : A + j,
(b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
out.getHeight(),
out.getWidth(),
a.getWidth());
}
}
return;
}
/// SPARSE_CSR format
if (b.getFormat() == SPARSE_CSR) {
for (size_t j = 0; j < b.getHeight(); ++j) {
int start = b.getRowStartIdx(j);
int end = b.getRowStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(!bTrans ? C + cols[i] : C + j,
!bTrans ? A + j : A + cols[i],
(b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
out.getHeight(),
out.getWidth(),
a.getWidth());
}
}
return;
}
}
/**
* mul operator
* out = scaleT * out + scaleAB * (A * B)
* here, scaleT in {0, 1}, scaleAB == 1,
* out = A * B, ASSIGN_TO
* out += A * B, ADD_TO
*
*
* \param outputs[0] output matrix (out), M * N,
* could be either Sparse or Dense Matrix
* M is num of rows, N is num of columns
* \param inputs[0] first input matrix (A), M * K (if non-trans)
* could be either Sparse or Dense Matrix
* M is num of rows, K is num of columns
* \param inputs[1] second input matrix (B), K * N (if non-trans)
* could be either Sparse or Dense Matrix
* K is num of rows, N is num of columns
*
* Support eight Mul operators, with both GPU and CPU devices
* For each device, four Mul operators are supported:
* 1. dense (out) = dense (A) * dense (B)
* 2. dense (out) = sparse (A) * dense (B)
* sparse matrix only support SPARSE_CSR format
* 3. dense (out) = dense (A) * sparse (B)
* sparse matrix support SPARSE_CSC and SPARSE_CSR formats
* 4. sparse (out) = dense (A) * dense (B)
* sparse matrix support SPARSE_CSC and SPARSE_CSR formats
*
*/
template <DeviceType Device>
class MulFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
aTrans_ = config.get<bool>("aTrans");
bTrans_ = config.get<bool>("bTrans");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK(!aTrans_ || !bTrans_)
<< "Not support both a and b are transpose matrices";
CHECK_EQ((size_t)2, inputs.size());
CHECK_EQ((size_t)1, outputs.size());
CHECK(inputs[0].data() && inputs[1].data() && outputs[0].data());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
size_t aRow = !aTrans_ ? inputs[0].shape()[0] : inputs[0].shape()[1];
size_t aCol = !aTrans_ ? inputs[0].shape()[1] : inputs[0].shape()[0];
size_t bRow = !bTrans_ ? inputs[1].shape()[0] : inputs[1].shape()[1];
size_t bCol = !bTrans_ ? inputs[1].shape()[1] : inputs[1].shape()[0];
/// C = A * B, or C += A * B, for matrix format
CHECK_EQ(aCol, bRow);
CHECK_EQ(aRow, outputs[0].shape()[0]);
CHECK_EQ(bCol, outputs[0].shape()[1]);
/// only support C = A * B (ASSIGN_TO) or C += A * B (ADD_TO)
real scaleT = (outputs[0].getArgType() == ADD_TO) ? 1.0 : 0.0;
/// support dense = not both sparse * sparse
/// or sparse = dense * dense
CHECK((!outputs[0].isSparseArg() &&
!(inputs[0].isSparseArg() && inputs[1].isSparseArg())) ||
(outputs[0].isSparseArg() && !inputs[0].isSparseArg() &&
!inputs[1].isSparseArg()));
auto outMat = outputs[0].matrix<Device>();
/// dense matrix = dense matrix * dense matrix
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) {
MulOp<Device>(outMat,
inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(),
1.0, // scaleAB
scaleT,
aTrans_,
bTrans_);
return;
}
/// dense matrix = dense matrix * sparse matrix
if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) {
CHECK(!aTrans_) << "Not supported a transpose";
MulOp<Device>(outMat,
inputs[0].matrix<Device>(),
inputs[1].sparse().SparseMatrix<Device>(),
1.0, // scaleAB
scaleT,
aTrans_,
bTrans_);
return;
}
/// dense matrix = sparse matrix * dense matrix
if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) {
CHECK(!bTrans_) << "Not supported b transpose";
CHECK_EQ(inputs[0].sparse().dataFormat(), T_SPARSE_CSR)
<< "Only supported SPARSE_CSR format for sparse matrix a";
MulOp<Device>(outMat,
inputs[0].sparse().SparseMatrix<Device>(),
inputs[1].matrix<Device>(),
1.0, // scaleAB
scaleT,
aTrans_,
bTrans_);
return;
}
/// sparse matrix = dense matrix * dense matrix
auto outSparseMat = outputs[0].sparse().SparseMatrix<Device>();
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
outputs[0].isSparseArg()) {
MulOp<Device>(outSparseMat,
inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(),
1.0, // scaleAB
scaleT,
aTrans_,
bTrans_);
return;
}
}
private:
bool aTrans_;
bool bTrans_;
};
REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(MulOp, GPU, MulFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"
namespace paddle {
/// CPU, dense matrix (+)= dense matrix * dense matrix
template <DeviceType DType>
void MulOp(CpuMatrix& out,
const CpuMatrix& a,
const CpuMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans);
/// CPU, dense matrix (+)= sparse matrix * dense matrix
template <DeviceType DType>
void MulOp(CpuMatrix& out,
const CpuSparseMatrix& a,
const CpuMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans);
/// CPU, dense matrix (+)= dense matrix * sparse matrix
template <DeviceType DType>
void MulOp(CpuMatrix& out,
const CpuMatrix& a,
const CpuSparseMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans);
/// CPU, sparse matrix (+)= dense matrix * dense matrix
template <DeviceType DType>
void MulOp(CpuSparseMatrix& out,
const CpuMatrix& a,
const CpuMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans);
/// GPU, dense matrix (+)= dense matrix * dense matrix
template <DeviceType DType>
void MulOp(GpuMatrix& out,
const GpuMatrix& a,
const GpuMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans);
/// GPU, dense matrix (+)= sparse matrix * dense matrix
template <DeviceType DType>
void MulOp(GpuMatrix& out,
const GpuSparseMatrix& a,
const GpuMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans);
/// GPU, dense matrix (+)= dense matrix * sparse matrix
template <DeviceType DType>
void MulOp(GpuMatrix& out,
const GpuMatrix& a,
const GpuSparseMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans);
/// GPU, sparse matrix (+)= dense matrix * dense matrix
template <DeviceType DType>
void MulOp(GpuSparseMatrix& out,
const GpuMatrix& a,
const GpuMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "MulOp.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"
namespace paddle {
/// dense matrix (+)= dense matrix * dense matrix
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuMatrix& a,
const GpuMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans) {
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
hl_matrix_mul(const_cast<real*>(a.getData()),
!aTrans ? HPPL_OP_N : HPPL_OP_T,
const_cast<real*>(b.getData()),
!bTrans ? HPPL_OP_N : HPPL_OP_T,
const_cast<real*>(out.getData()),
out.getHeight(),
out.getWidth(),
!aTrans ? a.getWidth() : a.getHeight(),
scaleAB,
scaleT,
a.getStride(),
b.getStride(),
out.getStride());
}
/// dense matrix (+)= sparse matrix * dense matrix
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuSparseMatrix& a,
const GpuMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans) {
CHECK(out.isContiguous());
CHECK(b.isContiguous());
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
hl_matrix_csr_mul_dense(a.sMatrix_.get(),
aTrans ? HPPL_OP_T : HPPL_OP_N,
const_cast<real*>(b.getData()),
HPPL_OP_N,
const_cast<real*>(out.getData()),
out.getHeight(),
out.getWidth(),
b.getHeight(),
scaleAB,
scaleT);
}
/// dense matrix (+)= dense matrix * sparse matrix
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuMatrix& a,
const GpuSparseMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans) {
CHECK(out.isContiguous());
CHECK(a.isContiguous());
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
if (b.format_ == SPARSE_CSC) {
hl_matrix_dense_mul_csc(const_cast<real*>(a.getData()),
HPPL_OP_N,
b.sMatrix_.get(),
bTrans ? HPPL_OP_T : HPPL_OP_N,
const_cast<real*>(out.getData()),
out.getHeight(),
out.getWidth(),
a.getWidth(),
scaleAB,
scaleT);
} else {
hl_matrix_dense_mul_csr(const_cast<real*>(a.getData()),
HPPL_OP_N,
b.sMatrix_.get(),
bTrans ? HPPL_OP_T : HPPL_OP_N,
const_cast<real*>(out.getData()),
out.getHeight(),
out.getWidth(),
a.getWidth(),
scaleAB,
scaleT);
}
}
/// sparse matrix (+)= dense matrix * dense matrix
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
const GpuMatrix& a,
const GpuMatrix& b,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans) {
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
hl_sparse_matrix_mul(const_cast<real*>(a.getData()),
aTrans ? HPPL_OP_T : HPPL_OP_N,
const_cast<real*>(b.getData()),
bTrans ? HPPL_OP_T : HPPL_OP_N,
out.sMatrix_.get(),
out.getHeight(),
out.getWidth(),
!bTrans ? b.getHeight() : b.getWidth(),
scaleAB,
scaleT);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"
#include "paddle/math/tests/test_matrixUtil.h"
#include "paddle/testing/TestUtil.h"
using namespace paddle; // NOLINT
/**
* C += A * B, A, B, C dense matrix
* dense = dense * dense
*/
void testFuncDDDMatrix(
bool transa, bool transb, size_t dimM, size_t dimN, size_t dimK) {
real scaleT = 1.0;
size_t heightA = (transa == false) ? dimM : dimK;
size_t widthA = (transa == false) ? dimK : dimM;
size_t heightB = (transb == false) ? dimK : dimN;
size_t widthB = (transb == false) ? dimN : dimK;
size_t heightC = dimM;
size_t widthC = dimN;
// init Test object
FunctionCompare test(
"MulOp", FuncConfig().set("aTrans", transa).set("bTrans", transb));
// prepare input arguments
/// matrix A : HA * WA
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightA, widthA}));
/// matrix B: HB * WB
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightB, widthB}));
/// output matrix C: HC * WC
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightC, widthC}),
scaleT == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function
test.run();
}
TEST(MulOp, DDDMatrixMul) {
LOG(INFO) << "function test for dense = dense * dense matrix";
for (const auto transa : {false, true}) {
for (const auto transb : {false, true}) {
for (const auto dimM : {1, 10, 100}) {
for (const auto dimN : {1, 10}) {
for (const auto dimK : {8}) {
if (transa && transb) {
continue;
}
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
<< " transa=" << transa << " transb=" << transb
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK;
testFuncDDDMatrix(transa, transb, dimM, dimN, dimK);
}
}
}
}
}
}
/**
* C += A * B, B, C dense, A sparse
* dense = sparse * dense
*/
void testFuncDSparseDMatrix(
size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) {
real scaleT = 1.0;
// init Test object
FunctionCompare test("MulOp",
FuncConfig().set("aTrans", false).set("bTrans", false));
// prepare input arguments
/// sparse matrix A : M * K
test.addInputs(SparseMatrixArg(
VALUE_TYPE_FLOAT, TensorShape{dimM, dimK}, nnz, FORMAT, FLOAT_VALUE));
/// matrix B: K * N
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimK, dimN}));
/// output matrix C: M * N
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}),
scaleT == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function
test.run();
}
TEST(MuLOp, DSparseDMul) {
LOG(INFO) << "function test for dense = sparse * dense matrix";
for (const auto dimM : {10, 100, 1000}) {
for (const auto dimN : {10, 100}) {
for (const auto dimK : {3, 10}) {
for (const auto nnz : {3, 10}) {
for (const auto FORMAT : {SPARSE_CSR}) {
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK
<< " nnz=" << std::setw(5) << nnz
<< " format=" << std::setw(5) << FORMAT;
testFuncDSparseDMatrix(dimM, dimN, dimK, nnz, FORMAT);
}
}
}
}
}
}
/**
* C += A * B, A, C dense, B sparse
* dense = dense * sparse
*/
void testFuncDDSparseMatrix(
size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) {
real scaleT = 1.0;
// init Test object
FunctionCompare test("MulOp",
FuncConfig().set("aTrans", false).set("bTrans", false));
// prepare input arguments
/// matrix A : M * K
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK}));
/// matrix B: K * N
test.addInputs(SparseMatrixArg(
VALUE_TYPE_FLOAT, TensorShape{dimK, dimN}, nnz, FORMAT, FLOAT_VALUE));
/// output matrix C: M * N
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}),
scaleT == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function
test.run();
}
TEST(MulOp, DDSparseMul) {
LOG(INFO) << "function test for dense = dense * sparse matrix";
for (const auto dimM : {10, 100, 1000}) {
for (const auto dimN : {10, 100}) {
for (const auto dimK : {3, 10}) {
for (const auto nnz : {3, 10}) {
for (const auto FORMAT : {SPARSE_CSR, SPARSE_CSC}) {
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK
<< " nnz=" << std::setw(5) << nnz
<< " format=" << std::setw(5) << FORMAT;
testFuncDDSparseMatrix(dimM, dimN, dimK, nnz, FORMAT);
}
}
}
}
}
}
/**
* C += A * B, A sparse, B, C dense
* sparse = dense * dense
*/
void testFuncSparseDDMatrix(
size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) {
real scaleT = 1.0;
// init Test object
FunctionCompare test("MulOp",
FuncConfig().set("aTrans", false).set("bTrans", false));
// prepare input arguments
/// matrix A : M * K
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK}));
/// matrix B: K * N
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimK, dimN}));
/// output sparse matrix C: M * N
test.addOutputs(
SparseMatrixArg(
VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}, nnz, FORMAT, FLOAT_VALUE),
scaleT == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function
test.run();
}
TEST(MulOp, SparseDDMul) {
LOG(INFO) << "function test for sparse = dense * dense matrix";
for (const auto dimM : {10, 100, 1000}) {
for (const auto dimN : {10, 100}) {
for (const auto dimK : {3, 10}) {
for (const auto nnz : {3, 10}) {
for (const auto FORMAT : {SPARSE_CSC, SPARSE_CSR}) {
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK
<< " nnz=" << std::setw(5) << nnz
<< " format=" << std::setw(5) << FORMAT;
testFuncSparseDDMatrix(dimM, dimN, dimK, nnz, FORMAT);
}
}
}
}
}
}
......@@ -31,6 +31,10 @@ enum DeviceType {
DEVICE_TYPE_GPU = 2
};
enum SparseDataType { T_NO_VALUE = 0, T_FLOAT_VALUE = 1 };
enum SparseDataFormat { T_SPARSE_CSR = 0, T_SPARSE_CSC = 1 };
inline int sizeOfValuType(ValueType valueType) {
if (valueType == VALUE_TYPE_INT32) {
return 4;
......@@ -87,6 +91,29 @@ struct MatrixT<int, DEVICE_TYPE_GPU> {
using type = void; // Not implemented
};
template <typename VType, DeviceType Device>
struct SparseMatrixT;
template <>
struct SparseMatrixT<real, DEVICE_TYPE_CPU> {
using type = CpuSparseMatrix;
};
template <>
struct SparseMatrixT<real, DEVICE_TYPE_GPU> {
using type = GpuSparseMatrix;
};
template <>
struct SparseMatrixT<int, DEVICE_TYPE_CPU> {
using type = void; // Not implemented
};
template <>
struct SparseMatrixT<int, DEVICE_TYPE_GPU> {
using type = void; // Not implemented
};
template <typename VType, DeviceType Device>
struct VectorT;
......@@ -114,8 +141,9 @@ struct VectorT<int, DEVICE_TYPE_GPU> {
template <typename VType, DeviceType DType>
struct Tensor {
typedef typename detail::MatrixT<VType, DType>::type Matrix;
typedef typename detail::VectorT<VType, DType>::type Vector;
typedef typename detail::MatrixT<VType, DType>::type Matrix;
typedef typename detail::SparseMatrixT<VType, DType>::type SparseMatrix;
};
} // namespace paddle
......@@ -31,6 +31,7 @@ limitations under the License. */
namespace paddle {
/// TODO(tianbing), move to paddle/function/TensorType.h
enum SparseValueType { NO_VALUE = 0, FLOAT_VALUE = 1 };
/**
......@@ -56,6 +57,7 @@ enum SparseValueType { NO_VALUE = 0, FLOAT_VALUE = 1 };
* value [1, 1, 2, 2, 5]
* @endcode
*/
/// TODO(tianbing), move to paddle/function/TensorType.h
enum SparseFormat { SPARSE_CSR = 0, SPARSE_CSC = 1 };
class Matrix;
......
......@@ -177,7 +177,6 @@ GpuSparseMatrix::GpuSparseMatrix(real* value,
hl_sparse_matrix_s_ptr tmp2(tmp, hl_destruct_sparse_matrix);
sMatrix_ = tmp2;
}
LOG(INFO) << "weight to matrix ";
}
}
......
......@@ -30,6 +30,17 @@ void checkMatrixEqual(const MatrixPtr& a, const MatrixPtr& b) {
}
}
void checkSMatrixEqual(const CpuSparseMatrix& a, const CpuSparseMatrix& b) {
ASSERT_EQ(a.getWidth(), b.getWidth());
ASSERT_EQ(a.getHeight(), b.getHeight());
ASSERT_EQ(a.isTransposed(), b.isTransposed());
ASSERT_EQ(a.getFormat(), b.getFormat());
ASSERT_EQ(a.getElementCnt(), b.getElementCnt());
for (size_t r = 0; r < a.getElementCnt(); ++r) {
ASSERT_FLOAT_EQ(a.getValue()[r], b.getValue()[r]);
}
}
void checkSMatrixEqual(const CpuSparseMatrixPtr& a,
const CpuSparseMatrixPtr& b) {
ASSERT_EQ(a->getWidth(), b->getWidth());
......@@ -73,6 +84,36 @@ void checkSMatrixEqual2(const CpuSparseMatrixPtr& a,
}
}
void checkSMatrixEqual2Dense(const CpuSparseMatrix& a, const CpuMatrix& b) {
ASSERT_EQ(a.getWidth(), b.getWidth());
ASSERT_EQ(a.getHeight(), b.getHeight());
ASSERT_EQ(a.isTransposed(), b.isTransposed());
if (a.getFormat() == SPARSE_CSC) {
int* rows = a.getRows();
for (size_t i = 0; i < a.getWidth(); i++) {
for (size_t j = a.getColStartIdx(i); j < a.getColStartIdx(i + 1); j++) {
if (a.getValueType() == FLOAT_VALUE) {
ASSERT_FLOAT_EQ(a.getValue()[j], b.getElement(rows[j], i));
} else {
ASSERT_FLOAT_EQ(1.0, b.getElement(rows[j], i));
}
}
}
} else {
int* cols = a.getCols();
for (size_t i = 0; i < a.getHeight(); i++) {
for (size_t j = a.getRowStartIdx(i); j < a.getRowStartIdx(i + 1); j++) {
if (a.getValueType() == FLOAT_VALUE) {
ASSERT_FLOAT_EQ(a.getValue()[j], b.getElement(i, cols[j]));
} else {
ASSERT_FLOAT_EQ(1.0, b.getElement(i, cols[j]));
}
}
}
}
}
void checkSMatrixEqual2Dense(const CpuSparseMatrixPtr& a,
const CpuMatrixPtr& b) {
ASSERT_EQ(a->getWidth(), b->getWidth());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册