提交 5b1a5c11 编写于 作者: X xutianbing

Daoyuan's comments.

上级 999cd14a
...@@ -70,7 +70,7 @@ public: ...@@ -70,7 +70,7 @@ public:
} }
// output need only contains shape, do not contains data. // output need only contains shape, do not contains data.
void addOutputs(const BufferArg& output, ArgType argType = ADD_TO) { void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) {
size_t size = size_t size =
output.shape().getElements() * sizeOfValuType(output.valueType()); output.shape().getElements() * sizeOfValuType(output.valueType());
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size)); cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size));
......
...@@ -49,8 +49,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -49,8 +49,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans) {
bool cTrans) {
CHECK_EQ(out.getValueType(), FLOAT_VALUE); CHECK_EQ(out.getValueType(), FLOAT_VALUE);
if (scaleT == 0) { if (scaleT == 0) {
out.zeroMem(); out.zeroMem();
...@@ -114,8 +113,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -114,8 +113,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans) {
bool cTrans) {
GEMM(aTrans ? CblasTrans : CblasNoTrans, GEMM(aTrans ? CblasTrans : CblasNoTrans,
bTrans ? CblasTrans : CblasNoTrans, bTrans ? CblasTrans : CblasNoTrans,
out.getHeight(), out.getHeight(),
...@@ -139,8 +137,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -139,8 +137,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans) {
bool cTrans) {
if (scaleT == 0) { if (scaleT == 0) {
out.zeroMem(); out.zeroMem();
} }
...@@ -174,8 +171,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -174,8 +171,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans) {
bool cTrans) {
if (scaleT == 0) { if (scaleT == 0) {
out.zeroMem(); out.zeroMem();
} }
...@@ -222,10 +218,10 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -222,10 +218,10 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
/** /**
* mul operator * mul operator
* out = scaleT * out + scaleAB * (in1 * in2) * out = scaleT * out + scaleAB * (A * B)
* here, scaleT in {0, 1}, scaleAB == 1, * here, scaleT in {0, 1}, scaleAB == 1,
* out = in1 (A) * in2 (B), ASSIGN_TO * out = A * B, ASSIGN_TO
* out += in1 (A) * in2 (B), ADD_TO * out += A * B, ADD_TO
* *
* *
* \param outputs[0] output matrix (out), M * N, * \param outputs[0] output matrix (out), M * N,
...@@ -253,15 +249,11 @@ template <DeviceType Device> ...@@ -253,15 +249,11 @@ template <DeviceType Device>
class MulFunc : public FunctionBase { class MulFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
alpha_ = config.get<real>("scaleAB");
beta_ = config.get<real>("scaleT");
aTrans_ = config.get<bool>("aTrans"); aTrans_ = config.get<bool>("aTrans");
bTrans_ = config.get<bool>("bTrans"); bTrans_ = config.get<bool>("bTrans");
cTrans_ = config.get<bool>("cTrans");
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK(!cTrans_) << "output matrix should not be transposed";
CHECK(!aTrans_ || !bTrans_) CHECK(!aTrans_ || !bTrans_)
<< "Not support both a and b are transpose matrices"; << "Not support both a and b are transpose matrices";
...@@ -281,10 +273,8 @@ public: ...@@ -281,10 +273,8 @@ public:
CHECK_EQ(aRow, outputs[0].shape()[0]); CHECK_EQ(aRow, outputs[0].shape()[0]);
CHECK_EQ(bCol, outputs[0].shape()[1]); CHECK_EQ(bCol, outputs[0].shape()[1]);
/// only support C = A * B or C += A * B /// only support C = A * B (ASSIGN_TO) or C += A * B (ADD_TO)
CHECK_EQ(alpha_, static_cast<real>(1.0)); real scaleT = (outputs[0].getArgType() == ADD_TO) ? 1.0 : 0.0;
CHECK((beta_ == 0 && outputs[0].getArgType() == ASSIGN_TO) ||
(beta_ == 1 && outputs[0].getArgType() == ADD_TO));
/// support dense = not both sparse * sparse /// support dense = not both sparse * sparse
/// or sparse = dense * dense /// or sparse = dense * dense
...@@ -300,11 +290,10 @@ public: ...@@ -300,11 +290,10 @@ public:
MulOp<Device>(outMat, MulOp<Device>(outMat,
inputs[0].matrix<Device>(), inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(), inputs[1].matrix<Device>(),
alpha_, 1.0, // scaleAB
beta_, scaleT,
aTrans_, aTrans_,
bTrans_, bTrans_);
cTrans_);
return; return;
} }
...@@ -315,11 +304,10 @@ public: ...@@ -315,11 +304,10 @@ public:
MulOp<Device>(outMat, MulOp<Device>(outMat,
inputs[0].matrix<Device>(), inputs[0].matrix<Device>(),
inputs[1].sparse().SparseMatrix<Device>(), inputs[1].sparse().SparseMatrix<Device>(),
alpha_, 1.0, // scaleAB
beta_, scaleT,
aTrans_, aTrans_,
bTrans_, bTrans_);
cTrans_);
return; return;
} }
...@@ -332,11 +320,10 @@ public: ...@@ -332,11 +320,10 @@ public:
MulOp<Device>(outMat, MulOp<Device>(outMat,
inputs[0].sparse().SparseMatrix<Device>(), inputs[0].sparse().SparseMatrix<Device>(),
inputs[1].matrix<Device>(), inputs[1].matrix<Device>(),
alpha_, 1.0, // scaleAB
beta_, scaleT,
aTrans_, aTrans_,
bTrans_, bTrans_);
cTrans_);
return; return;
} }
...@@ -347,21 +334,17 @@ public: ...@@ -347,21 +334,17 @@ public:
MulOp<Device>(outSparseMat, MulOp<Device>(outSparseMat,
inputs[0].matrix<Device>(), inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(), inputs[1].matrix<Device>(),
alpha_, 1.0, // scaleAB
beta_, scaleT,
aTrans_, aTrans_,
bTrans_, bTrans_);
cTrans_);
return; return;
} }
} }
private: private:
real alpha_;
real beta_;
bool aTrans_; bool aTrans_;
bool bTrans_; bool bTrans_;
bool cTrans_;
}; };
REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc); REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc);
......
...@@ -27,8 +27,7 @@ void MulOp(CpuMatrix& out, ...@@ -27,8 +27,7 @@ void MulOp(CpuMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans);
bool cTrans);
/// CPU, dense matrix (+)= sparse matrix * dense matrix /// CPU, dense matrix (+)= sparse matrix * dense matrix
template <DeviceType DType> template <DeviceType DType>
...@@ -38,8 +37,7 @@ void MulOp(CpuMatrix& out, ...@@ -38,8 +37,7 @@ void MulOp(CpuMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans);
bool cTrans);
/// CPU, dense matrix (+)= dense matrix * sparse matrix /// CPU, dense matrix (+)= dense matrix * sparse matrix
template <DeviceType DType> template <DeviceType DType>
...@@ -49,8 +47,7 @@ void MulOp(CpuMatrix& out, ...@@ -49,8 +47,7 @@ void MulOp(CpuMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans);
bool cTrans);
/// CPU, sparse matrix (+)= dense matrix * dense matrix /// CPU, sparse matrix (+)= dense matrix * dense matrix
template <DeviceType DType> template <DeviceType DType>
...@@ -60,8 +57,7 @@ void MulOp(CpuSparseMatrix& out, ...@@ -60,8 +57,7 @@ void MulOp(CpuSparseMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans);
bool cTrans);
/// GPU, dense matrix (+)= dense matrix * dense matrix /// GPU, dense matrix (+)= dense matrix * dense matrix
template <DeviceType DType> template <DeviceType DType>
...@@ -71,8 +67,7 @@ void MulOp(GpuMatrix& out, ...@@ -71,8 +67,7 @@ void MulOp(GpuMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans);
bool cTrans);
/// GPU, dense matrix (+)= sparse matrix * dense matrix /// GPU, dense matrix (+)= sparse matrix * dense matrix
template <DeviceType DType> template <DeviceType DType>
...@@ -82,8 +77,7 @@ void MulOp(GpuMatrix& out, ...@@ -82,8 +77,7 @@ void MulOp(GpuMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans);
bool cTrans);
/// GPU, dense matrix (+)= dense matrix * sparse matrix /// GPU, dense matrix (+)= dense matrix * sparse matrix
template <DeviceType DType> template <DeviceType DType>
...@@ -93,8 +87,8 @@ void MulOp(GpuMatrix& out, ...@@ -93,8 +87,8 @@ void MulOp(GpuMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans);
bool cTrans);
/// GPU, sparse matrix (+)= dense matrix * dense matrix /// GPU, sparse matrix (+)= dense matrix * dense matrix
template <DeviceType DType> template <DeviceType DType>
void MulOp(GpuSparseMatrix& out, void MulOp(GpuSparseMatrix& out,
...@@ -103,7 +97,6 @@ void MulOp(GpuSparseMatrix& out, ...@@ -103,7 +97,6 @@ void MulOp(GpuSparseMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans);
bool cTrans);
} // namespace paddle } // namespace paddle
...@@ -26,8 +26,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -26,8 +26,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans) {
bool cTrans) {
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match"; CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
hl_matrix_mul(const_cast<real*>(a.getData()), hl_matrix_mul(const_cast<real*>(a.getData()),
!aTrans ? HPPL_OP_N : HPPL_OP_T, !aTrans ? HPPL_OP_N : HPPL_OP_T,
...@@ -52,8 +51,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -52,8 +51,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans) {
bool cTrans) {
CHECK(out.isContiguous()); CHECK(out.isContiguous());
CHECK(b.isContiguous()); CHECK(b.isContiguous());
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match"; CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
...@@ -77,8 +75,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -77,8 +75,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans) {
bool cTrans) {
CHECK(out.isContiguous()); CHECK(out.isContiguous());
CHECK(a.isContiguous()); CHECK(a.isContiguous());
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match"; CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
...@@ -116,8 +113,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out, ...@@ -116,8 +113,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
real scaleAB, real scaleAB,
real scaleT, real scaleT,
bool aTrans, bool aTrans,
bool bTrans, bool bTrans) {
bool cTrans) {
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match"; CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
hl_sparse_matrix_mul(const_cast<real*>(a.getData()), hl_sparse_matrix_mul(const_cast<real*>(a.getData()),
aTrans ? HPPL_OP_T : HPPL_OP_N, aTrans ? HPPL_OP_T : HPPL_OP_N,
......
...@@ -27,8 +27,7 @@ using namespace paddle; // NOLINT ...@@ -27,8 +27,7 @@ using namespace paddle; // NOLINT
*/ */
void testFuncDDDMatrix( void testFuncDDDMatrix(
bool transa, bool transb, size_t dimM, size_t dimN, size_t dimK) { bool transa, bool transb, size_t dimM, size_t dimN, size_t dimK) {
real alpha = 1.0; real scaleT = 1.0;
real beta = 1.0;
size_t heightA = (transa == false) ? dimM : dimK; size_t heightA = (transa == false) ? dimM : dimK;
size_t widthA = (transa == false) ? dimK : dimM; size_t widthA = (transa == false) ? dimK : dimM;
size_t heightB = (transb == false) ? dimK : dimN; size_t heightB = (transb == false) ? dimK : dimN;
...@@ -36,13 +35,8 @@ void testFuncDDDMatrix( ...@@ -36,13 +35,8 @@ void testFuncDDDMatrix(
size_t heightC = dimM; size_t heightC = dimM;
size_t widthC = dimN; size_t widthC = dimN;
// init Test object // init Test object
FunctionCompare test("MulOp", FunctionCompare test(
FuncConfig() "MulOp", FuncConfig().set("aTrans", transa).set("bTrans", transb));
.set("scaleAB", alpha)
.set("scaleT", beta)
.set("aTrans", transa)
.set("bTrans", transb)
.set("cTrans", false));
// prepare input arguments // prepare input arguments
/// matrix A : HA * WA /// matrix A : HA * WA
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightA, widthA})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightA, widthA}));
...@@ -51,7 +45,7 @@ void testFuncDDDMatrix( ...@@ -51,7 +45,7 @@ void testFuncDDDMatrix(
/// output matrix C: HC * WC /// output matrix C: HC * WC
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightC, widthC}), test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightC, widthC}),
beta == 1.0 ? ADD_TO : ASSIGN_TO); scaleT == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function // run Function
test.run(); test.run();
} }
...@@ -85,16 +79,10 @@ TEST(MulOp, DDDMatrixMul) { ...@@ -85,16 +79,10 @@ TEST(MulOp, DDDMatrixMul) {
*/ */
void testFuncDSparseDMatrix( void testFuncDSparseDMatrix(
size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) { size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) {
real alpha = 1.0; real scaleT = 1.0;
real beta = 1.0;
// init Test object // init Test object
FunctionCompare test("MulOp", FunctionCompare test("MulOp",
FuncConfig() FuncConfig().set("aTrans", false).set("bTrans", false));
.set("scaleAB", alpha)
.set("scaleT", beta)
.set("aTrans", false)
.set("bTrans", false)
.set("cTrans", false));
// prepare input arguments // prepare input arguments
/// sparse matrix A : M * K /// sparse matrix A : M * K
test.addInputs(SparseMatrixArg( test.addInputs(SparseMatrixArg(
...@@ -104,7 +92,7 @@ void testFuncDSparseDMatrix( ...@@ -104,7 +92,7 @@ void testFuncDSparseDMatrix(
/// output matrix C: M * N /// output matrix C: M * N
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}), test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}),
beta == 1.0 ? ADD_TO : ASSIGN_TO); scaleT == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function // run Function
test.run(); test.run();
} }
...@@ -136,16 +124,10 @@ TEST(MuLOp, DSparseDMul) { ...@@ -136,16 +124,10 @@ TEST(MuLOp, DSparseDMul) {
*/ */
void testFuncDDSparseMatrix( void testFuncDDSparseMatrix(
size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) { size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) {
real alpha = 1.0; real scaleT = 1.0;
real beta = 1.0;
// init Test object // init Test object
FunctionCompare test("MulOp", FunctionCompare test("MulOp",
FuncConfig() FuncConfig().set("aTrans", false).set("bTrans", false));
.set("scaleAB", alpha)
.set("scaleT", beta)
.set("aTrans", false)
.set("bTrans", false)
.set("cTrans", false));
// prepare input arguments // prepare input arguments
/// matrix A : M * K /// matrix A : M * K
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK}));
...@@ -156,7 +138,7 @@ void testFuncDDSparseMatrix( ...@@ -156,7 +138,7 @@ void testFuncDDSparseMatrix(
/// output matrix C: M * N /// output matrix C: M * N
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}), test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}),
beta == 1.0 ? ADD_TO : ASSIGN_TO); scaleT == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function // run Function
test.run(); test.run();
} }
...@@ -188,16 +170,10 @@ TEST(MulOp, DDSparseMul) { ...@@ -188,16 +170,10 @@ TEST(MulOp, DDSparseMul) {
*/ */
void testFuncSparseDDMatrix( void testFuncSparseDDMatrix(
size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) { size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) {
real alpha = 1.0; real scaleT = 1.0;
real beta = 1.0;
// init Test object // init Test object
FunctionCompare test("MulOp", FunctionCompare test("MulOp",
FuncConfig() FuncConfig().set("aTrans", false).set("bTrans", false));
.set("scaleAB", alpha)
.set("scaleT", beta)
.set("aTrans", false)
.set("bTrans", false)
.set("cTrans", false));
// prepare input arguments // prepare input arguments
/// matrix A : M * K /// matrix A : M * K
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK}));
...@@ -209,7 +185,7 @@ void testFuncSparseDDMatrix( ...@@ -209,7 +185,7 @@ void testFuncSparseDDMatrix(
test.addOutputs( test.addOutputs(
SparseMatrixArg( SparseMatrixArg(
VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}, nnz, FORMAT, FLOAT_VALUE), VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}, nnz, FORMAT, FLOAT_VALUE),
beta == 1.0 ? ADD_TO : ASSIGN_TO); scaleT == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function // run Function
test.run(); test.run();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册