提交 5b1a5c11 编写于 作者: X xutianbing

Daoyuan's comments.

上级 999cd14a
......@@ -70,7 +70,7 @@ public:
}
// output need only contains shape, do not contains data.
void addOutputs(const BufferArg& output, ArgType argType = ADD_TO) {
void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) {
size_t size =
output.shape().getElements() * sizeOfValuType(output.valueType());
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size));
......
......@@ -49,8 +49,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans) {
bool bTrans) {
CHECK_EQ(out.getValueType(), FLOAT_VALUE);
if (scaleT == 0) {
out.zeroMem();
......@@ -114,8 +113,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans) {
bool bTrans) {
GEMM(aTrans ? CblasTrans : CblasNoTrans,
bTrans ? CblasTrans : CblasNoTrans,
out.getHeight(),
......@@ -139,8 +137,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans) {
bool bTrans) {
if (scaleT == 0) {
out.zeroMem();
}
......@@ -174,8 +171,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans) {
bool bTrans) {
if (scaleT == 0) {
out.zeroMem();
}
......@@ -222,10 +218,10 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
/**
* mul operator
* out = scaleT * out + scaleAB * (in1 * in2)
* out = scaleT * out + scaleAB * (A * B)
* here, scaleT in {0, 1}, scaleAB == 1,
* out = in1 (A) * in2 (B), ASSIGN_TO
* out += in1 (A) * in2 (B), ADD_TO
* out = A * B, ASSIGN_TO
* out += A * B, ADD_TO
*
*
* \param outputs[0] output matrix (out), M * N,
......@@ -253,15 +249,11 @@ template <DeviceType Device>
class MulFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
alpha_ = config.get<real>("scaleAB");
beta_ = config.get<real>("scaleT");
aTrans_ = config.get<bool>("aTrans");
bTrans_ = config.get<bool>("bTrans");
cTrans_ = config.get<bool>("cTrans");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK(!cTrans_) << "output matrix should not be transposed";
CHECK(!aTrans_ || !bTrans_)
<< "Not support both a and b are transpose matrices";
......@@ -281,10 +273,8 @@ public:
CHECK_EQ(aRow, outputs[0].shape()[0]);
CHECK_EQ(bCol, outputs[0].shape()[1]);
/// only support C = A * B or C += A * B
CHECK_EQ(alpha_, static_cast<real>(1.0));
CHECK((beta_ == 0 && outputs[0].getArgType() == ASSIGN_TO) ||
(beta_ == 1 && outputs[0].getArgType() == ADD_TO));
/// only support C = A * B (ASSIGN_TO) or C += A * B (ADD_TO)
real scaleT = (outputs[0].getArgType() == ADD_TO) ? 1.0 : 0.0;
/// support dense = not both sparse * sparse
/// or sparse = dense * dense
......@@ -300,11 +290,10 @@ public:
MulOp<Device>(outMat,
inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(),
alpha_,
beta_,
1.0, // scaleAB
scaleT,
aTrans_,
bTrans_,
cTrans_);
bTrans_);
return;
}
......@@ -315,11 +304,10 @@ public:
MulOp<Device>(outMat,
inputs[0].matrix<Device>(),
inputs[1].sparse().SparseMatrix<Device>(),
alpha_,
beta_,
1.0, // scaleAB
scaleT,
aTrans_,
bTrans_,
cTrans_);
bTrans_);
return;
}
......@@ -332,11 +320,10 @@ public:
MulOp<Device>(outMat,
inputs[0].sparse().SparseMatrix<Device>(),
inputs[1].matrix<Device>(),
alpha_,
beta_,
1.0, // scaleAB
scaleT,
aTrans_,
bTrans_,
cTrans_);
bTrans_);
return;
}
......@@ -347,21 +334,17 @@ public:
MulOp<Device>(outSparseMat,
inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(),
alpha_,
beta_,
1.0, // scaleAB
scaleT,
aTrans_,
bTrans_,
cTrans_);
bTrans_);
return;
}
}
private:
real alpha_;
real beta_;
bool aTrans_;
bool bTrans_;
bool cTrans_;
};
REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc);
......
......@@ -27,8 +27,7 @@ void MulOp(CpuMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
bool bTrans);
/// CPU, dense matrix (+)= sparse matrix * dense matrix
template <DeviceType DType>
......@@ -38,8 +37,7 @@ void MulOp(CpuMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
bool bTrans);
/// CPU, dense matrix (+)= dense matrix * sparse matrix
template <DeviceType DType>
......@@ -49,8 +47,7 @@ void MulOp(CpuMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
bool bTrans);
/// CPU, sparse matrix (+)= dense matrix * dense matrix
template <DeviceType DType>
......@@ -60,8 +57,7 @@ void MulOp(CpuSparseMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
bool bTrans);
/// GPU, dense matrix (+)= dense matrix * dense matrix
template <DeviceType DType>
......@@ -71,8 +67,7 @@ void MulOp(GpuMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
bool bTrans);
/// GPU, dense matrix (+)= sparse matrix * dense matrix
template <DeviceType DType>
......@@ -82,8 +77,7 @@ void MulOp(GpuMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
bool bTrans);
/// GPU, dense matrix (+)= dense matrix * sparse matrix
template <DeviceType DType>
......@@ -93,8 +87,8 @@ void MulOp(GpuMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
bool bTrans);
/// GPU, sparse matrix (+)= dense matrix * dense matrix
template <DeviceType DType>
void MulOp(GpuSparseMatrix& out,
......@@ -103,7 +97,6 @@ void MulOp(GpuSparseMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
bool bTrans);
} // namespace paddle
......@@ -26,8 +26,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans) {
bool bTrans) {
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
hl_matrix_mul(const_cast<real*>(a.getData()),
!aTrans ? HPPL_OP_N : HPPL_OP_T,
......@@ -52,8 +51,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans) {
bool bTrans) {
CHECK(out.isContiguous());
CHECK(b.isContiguous());
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
......@@ -77,8 +75,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans) {
bool bTrans) {
CHECK(out.isContiguous());
CHECK(a.isContiguous());
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
......@@ -116,8 +113,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
real scaleAB,
real scaleT,
bool aTrans,
bool bTrans,
bool cTrans) {
bool bTrans) {
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
hl_sparse_matrix_mul(const_cast<real*>(a.getData()),
aTrans ? HPPL_OP_T : HPPL_OP_N,
......
......@@ -27,8 +27,7 @@ using namespace paddle; // NOLINT
*/
void testFuncDDDMatrix(
bool transa, bool transb, size_t dimM, size_t dimN, size_t dimK) {
real alpha = 1.0;
real beta = 1.0;
real scaleT = 1.0;
size_t heightA = (transa == false) ? dimM : dimK;
size_t widthA = (transa == false) ? dimK : dimM;
size_t heightB = (transb == false) ? dimK : dimN;
......@@ -36,13 +35,8 @@ void testFuncDDDMatrix(
size_t heightC = dimM;
size_t widthC = dimN;
// init Test object
FunctionCompare test("MulOp",
FuncConfig()
.set("scaleAB", alpha)
.set("scaleT", beta)
.set("aTrans", transa)
.set("bTrans", transb)
.set("cTrans", false));
FunctionCompare test(
"MulOp", FuncConfig().set("aTrans", transa).set("bTrans", transb));
// prepare input arguments
/// matrix A : HA * WA
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightA, widthA}));
......@@ -51,7 +45,7 @@ void testFuncDDDMatrix(
/// output matrix C: HC * WC
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightC, widthC}),
beta == 1.0 ? ADD_TO : ASSIGN_TO);
scaleT == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function
test.run();
}
......@@ -85,16 +79,10 @@ TEST(MulOp, DDDMatrixMul) {
*/
void testFuncDSparseDMatrix(
size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) {
real alpha = 1.0;
real beta = 1.0;
real scaleT = 1.0;
// init Test object
FunctionCompare test("MulOp",
FuncConfig()
.set("scaleAB", alpha)
.set("scaleT", beta)
.set("aTrans", false)
.set("bTrans", false)
.set("cTrans", false));
FuncConfig().set("aTrans", false).set("bTrans", false));
// prepare input arguments
/// sparse matrix A : M * K
test.addInputs(SparseMatrixArg(
......@@ -104,7 +92,7 @@ void testFuncDSparseDMatrix(
/// output matrix C: M * N
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}),
beta == 1.0 ? ADD_TO : ASSIGN_TO);
scaleT == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function
test.run();
}
......@@ -136,16 +124,10 @@ TEST(MuLOp, DSparseDMul) {
*/
void testFuncDDSparseMatrix(
size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) {
real alpha = 1.0;
real beta = 1.0;
real scaleT = 1.0;
// init Test object
FunctionCompare test("MulOp",
FuncConfig()
.set("scaleAB", alpha)
.set("scaleT", beta)
.set("aTrans", false)
.set("bTrans", false)
.set("cTrans", false));
FuncConfig().set("aTrans", false).set("bTrans", false));
// prepare input arguments
/// matrix A : M * K
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK}));
......@@ -156,7 +138,7 @@ void testFuncDDSparseMatrix(
/// output matrix C: M * N
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}),
beta == 1.0 ? ADD_TO : ASSIGN_TO);
scaleT == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function
test.run();
}
......@@ -188,16 +170,10 @@ TEST(MulOp, DDSparseMul) {
*/
void testFuncSparseDDMatrix(
size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) {
real alpha = 1.0;
real beta = 1.0;
real scaleT = 1.0;
// init Test object
FunctionCompare test("MulOp",
FuncConfig()
.set("scaleAB", alpha)
.set("scaleT", beta)
.set("aTrans", false)
.set("bTrans", false)
.set("cTrans", false));
FuncConfig().set("aTrans", false).set("bTrans", false));
// prepare input arguments
/// matrix A : M * K
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK}));
......@@ -209,7 +185,7 @@ void testFuncSparseDDMatrix(
test.addOutputs(
SparseMatrixArg(
VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}, nnz, FORMAT, FLOAT_VALUE),
beta == 1.0 ? ADD_TO : ASSIGN_TO);
scaleT == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function
test.run();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册