提交 999cd14a 编写于 作者: X xutianbing

Further address Daoyuan's comments, clean the code.

上级 b3be7358
...@@ -34,8 +34,8 @@ SparseMatrixArg::SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType) ...@@ -34,8 +34,8 @@ SparseMatrixArg::SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType)
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32), row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32), col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32),
nnz_(sparse.getElementCnt()), nnz_(sparse.getElementCnt()),
format_(sparse.getFormat()), format_(static_cast<SparseDataFormat>(sparse.getFormat())),
type_(sparse.getValueType()) { type_(static_cast<SparseDataType>(sparse.getValueType())) {
bufferType_ = TENSOR_SPARSE; bufferType_ = TENSOR_SPARSE;
} }
...@@ -44,8 +44,8 @@ SparseMatrixArg::SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType) ...@@ -44,8 +44,8 @@ SparseMatrixArg::SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType)
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32), row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32), col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32),
nnz_(sparse.getElementCnt()), nnz_(sparse.getElementCnt()),
format_(sparse.getFormat()), format_(static_cast<SparseDataFormat>(sparse.getFormat())),
type_(sparse.getValueType()) { type_(static_cast<SparseDataType>(sparse.getValueType())) {
bufferType_ = TENSOR_SPARSE; bufferType_ = TENSOR_SPARSE;
} }
......
...@@ -72,19 +72,21 @@ public: ...@@ -72,19 +72,21 @@ public:
BufferArg(ValueType valueType, BufferArg(ValueType valueType,
const TensorShape& shape, const TensorShape& shape,
ArgType argType = UNSPECIFIED) ArgType argType = UNSPECIFIED)
: buf_(nullptr), : buf_(nullptr), valueType_(valueType), shape_(shape), argType_(argType) {
valueType_(valueType), bufferType_ = TENSOR_NORMAL;
shape_(shape), }
argType_(argType) {}
BufferArg(void* buf, BufferArg(void* buf,
ValueType valueType, ValueType valueType,
const TensorShape& shape, const TensorShape& shape,
ArgType argType = UNSPECIFIED) ArgType argType = UNSPECIFIED)
: buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {} : buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {
bufferType_ = TENSOR_NORMAL;
}
BufferArg(void* buf, ValueType valueType) BufferArg(void* buf, ValueType valueType) : buf_(buf), valueType_(valueType) {
: buf_(buf), valueType_(valueType) {} bufferType_ = TENSOR_NORMAL;
}
BufferArg(const Matrix& matrix, ArgType argType = UNSPECIFIED) BufferArg(const Matrix& matrix, ArgType argType = UNSPECIFIED)
: buf_( : buf_(
...@@ -173,7 +175,7 @@ protected: ...@@ -173,7 +175,7 @@ protected:
TensorShape shape_; TensorShape shape_;
BufferType bufferType_{TENSOR_UNKNOWN}; BufferType bufferType_{TENSOR_UNKNOWN};
ArgType argType_{UNSPECIFIED}; ArgType argType_{UNSPECIFIED};
// todo(tianbing), add deviceType_ // TODO(tianbing), add deviceType_
// leading dimensions. The size is dims_.size() // leading dimensions. The size is dims_.size()
// Dims lds_; // Dims lds_;
}; };
...@@ -186,6 +188,7 @@ class SequenceIdArg : public BufferArg { ...@@ -186,6 +188,7 @@ class SequenceIdArg : public BufferArg {
public: public:
SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED) SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
: BufferArg(VALUE_TYPE_INT32, shape, argType) { : BufferArg(VALUE_TYPE_INT32, shape, argType) {
bufferType_ = TENSOR_SEQUENCE_ID;
CHECK_EQ(shape_.ndims(), (size_t)1); CHECK_EQ(shape_.ndims(), (size_t)1);
CHECK_GT(shape_[0], 1); CHECK_GT(shape_[0], 1);
numSeqs_ = shape_[0] - 1; numSeqs_ = shape_[0] - 1;
...@@ -223,7 +226,9 @@ public: ...@@ -223,7 +226,9 @@ public:
SequenceArg(ValueType valueType, SequenceArg(ValueType valueType,
const TensorShape& shape, const TensorShape& shape,
ArgType argType = UNSPECIFIED) ArgType argType = UNSPECIFIED)
: BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {} : BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {
bufferType_ = TENSOR_SEQUENCE_DATA;
}
SequenceArg(void* buf, SequenceArg(void* buf,
ValueType valueType, ValueType valueType,
...@@ -271,16 +276,16 @@ public: ...@@ -271,16 +276,16 @@ public:
row_(row), row_(row),
col_(col), col_(col),
nnz_(nnz), nnz_(nnz),
format_(format), format_(static_cast<SparseDataFormat>(format)),
type_(type) { type_(static_cast<SparseDataType>(type)) {
bufferType_ = TENSOR_SPARSE; bufferType_ = TENSOR_SPARSE;
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE)); CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
CHECK_EQ(shape_.ndims(), (size_t)2); CHECK_EQ(shape_.ndims(), (size_t)2);
CHECK_EQ(row_.shape().ndims(), (size_t)1); CHECK_EQ(row_.shape().ndims(), (size_t)1);
CHECK_EQ(col_.shape().ndims(), (size_t)1); CHECK_EQ(col_.shape().ndims(), (size_t)1);
if (format == SPARSE_CSR) { if (format_ == T_SPARSE_CSR) {
CHECK_EQ(nnz, col.shape()[0]); CHECK_EQ(nnz, col.shape()[0]);
} else if (format == SPARSE_CSC) { } else if (format_ == T_SPARSE_CSC) {
CHECK_EQ(nnz, row.shape()[0]); CHECK_EQ(nnz, row.shape()[0]);
} }
} }
...@@ -292,23 +297,23 @@ public: ...@@ -292,23 +297,23 @@ public:
SparseValueType type, SparseValueType type,
ArgType argType = UNSPECIFIED) ArgType argType = UNSPECIFIED)
: BufferArg(valueType, shape, argType), : BufferArg(valueType, shape, argType),
/// len of row_ : height + 1 (CSR), buf_ == nullptr row_(BufferArg(nullptr, VALUE_TYPE_INT32)),
row_(format == SPARSE_CSR col_(BufferArg(nullptr, VALUE_TYPE_INT32)),
? BufferArg(VALUE_TYPE_INT32, TensorShape{shape[0] + 1})
: BufferArg(VALUE_TYPE_INT32, TensorShape{nnz})),
/// len of col_ : width + 1 (CSC), buf_ == nullptr
col_(format == SPARSE_CSR
? BufferArg(VALUE_TYPE_INT32, TensorShape{nnz})
: BufferArg(VALUE_TYPE_INT32, TensorShape{shape[1] + 1})),
nnz_(nnz), nnz_(nnz),
format_(format), format_(static_cast<SparseDataFormat>(format)),
type_(type) { type_(static_cast<SparseDataType>(type)) {
bufferType_ = TENSOR_SPARSE; bufferType_ = TENSOR_SPARSE;
/// todo(tianbing)
/// valueType and shape_.ndims() == 2 need to check before
/// this constructor to make sure row_ and col_ are right
CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE)); CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
CHECK_EQ(shape_.ndims(), (size_t)2); CHECK_EQ(shape_.ndims(), (size_t)2);
/// len of row_ : height + 1 (CSR) or nnz (CSC), buf_ == nullptr
row_ = (format_ == T_SPARSE_CSR
? BufferArg(VALUE_TYPE_INT32, TensorShape{shape_[0] + 1})
: BufferArg(VALUE_TYPE_INT32, TensorShape{nnz}));
/// len of col_ : width + 1 (CSC) or nnz (CSR), buf_ == nullptr
col_ = (format_ == T_SPARSE_CSR
? BufferArg(VALUE_TYPE_INT32, TensorShape{nnz})
: BufferArg(VALUE_TYPE_INT32, TensorShape{shape_[1] + 1}));
} }
SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED); SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED);
...@@ -328,8 +333,8 @@ public: ...@@ -328,8 +333,8 @@ public:
shape_[0], shape_[0],
shape_[1], shape_[1],
nnz_, nnz_,
type_, static_cast<SparseValueType>(type_),
format_, static_cast<SparseFormat>(format_),
false); false);
} }
...@@ -343,16 +348,16 @@ public: ...@@ -343,16 +348,16 @@ public:
size_t numElements() const override { return nnz_; } size_t numElements() const override { return nnz_; }
SparseFormat dataFormat() const { return format_; } SparseDataFormat dataFormat() const { return format_; }
SparseValueType dataType() const { return type_; } SparseDataType dataType() const { return type_; }
private: private:
BufferArg row_; BufferArg row_;
BufferArg col_; BufferArg col_;
size_t nnz_; size_t nnz_;
SparseFormat format_; SparseDataFormat format_;
SparseValueType type_; SparseDataType type_;
}; };
} // namespace paddle } // namespace paddle
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include "Function.h" #include "Function.h"
#include "paddle/math/Matrix.h" #include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h" #include "paddle/math/SparseMatrix.h"
#include "paddle/math/Vector.h"
#include "paddle/math/tests/TensorCheck.h" #include "paddle/math/tests/TensorCheck.h"
#include "paddle/testing/TestUtil.h" #include "paddle/testing/TestUtil.h"
...@@ -77,33 +76,33 @@ public: ...@@ -77,33 +76,33 @@ public:
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size)); cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size)); gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size));
cpuOutputs_.emplace_back(std::make_shared<BufferArg>( cpuOutputs_.emplace_back(
cpuMemory_.back()->getBuf(), std::make_shared<BufferArg>(cpuMemory_.back()->getBuf(),
output.valueType(), output.valueType(),
output.shape(), output.shape(),
// todo(tianbing), argType = output.getArgType(), but default ADD_TO argType));
argType)); gpuOutputs_.emplace_back(
gpuOutputs_.emplace_back(std::make_shared<BufferArg>( std::make_shared<BufferArg>(gpuMemory_.back()->getBuf(),
gpuMemory_.back()->getBuf(), output.valueType(),
output.valueType(), output.shape(),
output.shape(), argType));
// todo(tianbing), argType = output.getArgType(), but default ADD_TO
argType));
} }
/// add and init output sparse matrix /// add and init output sparse matrix
void addOutputs(const SparseMatrixArg& output, ArgType argType = ASSIGN_TO) { void addOutputs(const SparseMatrixArg& output, ArgType argType = ASSIGN_TO) {
cpuSparse_ = std::make_shared<CpuSparseMatrix>(output.shape()[0], cpuSparse_ = std::make_shared<CpuSparseMatrix>(
output.shape()[1], output.shape()[0],
output.nnz(), output.shape()[1],
output.dataType(), output.nnz(),
output.dataFormat()); static_cast<SparseValueType>(output.dataType()),
static_cast<SparseFormat>(output.dataFormat()));
gpuSparse_ = std::make_shared<GpuSparseMatrix>(output.shape()[0],
output.shape()[1], gpuSparse_ = std::make_shared<GpuSparseMatrix>(
output.nnz(), output.shape()[0],
output.dataType(), output.shape()[1],
output.dataFormat()); output.nnz(),
static_cast<SparseValueType>(output.dataType()),
static_cast<SparseFormat>(output.dataFormat()));
/// init sparse matrix /// init sparse matrix
hl_stream_t stream(HPPL_STREAM_1); hl_stream_t stream(HPPL_STREAM_1);
...@@ -138,17 +137,19 @@ public: ...@@ -138,17 +137,19 @@ public:
} }
void addInputs(const SparseMatrixArg& input) { void addInputs(const SparseMatrixArg& input) {
cpuSparse_ = std::make_shared<CpuSparseMatrix>(input.shape()[0], cpuSparse_ = std::make_shared<CpuSparseMatrix>(
input.shape()[1], input.shape()[0],
input.nnz(), input.shape()[1],
input.dataType(), input.nnz(),
input.dataFormat()); static_cast<SparseValueType>(input.dataType()),
static_cast<SparseFormat>(input.dataFormat()));
gpuSparse_ = std::make_shared<GpuSparseMatrix>(input.shape()[0],
input.shape()[1], gpuSparse_ = std::make_shared<GpuSparseMatrix>(
input.nnz(), input.shape()[0],
input.dataType(), input.shape()[1],
input.dataFormat()); input.nnz(),
static_cast<SparseValueType>(input.dataType()),
static_cast<SparseFormat>(input.dataFormat()));
/// init sparse matrix /// init sparse matrix
hl_stream_t stream(HPPL_STREAM_1); hl_stream_t stream(HPPL_STREAM_1);
......
...@@ -41,6 +41,7 @@ inline void colVecAddTo( ...@@ -41,6 +41,7 @@ inline void colVecAddTo(
} // namespace } // namespace
namespace paddle { namespace paddle {
/// sparse matrix (+)= dense matrix * dense matrix
template <> template <>
void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
const CpuMatrix& a, const CpuMatrix& a,
...@@ -105,6 +106,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -105,6 +106,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
} }
} }
/// dense matrix (+)= dense matrix * dense matrix
template <> template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuMatrix& a, const CpuMatrix& a,
...@@ -129,6 +131,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -129,6 +131,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
out.getStride()); out.getStride());
} }
/// dense matrix (+)= sparse matrix * dense matrix
template <> template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuSparseMatrix& a, const CpuSparseMatrix& a,
...@@ -138,8 +141,6 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -138,8 +141,6 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans,
bool cTrans) { bool cTrans) {
CHECK_EQ(a.getFormat(), SPARSE_CSR)
<< "Not supported SPARSE_CSR format for a";
if (scaleT == 0) { if (scaleT == 0) {
out.zeroMem(); out.zeroMem();
} }
...@@ -165,6 +166,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -165,6 +166,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
} }
} }
/// dense matrix (+)= dense matrix * sparse matrix
template <> template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuMatrix& a, const CpuMatrix& a,
...@@ -183,7 +185,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -183,7 +185,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
int* rows = b.getRows(); int* rows = b.getRows();
int* cols = b.getCols(); int* cols = b.getCols();
/// b.getFormat() == SPARSE_CSC /// SPARSE_CSC format
if (b.getFormat() == SPARSE_CSC) { if (b.getFormat() == SPARSE_CSC) {
for (size_t j = 0; j < b.getWidth(); ++j) { for (size_t j = 0; j < b.getWidth(); ++j) {
int start = b.getColStartIdx(j); int start = b.getColStartIdx(j);
...@@ -200,7 +202,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -200,7 +202,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
return; return;
} }
/// b.getFormat() == SPARSE_CSR /// SPARSE_CSR format
if (b.getFormat() == SPARSE_CSR) { if (b.getFormat() == SPARSE_CSR) {
for (size_t j = 0; j < b.getHeight(); ++j) { for (size_t j = 0; j < b.getHeight(); ++j) {
int start = b.getRowStartIdx(j); int start = b.getRowStartIdx(j);
...@@ -220,11 +222,32 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -220,11 +222,32 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
/** /**
* mul operator * mul operator
* out = scaleT * out + scaleAB*(in1 * in2) * out = scaleT * out + scaleAB * (in1 * in2)
* here, scaleT in {0, 1}, scaleAB == 1,
* out = in1 (A) * in2 (B), ASSIGN_TO
* out += in1 (A) * in2 (B), ADD_TO
*
*
* \param outputs[0] output matrix (out), M * N,
* could be either Sparse or Dense Matrix
* M is num of rows, N is num of columns
* \param inputs[0] first input matrix (A), M * K (if non-trans)
* could be either Sparse or Dense Matrix
* M is num of rows, K is num of columns
* \param inputs[1] second input matrix (B), K * N (if non-trans)
* could be either Sparse or Dense Matrix
* K is num of rows, N is num of columns
*
* Support eight Mul operators, with both GPU and CPU devices
* For each device, four Mul operators are supported:
* 1. dense (out) = dense (A) * dense (B)
* 2. dense (out) = sparse (A) * dense (B)
* sparse matrix only support SPARSE_CSR format
* 3. dense (out) = dense (A) * sparse (B)
* sparse matrix support SPARSE_CSC and SPARSE_CSR formats
* 4. sparse (out) = dense (A) * dense (B)
* sparse matrix support SPARSE_CSC and SPARSE_CSR formats
* *
* \param outputs[0] output matrix, M * N
* \param inputs[0] first input (sparse) matrix, M * K (if non-trans)
* \param inputs[1] second input matrix, K * N (if non-trans)
*/ */
template <DeviceType Device> template <DeviceType Device>
class MulFunc : public FunctionBase { class MulFunc : public FunctionBase {
...@@ -271,7 +294,7 @@ public: ...@@ -271,7 +294,7 @@ public:
!inputs[1].isSparseArg())); !inputs[1].isSparseArg()));
auto outMat = outputs[0].matrix<Device>(); auto outMat = outputs[0].matrix<Device>();
/// matrix = matrix * matrix /// dense matrix = dense matrix * dense matrix
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() && if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) { !outputs[0].isSparseArg()) {
MulOp<Device>(outMat, MulOp<Device>(outMat,
...@@ -285,7 +308,7 @@ public: ...@@ -285,7 +308,7 @@ public:
return; return;
} }
/// matrix = matrix * sparse matrix /// dense matrix = dense matrix * sparse matrix
if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() && if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) { !outputs[0].isSparseArg()) {
CHECK(!aTrans_) << "Not supported a transpose"; CHECK(!aTrans_) << "Not supported a transpose";
...@@ -300,10 +323,12 @@ public: ...@@ -300,10 +323,12 @@ public:
return; return;
} }
/// matrix = sparse matrix * matrix /// dense matrix = sparse matrix * dense matrix
if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() && if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) { !outputs[0].isSparseArg()) {
CHECK(!bTrans_) << "Not supported b transpose"; CHECK(!bTrans_) << "Not supported b transpose";
CHECK_EQ(inputs[0].sparse().dataFormat(), T_SPARSE_CSR)
<< "Only supported SPARSE_CSR format for sparse matrix a";
MulOp<Device>(outMat, MulOp<Device>(outMat,
inputs[0].sparse().SparseMatrix<Device>(), inputs[0].sparse().SparseMatrix<Device>(),
inputs[1].matrix<Device>(), inputs[1].matrix<Device>(),
...@@ -315,7 +340,7 @@ public: ...@@ -315,7 +340,7 @@ public:
return; return;
} }
/// sparse matrix = matrix * matrix /// sparse matrix = dense matrix * dense matrix
auto outSparseMat = outputs[0].sparse().SparseMatrix<Device>(); auto outSparseMat = outputs[0].sparse().SparseMatrix<Device>();
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() && if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
outputs[0].isSparseArg()) { outputs[0].isSparseArg()) {
......
...@@ -15,12 +15,11 @@ limitations under the License. */ ...@@ -15,12 +15,11 @@ limitations under the License. */
#pragma once #pragma once
#include "Function.h" #include "Function.h"
/// todo(tianbing), delete it
#include <iostream>
#include "paddle/math/Matrix.h" #include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h" #include "paddle/math/SparseMatrix.h"
namespace paddle { namespace paddle {
/// CPU, dense matrix (+)= dense matrix * dense matrix
template <DeviceType DType> template <DeviceType DType>
void MulOp(CpuMatrix& out, void MulOp(CpuMatrix& out,
const CpuMatrix& a, const CpuMatrix& a,
...@@ -31,6 +30,7 @@ void MulOp(CpuMatrix& out, ...@@ -31,6 +30,7 @@ void MulOp(CpuMatrix& out,
bool bTrans, bool bTrans,
bool cTrans); bool cTrans);
/// CPU, dense matrix (+)= sparse matrix * dense matrix
template <DeviceType DType> template <DeviceType DType>
void MulOp(CpuMatrix& out, void MulOp(CpuMatrix& out,
const CpuSparseMatrix& a, const CpuSparseMatrix& a,
...@@ -41,6 +41,7 @@ void MulOp(CpuMatrix& out, ...@@ -41,6 +41,7 @@ void MulOp(CpuMatrix& out,
bool bTrans, bool bTrans,
bool cTrans); bool cTrans);
/// CPU, dense matrix (+)= dense matrix * sparse matrix
template <DeviceType DType> template <DeviceType DType>
void MulOp(CpuMatrix& out, void MulOp(CpuMatrix& out,
const CpuMatrix& a, const CpuMatrix& a,
...@@ -51,6 +52,7 @@ void MulOp(CpuMatrix& out, ...@@ -51,6 +52,7 @@ void MulOp(CpuMatrix& out,
bool bTrans, bool bTrans,
bool cTrans); bool cTrans);
/// CPU, sparse matrix (+)= dense matrix * dense matrix
template <DeviceType DType> template <DeviceType DType>
void MulOp(CpuSparseMatrix& out, void MulOp(CpuSparseMatrix& out,
const CpuMatrix& a, const CpuMatrix& a,
...@@ -61,6 +63,7 @@ void MulOp(CpuSparseMatrix& out, ...@@ -61,6 +63,7 @@ void MulOp(CpuSparseMatrix& out,
bool bTrans, bool bTrans,
bool cTrans); bool cTrans);
/// GPU, dense matrix (+)= dense matrix * dense matrix
template <DeviceType DType> template <DeviceType DType>
void MulOp(GpuMatrix& out, void MulOp(GpuMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
...@@ -71,6 +74,7 @@ void MulOp(GpuMatrix& out, ...@@ -71,6 +74,7 @@ void MulOp(GpuMatrix& out,
bool bTrans, bool bTrans,
bool cTrans); bool cTrans);
/// GPU, dense matrix (+)= sparse matrix * dense matrix
template <DeviceType DType> template <DeviceType DType>
void MulOp(GpuMatrix& out, void MulOp(GpuMatrix& out,
const GpuSparseMatrix& a, const GpuSparseMatrix& a,
...@@ -81,6 +85,7 @@ void MulOp(GpuMatrix& out, ...@@ -81,6 +85,7 @@ void MulOp(GpuMatrix& out,
bool bTrans, bool bTrans,
bool cTrans); bool cTrans);
/// GPU, dense matrix (+)= dense matrix * sparse matrix
template <DeviceType DType> template <DeviceType DType>
void MulOp(GpuMatrix& out, void MulOp(GpuMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
...@@ -90,7 +95,7 @@ void MulOp(GpuMatrix& out, ...@@ -90,7 +95,7 @@ void MulOp(GpuMatrix& out,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans,
bool cTrans); bool cTrans);
/// GPU, sparse matrix (+)= dense matrix * dense matrix
template <DeviceType DType> template <DeviceType DType>
void MulOp(GpuSparseMatrix& out, void MulOp(GpuSparseMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
......
...@@ -18,10 +18,7 @@ limitations under the License. */ ...@@ -18,10 +18,7 @@ limitations under the License. */
#include "paddle/math/SparseMatrix.h" #include "paddle/math/SparseMatrix.h"
namespace paddle { namespace paddle {
/** /// dense matrix (+)= dense matrix * dense matrix
* out = scaleT * out + scaleAB * (a * b)
* out : output matrix, M * N
*/
template <> template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
...@@ -32,14 +29,11 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -32,14 +29,11 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
bool bTrans, bool bTrans,
bool cTrans) { bool cTrans) {
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match"; CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
real* aData = const_cast<real*>(a.getData()); hl_matrix_mul(const_cast<real*>(a.getData()),
real* bData = const_cast<real*>(b.getData());
real* outData = const_cast<real*>(out.getData());
hl_matrix_mul(aData,
!aTrans ? HPPL_OP_N : HPPL_OP_T, !aTrans ? HPPL_OP_N : HPPL_OP_T,
bData, const_cast<real*>(b.getData()),
!bTrans ? HPPL_OP_N : HPPL_OP_T, !bTrans ? HPPL_OP_N : HPPL_OP_T,
outData, const_cast<real*>(out.getData()),
out.getHeight(), out.getHeight(),
out.getWidth(), out.getWidth(),
!aTrans ? a.getWidth() : a.getHeight(), !aTrans ? a.getWidth() : a.getHeight(),
...@@ -50,10 +44,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -50,10 +44,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
out.getStride()); out.getStride());
} }
/** /// dense matrix (+)= sparse matrix * dense matrix
* out = scaleT * out + scaleAB * (a * b)
* out : M * N
*/
template <> template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuSparseMatrix& a, const GpuSparseMatrix& a,
...@@ -66,15 +57,11 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -66,15 +57,11 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
CHECK(out.isContiguous()); CHECK(out.isContiguous());
CHECK(b.isContiguous()); CHECK(b.isContiguous());
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match"; CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
hl_matrix_csr_mul_dense(a.sMatrix_.get(),
hl_sparse_matrix_s aData = a.sMatrix_.get();
real* bData = const_cast<real*>(b.getData());
real* outData = const_cast<real*>(out.getData());
hl_matrix_csr_mul_dense(aData,
aTrans ? HPPL_OP_T : HPPL_OP_N, aTrans ? HPPL_OP_T : HPPL_OP_N,
bData, const_cast<real*>(b.getData()),
HPPL_OP_N, HPPL_OP_N,
outData, const_cast<real*>(out.getData()),
out.getHeight(), out.getHeight(),
out.getWidth(), out.getWidth(),
b.getHeight(), b.getHeight(),
...@@ -82,10 +69,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -82,10 +69,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
scaleT); scaleT);
} }
/** /// dense matrix (+)= dense matrix * sparse matrix
* out = scaleT * out + scaleAB * (a * b)
* out : M * N
*/
template <> template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
...@@ -99,27 +83,23 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -99,27 +83,23 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
CHECK(a.isContiguous()); CHECK(a.isContiguous());
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match"; CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
hl_sparse_matrix_s bData = b.sMatrix_.get();
real* aData = const_cast<real*>(a.getData());
real* outData = const_cast<real*>(out.getData());
if (b.format_ == SPARSE_CSC) { if (b.format_ == SPARSE_CSC) {
hl_matrix_dense_mul_csc(aData, hl_matrix_dense_mul_csc(const_cast<real*>(a.getData()),
HPPL_OP_N, HPPL_OP_N,
bData, b.sMatrix_.get(),
bTrans ? HPPL_OP_T : HPPL_OP_N, bTrans ? HPPL_OP_T : HPPL_OP_N,
outData, const_cast<real*>(out.getData()),
out.getHeight(), out.getHeight(),
out.getWidth(), out.getWidth(),
a.getWidth(), a.getWidth(),
scaleAB, scaleAB,
scaleT); scaleT);
} else { } else {
hl_matrix_dense_mul_csr(aData, hl_matrix_dense_mul_csr(const_cast<real*>(a.getData()),
HPPL_OP_N, HPPL_OP_N,
bData, b.sMatrix_.get(),
bTrans ? HPPL_OP_T : HPPL_OP_N, bTrans ? HPPL_OP_T : HPPL_OP_N,
outData, const_cast<real*>(out.getData()),
out.getHeight(), out.getHeight(),
out.getWidth(), out.getWidth(),
a.getWidth(), a.getWidth(),
...@@ -128,6 +108,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -128,6 +108,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
} }
} }
/// sparse matrix (+)= dense matrix * dense matrix
template <> template <>
void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out, void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
...@@ -138,16 +119,11 @@ void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out, ...@@ -138,16 +119,11 @@ void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
bool bTrans, bool bTrans,
bool cTrans) { bool cTrans) {
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match"; CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
hl_sparse_matrix_mul(const_cast<real*>(a.getData()),
real* aData = const_cast<real*>(a.getData());
real* bData = const_cast<real*>(b.getData());
hl_sparse_matrix_s outData = out.sMatrix_.get();
hl_sparse_matrix_mul(aData,
aTrans ? HPPL_OP_T : HPPL_OP_N, aTrans ? HPPL_OP_T : HPPL_OP_N,
bData, const_cast<real*>(b.getData()),
bTrans ? HPPL_OP_T : HPPL_OP_N, bTrans ? HPPL_OP_T : HPPL_OP_N,
outData, out.sMatrix_.get(),
out.getHeight(), out.getHeight(),
out.getWidth(), out.getWidth(),
!bTrans ? b.getHeight() : b.getWidth(), !bTrans ? b.getHeight() : b.getWidth(),
......
...@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
/// todo(tianbing), delete
#include <iostream>
#include "FunctionTest.h" #include "FunctionTest.h"
#include "paddle/math/Matrix.h" #include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h" #include "paddle/math/SparseMatrix.h"
......
...@@ -31,6 +31,10 @@ enum DeviceType { ...@@ -31,6 +31,10 @@ enum DeviceType {
DEVICE_TYPE_GPU = 2 DEVICE_TYPE_GPU = 2
}; };
enum SparseDataType { T_NO_VALUE = 0, T_FLOAT_VALUE = 1 };
enum SparseDataFormat { T_SPARSE_CSR = 0, T_SPARSE_CSC = 1 };
inline int sizeOfValuType(ValueType valueType) { inline int sizeOfValuType(ValueType valueType) {
if (valueType == VALUE_TYPE_INT32) { if (valueType == VALUE_TYPE_INT32) {
return 4; return 4;
......
...@@ -31,6 +31,7 @@ limitations under the License. */ ...@@ -31,6 +31,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
/// TODO(tianbing), move to paddle/function/TensorType.h
enum SparseValueType { NO_VALUE = 0, FLOAT_VALUE = 1 }; enum SparseValueType { NO_VALUE = 0, FLOAT_VALUE = 1 };
/** /**
...@@ -56,6 +57,7 @@ enum SparseValueType { NO_VALUE = 0, FLOAT_VALUE = 1 }; ...@@ -56,6 +57,7 @@ enum SparseValueType { NO_VALUE = 0, FLOAT_VALUE = 1 };
* value [1, 1, 2, 2, 5] * value [1, 1, 2, 2, 5]
* @endcode * @endcode
*/ */
/// TODO(tianbing), move to paddle/function/TensorType.h
enum SparseFormat { SPARSE_CSR = 0, SPARSE_CSC = 1 }; enum SparseFormat { SPARSE_CSR = 0, SPARSE_CSC = 1 };
class Matrix; class Matrix;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册