提交 b3be7358 编写于 作者: X xutianbing

Daoyuan's comments.

上级 bc5d7bb6
...@@ -71,24 +71,17 @@ public: ...@@ -71,24 +71,17 @@ public:
public: public:
BufferArg(ValueType valueType, BufferArg(ValueType valueType,
const TensorShape& shape, const TensorShape& shape,
ArgType argType = UNSPECIFIED, ArgType argType = UNSPECIFIED)
bool trans = false)
: buf_(nullptr), : buf_(nullptr),
valueType_(valueType), valueType_(valueType),
shape_(shape), shape_(shape),
argType_(argType), argType_(argType) {}
trans_(trans) {}
BufferArg(void* buf, BufferArg(void* buf,
ValueType valueType, ValueType valueType,
const TensorShape& shape, const TensorShape& shape,
ArgType argType = UNSPECIFIED, ArgType argType = UNSPECIFIED)
bool trans = false) : buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {}
: buf_(buf),
valueType_(valueType),
shape_(shape),
argType_(argType),
trans_(trans) {}
BufferArg(void* buf, ValueType valueType) BufferArg(void* buf, ValueType valueType)
: buf_(buf), valueType_(valueType) {} : buf_(buf), valueType_(valueType) {}
...@@ -98,8 +91,7 @@ public: ...@@ -98,8 +91,7 @@ public:
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))), const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(2), shape_(2),
argType_(argType), argType_(argType) {
trans_(matrix.isTransposed()) {
bufferType_ = TENSOR_NORMAL; bufferType_ = TENSOR_NORMAL;
shape_.setDim(0, matrix.getHeight()); shape_.setDim(0, matrix.getHeight());
shape_.setDim(1, matrix.getWidth()); shape_.setDim(1, matrix.getWidth());
...@@ -112,8 +104,7 @@ public: ...@@ -112,8 +104,7 @@ public:
const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))), const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(shape), shape_(shape),
argType_(argType), argType_(argType) {
trans_(matrix.isTransposed()) {
bufferType_ = TENSOR_NORMAL; bufferType_ = TENSOR_NORMAL;
CHECK_EQ(matrix.getElementCnt(), shape.getElements()); CHECK_EQ(matrix.getElementCnt(), shape.getElements());
} }
...@@ -145,7 +136,7 @@ public: ...@@ -145,7 +136,7 @@ public:
// CHECK(deviceType_ == DType); // CHECK(deviceType_ == DType);
CHECK_EQ((size_t)2, shape_.ndims()); CHECK_EQ((size_t)2, shape_.ndims());
return typename Tensor<real, DType>::Matrix( return typename Tensor<real, DType>::Matrix(
reinterpret_cast<real*>(buf_), shape_[0], shape_[1], trans_); reinterpret_cast<real*>(buf_), shape_[0], shape_[1]);
} }
template <typename VType, DeviceType DType> template <typename VType, DeviceType DType>
...@@ -169,7 +160,6 @@ public: ...@@ -169,7 +160,6 @@ public:
ValueType valueType() const { return valueType_; } ValueType valueType() const { return valueType_; }
BufferType bufferType() const { return bufferType_; } BufferType bufferType() const { return bufferType_; }
const TensorShape& shape() const { return shape_; } const TensorShape& shape() const { return shape_; }
bool isTransposed() const { return trans_; }
bool isSparseArg() const { return TENSOR_SPARSE == bufferType_; } bool isSparseArg() const { return TENSOR_SPARSE == bufferType_; }
bool isSequenceArg() const { return TENSOR_SEQUENCE_DATA == bufferType_; } bool isSequenceArg() const { return TENSOR_SEQUENCE_DATA == bufferType_; }
virtual size_t numElements() const { return shape_.getElements(); } virtual size_t numElements() const { return shape_.getElements(); }
...@@ -183,7 +173,6 @@ protected: ...@@ -183,7 +173,6 @@ protected:
TensorShape shape_; TensorShape shape_;
BufferType bufferType_{TENSOR_UNKNOWN}; BufferType bufferType_{TENSOR_UNKNOWN};
ArgType argType_{UNSPECIFIED}; ArgType argType_{UNSPECIFIED};
bool trans_{false};
// todo(tianbing), add deviceType_ // todo(tianbing), add deviceType_
// leading dimensions. The size is dims_.size() // leading dimensions. The size is dims_.size()
// Dims lds_; // Dims lds_;
...@@ -277,9 +266,8 @@ public: ...@@ -277,9 +266,8 @@ public:
size_t nnz, size_t nnz,
SparseFormat format, SparseFormat format,
SparseValueType type, SparseValueType type,
ArgType argType = UNSPECIFIED, ArgType argType = UNSPECIFIED)
bool trans = false) : BufferArg(buf, valueType, shape, argType),
: BufferArg(buf, valueType, shape, argType, trans),
row_(row), row_(row),
col_(col), col_(col),
nnz_(nnz), nnz_(nnz),
...@@ -302,9 +290,8 @@ public: ...@@ -302,9 +290,8 @@ public:
size_t nnz, size_t nnz,
SparseFormat format, SparseFormat format,
SparseValueType type, SparseValueType type,
ArgType argType = UNSPECIFIED, ArgType argType = UNSPECIFIED)
bool trans = false) : BufferArg(valueType, shape, argType),
: BufferArg(valueType, shape, argType, trans),
/// len of row_ : height + 1 (CSR), buf_ == nullptr /// len of row_ : height + 1 (CSR), buf_ == nullptr
row_(format == SPARSE_CSR row_(format == SPARSE_CSR
? BufferArg(VALUE_TYPE_INT32, TensorShape{shape[0] + 1}) ? BufferArg(VALUE_TYPE_INT32, TensorShape{shape[0] + 1})
...@@ -343,7 +330,7 @@ public: ...@@ -343,7 +330,7 @@ public:
nnz_, nnz_,
type_, type_,
format_, format_,
trans_); false);
} }
~SparseMatrixArg() {} ~SparseMatrixArg() {}
......
...@@ -64,22 +64,14 @@ public: ...@@ -64,22 +64,14 @@ public:
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size)); cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size)); gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size));
cpuInputs_.emplace_back( cpuInputs_.emplace_back(std::make_shared<BufferArg>(
std::make_shared<BufferArg>(cpuMemory_.back()->getBuf(), cpuMemory_.back()->getBuf(), input.valueType(), input.shape()));
input.valueType(), gpuInputs_.emplace_back(std::make_shared<BufferArg>(
input.shape(), gpuMemory_.back()->getBuf(), input.valueType(), input.shape()));
UNSPECIFIED,
input.isTransposed()));
gpuInputs_.emplace_back(
std::make_shared<BufferArg>(gpuMemory_.back()->getBuf(),
input.valueType(),
input.shape(),
UNSPECIFIED,
input.isTransposed()));
} }
// output need only contains shape, do not contains data. // output need only contains shape, do not contains data.
void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) { void addOutputs(const BufferArg& output, ArgType argType = ADD_TO) {
size_t size = size_t size =
output.shape().getElements() * sizeOfValuType(output.valueType()); output.shape().getElements() * sizeOfValuType(output.valueType());
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size)); cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size));
...@@ -89,16 +81,14 @@ public: ...@@ -89,16 +81,14 @@ public:
cpuMemory_.back()->getBuf(), cpuMemory_.back()->getBuf(),
output.valueType(), output.valueType(),
output.shape(), output.shape(),
// todo(tianbing), argType = output.getArgType(), but default ASSIGN_TO // todo(tianbing), argType = output.getArgType(), but default ADD_TO
argType, argType));
output.isTransposed()));
gpuOutputs_.emplace_back(std::make_shared<BufferArg>( gpuOutputs_.emplace_back(std::make_shared<BufferArg>(
gpuMemory_.back()->getBuf(), gpuMemory_.back()->getBuf(),
output.valueType(), output.valueType(),
output.shape(), output.shape(),
// todo(tianbing), argType = output.getArgType(), but default ASSIGN_TO // todo(tianbing), argType = output.getArgType(), but default ADD_TO
argType, argType));
output.isTransposed()));
} }
/// add and init output sparse matrix /// add and init output sparse matrix
...@@ -107,15 +97,13 @@ public: ...@@ -107,15 +97,13 @@ public:
output.shape()[1], output.shape()[1],
output.nnz(), output.nnz(),
output.dataType(), output.dataType(),
output.dataFormat(), output.dataFormat());
output.isTransposed());
gpuSparse_ = std::make_shared<GpuSparseMatrix>(output.shape()[0], gpuSparse_ = std::make_shared<GpuSparseMatrix>(output.shape()[0],
output.shape()[1], output.shape()[1],
output.nnz(), output.nnz(),
output.dataType(), output.dataType(),
output.dataFormat(), output.dataFormat());
output.isTransposed());
/// init sparse matrix /// init sparse matrix
hl_stream_t stream(HPPL_STREAM_1); hl_stream_t stream(HPPL_STREAM_1);
...@@ -154,15 +142,13 @@ public: ...@@ -154,15 +142,13 @@ public:
input.shape()[1], input.shape()[1],
input.nnz(), input.nnz(),
input.dataType(), input.dataType(),
input.dataFormat(), input.dataFormat());
input.isTransposed());
gpuSparse_ = std::make_shared<GpuSparseMatrix>(input.shape()[0], gpuSparse_ = std::make_shared<GpuSparseMatrix>(input.shape()[0],
input.shape()[1], input.shape()[1],
input.nnz(), input.nnz(),
input.dataType(), input.dataType(),
input.dataFormat(), input.dataFormat());
input.isTransposed());
/// init sparse matrix /// init sparse matrix
hl_stream_t stream(HPPL_STREAM_1); hl_stream_t stream(HPPL_STREAM_1);
......
...@@ -46,21 +46,11 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -46,21 +46,11 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
const CpuMatrix& a, const CpuMatrix& a,
const CpuMatrix& b, const CpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT,
CHECK(!out.isTransposed()) << "Not supported"; bool aTrans,
bool bTrans,
bool cTrans) {
CHECK_EQ(out.getValueType(), FLOAT_VALUE); CHECK_EQ(out.getValueType(), FLOAT_VALUE);
CHECK(!a.isTransposed() || !b.isTransposed())
<< "Not support both a and b are transpose matrices";
size_t height = out.getHeight();
size_t width = out.getWidth();
size_t aRow = !a.isTransposed() ? a.getHeight() : a.getWidth();
size_t aCol = !a.isTransposed() ? a.getWidth() : a.getHeight();
size_t bRow = !b.isTransposed() ? b.getHeight() : b.getWidth();
size_t bCol = !b.isTransposed() ? b.getWidth() : b.getHeight();
/// C = A * B, for matrix format
CHECK(aCol == bRow && aRow == height && bCol == width);
if (scaleT == 0) { if (scaleT == 0) {
out.zeroMem(); out.zeroMem();
} }
...@@ -69,12 +59,14 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -69,12 +59,14 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
real* C = out.getValue(); real* C = out.getValue();
int* rows = out.getRows(); int* rows = out.getRows();
int* cols = out.getCols(); int* cols = out.getCols();
size_t width = out.getWidth();
size_t height = out.getHeight();
/// SPARSE_CSC, {a any, b not trans} /// SPARSE_CSC, {a any, b not trans}
if (out.getFormat() == SPARSE_CSC) { if (out.getFormat() == SPARSE_CSC) {
/// b not trans and a any /// b not trans and a any
CHECK(!b.isTransposed()); CHECK(!bTrans);
size_t m = !a.isTransposed() ? a.getWidth() : a.getHeight(); size_t m = !aTrans ? a.getWidth() : a.getHeight();
for (size_t i = 0; i < width; i++) { for (size_t i = 0; i < width; i++) {
size_t start = out.getColStartIdx(i); size_t start = out.getColStartIdx(i);
size_t end = out.getColStartIdx(i + 1); size_t end = out.getColStartIdx(i + 1);
...@@ -82,9 +74,8 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -82,9 +74,8 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
real sum = 0; real sum = 0;
size_t rowIdx = rows[j]; size_t rowIdx = rows[j];
for (size_t k = 0; k < m; k++) { for (size_t k = 0; k < m; k++) {
sum += sum += (!aTrans ? A[rowIdx * m + k] : A[k * height + rowIdx]) *
(!a.isTransposed() ? A[rowIdx * m + k] : A[k * height + rowIdx]) * B[k * width + i];
B[k * width + i];
} }
C[j] = scaleAB * sum + scaleT * C[j]; C[j] = scaleAB * sum + scaleT * C[j];
} }
...@@ -95,7 +86,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -95,7 +86,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
/// SPARSE_CSR, {a any, b not trans} or {a not trans, b trans} /// SPARSE_CSR, {a any, b not trans} or {a not trans, b trans}
if (out.getFormat() == SPARSE_CSR) { if (out.getFormat() == SPARSE_CSR) {
/// a and b can not both transpose /// a and b can not both transpose
CHECK(!(a.isTransposed() && b.isTransposed())); CHECK(!(aTrans && bTrans));
size_t m = a.getWidth(); size_t m = a.getWidth();
for (size_t i = 0; i < height; i++) { for (size_t i = 0; i < height; i++) {
size_t start = out.getRowStartIdx(i); size_t start = out.getRowStartIdx(i);
...@@ -104,9 +95,8 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -104,9 +95,8 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
real sum = 0; real sum = 0;
size_t colIdx = cols[j]; size_t colIdx = cols[j];
for (size_t k = 0; k < m; k++) { for (size_t k = 0; k < m; k++) {
sum += sum += (!aTrans ? A[i * m + k] : A[k * height + i]) *
(!a.isTransposed() ? A[i * m + k] : A[k * height + i]) * (!bTrans ? B[k * width + colIdx] : B[colIdx * m + k]);
(!b.isTransposed() ? B[k * width + colIdx] : B[colIdx * m + k]);
} }
C[j] = scaleAB * sum + scaleT * C[j]; C[j] = scaleAB * sum + scaleT * C[j];
} }
...@@ -120,25 +110,15 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -120,25 +110,15 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuMatrix& a, const CpuMatrix& a,
const CpuMatrix& b, const CpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT,
CHECK(!out.isTransposed()) << "out matrix transpose not supported"; bool aTrans,
CBLAS_TRANSPOSE aTrans = a.isTransposed() ? CblasTrans : CblasNoTrans; bool bTrans,
size_t aRow = a.isTransposed() ? a.getWidth() : a.getHeight(); bool cTrans) {
size_t aCol = a.isTransposed() ? a.getHeight() : a.getWidth(); GEMM(aTrans ? CblasTrans : CblasNoTrans,
CBLAS_TRANSPOSE bTrans = b.isTransposed() ? CblasTrans : CblasNoTrans; bTrans ? CblasTrans : CblasNoTrans,
size_t bRow = b.isTransposed() ? b.getWidth() : b.getHeight();
size_t bCol = b.isTransposed() ? b.getHeight() : b.getWidth();
/// C = A * B, for matrix format
CHECK_EQ(aCol, bRow);
CHECK_EQ(aRow, out.getHeight());
CHECK_EQ(bCol, out.getWidth());
GEMM(aTrans,
bTrans,
out.getHeight(), out.getHeight(),
out.getWidth(), out.getWidth(),
aCol, !aTrans ? a.getWidth() : a.getHeight(),
scaleAB, scaleAB,
a.getData(), a.getData(),
a.getStride(), a.getStride(),
...@@ -154,21 +134,12 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -154,21 +134,12 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuSparseMatrix& a, const CpuSparseMatrix& a,
const CpuMatrix& b, const CpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT,
CHECK(!out.isTransposed()) << "Not supported"; bool aTrans,
CHECK(!b.isTransposed()) << "Not supported"; bool bTrans,
CHECK(scaleT == 0 || scaleT == 1) << "Not support"; bool cTrans) {
CHECK_EQ(scaleAB, static_cast<real>(1.0)) << "Not supported"; CHECK_EQ(a.getFormat(), SPARSE_CSR)
CHECK_EQ(a.getFormat(), SPARSE_CSR) << "Not supported"; << "Not supported SPARSE_CSR format for a";
if (!a.isTransposed()) {
CHECK(b.getHeight() == a.getWidth() && a.getHeight() == out.getHeight() &&
b.getWidth() == out.getWidth());
} else {
CHECK(b.getHeight() == a.getHeight() && a.getWidth() == out.getHeight() &&
b.getWidth() == out.getWidth());
}
if (scaleT == 0) { if (scaleT == 0) {
out.zeroMem(); out.zeroMem();
} }
...@@ -185,9 +156,9 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -185,9 +156,9 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const int start = a.getRowStartIdx(i); const int start = a.getRowStartIdx(i);
const int end = a.getRowStartIdx(i + 1); const int end = a.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) { for (int j = start; j < end; ++j) {
vecAddTo(!a.isTransposed() ? out.getRow(i) : out.getRow(cols[j]), vecAddTo(!aTrans ? out.getRow(i) : out.getRow(cols[j]),
!a.isTransposed() ? const_cast<CpuMatrix&>(b).getRow(cols[j]) !aTrans ? const_cast<CpuMatrix&>(b).getRow(cols[j])
: const_cast<CpuMatrix&>(b).getRow(i), : const_cast<CpuMatrix&>(b).getRow(i),
(a.getValueType() == FLOAT_VALUE) ? values[j] : (real)1.0, (a.getValueType() == FLOAT_VALUE) ? values[j] : (real)1.0,
out.getWidth()); out.getWidth());
} }
...@@ -199,19 +170,10 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -199,19 +170,10 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuMatrix& a, const CpuMatrix& a,
const CpuSparseMatrix& b, const CpuSparseMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT,
CHECK(!out.trans_) << "Not supported"; bool aTrans,
CHECK(!a.isTransposed()) << "Not supported"; bool bTrans,
CHECK(scaleT == 0 || scaleT == 1); bool cTrans) {
CHECK_EQ(scaleAB, static_cast<real>(1.0));
if (!b.isTransposed()) { /// b is not Transpose
CHECK(b.getHeight() == a.getWidth() && a.getHeight() == out.getHeight() &&
b.getWidth() == out.getWidth());
} else {
CHECK(b.getHeight() == out.getWidth() && a.getHeight() == out.getHeight() &&
b.getWidth() == a.getWidth());
}
if (scaleT == 0) { if (scaleT == 0) {
out.zeroMem(); out.zeroMem();
} }
...@@ -227,8 +189,8 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -227,8 +189,8 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
int start = b.getColStartIdx(j); int start = b.getColStartIdx(j);
int end = b.getColStartIdx(j + 1); int end = b.getColStartIdx(j + 1);
for (int i = start; i < end; ++i) { for (int i = start; i < end; ++i) {
colVecAddTo(!b.isTransposed() ? C + j : C + rows[i], colVecAddTo(!bTrans ? C + j : C + rows[i],
!b.isTransposed() ? A + rows[i] : A + j, !bTrans ? A + rows[i] : A + j,
(b.getValueType() == NO_VALUE) ? (real)1.0 : B[i], (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
out.getHeight(), out.getHeight(),
out.getWidth(), out.getWidth(),
...@@ -244,8 +206,8 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -244,8 +206,8 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
int start = b.getRowStartIdx(j); int start = b.getRowStartIdx(j);
int end = b.getRowStartIdx(j + 1); int end = b.getRowStartIdx(j + 1);
for (int i = start; i < end; ++i) { for (int i = start; i < end; ++i) {
colVecAddTo(!b.isTransposed() ? C + cols[i] : C + j, colVecAddTo(!bTrans ? C + cols[i] : C + j,
!b.isTransposed() ? A + j : A + cols[i], !bTrans ? A + j : A + cols[i],
(b.getValueType() == NO_VALUE) ? (real)1.0 : B[i], (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
out.getHeight(), out.getHeight(),
out.getWidth(), out.getWidth(),
...@@ -270,16 +232,43 @@ public: ...@@ -270,16 +232,43 @@ public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
alpha_ = config.get<real>("scaleAB"); alpha_ = config.get<real>("scaleAB");
beta_ = config.get<real>("scaleT"); beta_ = config.get<real>("scaleT");
aTrans_ = config.get<bool>("aTrans");
bTrans_ = config.get<bool>("bTrans");
cTrans_ = config.get<bool>("cTrans");
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK(!cTrans_) << "output matrix should not be transposed";
CHECK(!aTrans_ || !bTrans_)
<< "Not support both a and b are transpose matrices";
CHECK_EQ((size_t)2, inputs.size()); CHECK_EQ((size_t)2, inputs.size());
CHECK_EQ((size_t)1, outputs.size()); CHECK_EQ((size_t)1, outputs.size());
CHECK(inputs[0].data() && inputs[1].data() && outputs[0].data()); CHECK(inputs[0].data() && inputs[1].data() && outputs[0].data());
CHECK_EQ(inputs[0].shape().ndims(), (size_t)2); CHECK_EQ(inputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(inputs[1].shape().ndims(), (size_t)2); CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2); CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
size_t aRow = !aTrans_ ? inputs[0].shape()[0] : inputs[0].shape()[1];
size_t aCol = !aTrans_ ? inputs[0].shape()[1] : inputs[0].shape()[0];
size_t bRow = !bTrans_ ? inputs[1].shape()[0] : inputs[1].shape()[1];
size_t bCol = !bTrans_ ? inputs[1].shape()[1] : inputs[1].shape()[0];
/// C = A * B, or C += A * B, for matrix format
CHECK_EQ(aCol, bRow);
CHECK_EQ(aRow, outputs[0].shape()[0]);
CHECK_EQ(bCol, outputs[0].shape()[1]);
/// only support C = A * B or C += A * B
CHECK_EQ(alpha_, static_cast<real>(1.0));
CHECK((beta_ == 0 && outputs[0].getArgType() == ASSIGN_TO) ||
(beta_ == 1 && outputs[0].getArgType() == ADD_TO));
/// support dense = not both sparse * sparse
/// or sparse = dense * dense
CHECK((!outputs[0].isSparseArg() &&
!(inputs[0].isSparseArg() && inputs[1].isSparseArg())) ||
(outputs[0].isSparseArg() && !inputs[0].isSparseArg() &&
!inputs[1].isSparseArg()));
auto outMat = outputs[0].matrix<Device>(); auto outMat = outputs[0].matrix<Device>();
/// matrix = matrix * matrix /// matrix = matrix * matrix
...@@ -289,29 +278,40 @@ public: ...@@ -289,29 +278,40 @@ public:
inputs[0].matrix<Device>(), inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(), inputs[1].matrix<Device>(),
alpha_, alpha_,
beta_); beta_,
aTrans_,
bTrans_,
cTrans_);
return; return;
} }
/// matrix = matrix * sparse matrix /// matrix = matrix * sparse matrix
if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() && if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) { !outputs[0].isSparseArg()) {
CHECK(!aTrans_) << "Not supported a transpose";
MulOp<Device>(outMat, MulOp<Device>(outMat,
inputs[0].matrix<Device>(), inputs[0].matrix<Device>(),
inputs[1].sparse().SparseMatrix<Device>(), inputs[1].sparse().SparseMatrix<Device>(),
alpha_, alpha_,
beta_); beta_,
aTrans_,
bTrans_,
cTrans_);
return; return;
} }
/// matrix = sparse matrix * matrix /// matrix = sparse matrix * matrix
if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() && if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) { !outputs[0].isSparseArg()) {
CHECK(!bTrans_) << "Not supported b transpose";
MulOp<Device>(outMat, MulOp<Device>(outMat,
inputs[0].sparse().SparseMatrix<Device>(), inputs[0].sparse().SparseMatrix<Device>(),
inputs[1].matrix<Device>(), inputs[1].matrix<Device>(),
alpha_, alpha_,
beta_); beta_,
aTrans_,
bTrans_,
cTrans_);
return; return;
} }
...@@ -319,18 +319,14 @@ public: ...@@ -319,18 +319,14 @@ public:
auto outSparseMat = outputs[0].sparse().SparseMatrix<Device>(); auto outSparseMat = outputs[0].sparse().SparseMatrix<Device>();
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() && if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
outputs[0].isSparseArg()) { outputs[0].isSparseArg()) {
/*
LOG(INFO) << "input0";
inputs[0].matrix<Device>().print(std::cout);
LOG(INFO) << "input1";
inputs[1].matrix<Device>().print(std::cout);
LOG(INFO) << "output sparse matrix";
outSparseMat.print(std::cout); */
MulOp<Device>(outSparseMat, MulOp<Device>(outSparseMat,
inputs[0].matrix<Device>(), inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(), inputs[1].matrix<Device>(),
alpha_, alpha_,
beta_); beta_,
aTrans_,
bTrans_,
cTrans_);
return; return;
} }
} }
...@@ -338,6 +334,9 @@ public: ...@@ -338,6 +334,9 @@ public:
private: private:
real alpha_; real alpha_;
real beta_; real beta_;
bool aTrans_;
bool bTrans_;
bool cTrans_;
}; };
REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc); REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc);
......
...@@ -26,55 +26,79 @@ void MulOp(CpuMatrix& out, ...@@ -26,55 +26,79 @@ void MulOp(CpuMatrix& out,
const CpuMatrix& a, const CpuMatrix& a,
const CpuMatrix& b, const CpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT); real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
template <DeviceType DType> template <DeviceType DType>
void MulOp(CpuMatrix& out, void MulOp(CpuMatrix& out,
const CpuSparseMatrix& a, const CpuSparseMatrix& a,
const CpuMatrix& b, const CpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT); real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
template <DeviceType DType> template <DeviceType DType>
void MulOp(CpuMatrix& out, void MulOp(CpuMatrix& out,
const CpuMatrix& a, const CpuMatrix& a,
const CpuSparseMatrix& b, const CpuSparseMatrix& b,
real scaleAB, real scaleAB,
real scaleT); real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
template <DeviceType DType> template <DeviceType DType>
void MulOp(CpuSparseMatrix& out, void MulOp(CpuSparseMatrix& out,
const CpuMatrix& a, const CpuMatrix& a,
const CpuMatrix& b, const CpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT); real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
template <DeviceType DType> template <DeviceType DType>
void MulOp(GpuMatrix& out, void MulOp(GpuMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
const GpuMatrix& b, const GpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT); real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
template <DeviceType DType> template <DeviceType DType>
void MulOp(GpuMatrix& out, void MulOp(GpuMatrix& out,
const GpuSparseMatrix& a, const GpuSparseMatrix& a,
const GpuMatrix& b, const GpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT); real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
template <DeviceType DType> template <DeviceType DType>
void MulOp(GpuMatrix& out, void MulOp(GpuMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
const GpuSparseMatrix& b, const GpuSparseMatrix& b,
real scaleAB, real scaleAB,
real scaleT); real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
template <DeviceType DType> template <DeviceType DType>
void MulOp(GpuSparseMatrix& out, void MulOp(GpuSparseMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
const GpuMatrix& b, const GpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT); real scaleT,
bool aTrans,
bool bTrans,
bool cTrans);
} // namespace paddle } // namespace paddle
...@@ -27,38 +27,22 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -27,38 +27,22 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
const GpuMatrix& b, const GpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT,
CHECK(!out.isTransposed()) << "Transpose not supported for out matrix"; bool aTrans,
if (!a.isTransposed() && !b.isTransposed()) { bool bTrans,
/// a : M * K, b: K * N bool cTrans) {
CHECK(out.getWidth() == b.getWidth() && CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
out.getHeight() == a.getHeight() &&
a.getWidth() == b.getHeight());
} else if (a.isTransposed() && !b.isTransposed()) {
/// a : K * M, b : K * N
CHECK(out.getWidth() == b.getWidth() &&
out.getHeight() == a.getWidth() &&
a.getHeight() == b.getHeight());
} else if (!a.isTransposed() && b.isTransposed()) {
/// a: M * K, b : N * K
CHECK(out.getWidth() == b.getHeight() &&
out.getHeight() == a.getHeight() &&
a.getWidth() == b.getWidth());
} else {
LOG(FATAL) << "Not support for both a and b are Transposed Matrices";
}
real* aData = const_cast<real*>(a.getData()); real* aData = const_cast<real*>(a.getData());
real* bData = const_cast<real*>(b.getData()); real* bData = const_cast<real*>(b.getData());
real* outData = const_cast<real*>(out.getData()); real* outData = const_cast<real*>(out.getData());
hl_matrix_mul(aData, hl_matrix_mul(aData,
!a.isTransposed() ? HPPL_OP_N : HPPL_OP_T, !aTrans ? HPPL_OP_N : HPPL_OP_T,
bData, bData,
!b.isTransposed() ? HPPL_OP_N : HPPL_OP_T, !bTrans ? HPPL_OP_N : HPPL_OP_T,
outData, outData,
out.getHeight(), out.getHeight(),
out.getWidth(), out.getWidth(),
!a.isTransposed() ? a.getWidth() : a.getHeight(), !aTrans ? a.getWidth() : a.getHeight(),
scaleAB, scaleAB,
scaleT, scaleT,
a.getStride(), a.getStride(),
...@@ -75,27 +59,19 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -75,27 +59,19 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuSparseMatrix& a, const GpuSparseMatrix& a,
const GpuMatrix& b, const GpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT,
bool aTrans,
bool bTrans,
bool cTrans) {
CHECK(out.isContiguous()); CHECK(out.isContiguous());
CHECK(b.isContiguous()); CHECK(b.isContiguous());
CHECK(b.useGpu_) << "Matrix type are not equal"; CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
CHECK(!out.isTransposed() && !b.isTransposed()) << "not supported";
if (!a.isTransposed()) {
/// a: M * K, b: K * N
CHECK(out.getWidth() == b.getWidth() && out.getHeight() == a.getHeight()
&& a.getWidth() == b.getHeight()) << "Matrix dimensions are not equal";
} else {
/// a: K * M, transpose, b: K * N
CHECK(out.getWidth() == b.getWidth() && out.getHeight() == a.getWidth()
&& a.getHeight() == b.getHeight()) << "Matrix dimensions are not equal";
}
hl_trans_op_t aTrans = a.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
hl_sparse_matrix_s aData = a.sMatrix_.get(); hl_sparse_matrix_s aData = a.sMatrix_.get();
real* bData = const_cast<real*>(b.getData()); real* bData = const_cast<real*>(b.getData());
real* outData = const_cast<real*>(out.getData()); real* outData = const_cast<real*>(out.getData());
hl_matrix_csr_mul_dense(aData, hl_matrix_csr_mul_dense(aData,
aTrans, aTrans ? HPPL_OP_T : HPPL_OP_N,
bData, bData,
HPPL_OP_N, HPPL_OP_N,
outData, outData,
...@@ -115,25 +91,14 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -115,25 +91,14 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
const GpuSparseMatrix& b, const GpuSparseMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT,
bool aTrans,
bool bTrans,
bool cTrans) {
CHECK(out.isContiguous()); CHECK(out.isContiguous());
CHECK(a.isContiguous()); CHECK(a.isContiguous());
CHECK(a.useGpu_) << "Matrix type are not equal"; CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
if (!b.isTransposed()) {
/// a : M * K, b : K * N
CHECK(out.getWidth() == b.getWidth() &&
out.getHeight() == a.getHeight() &&
a.getWidth() == b.getHeight())
<< "Matrix dimensions are not equal";
} else {
/// a : M * K, b : N * K, transpose
CHECK(out.getWidth() == b.getHeight() &&
out.getHeight() == a.getHeight() &&
a.getWidth() == b.getWidth())
<< "Matrix dimensions are not equal";
}
hl_trans_op_t bTrans = b.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
hl_sparse_matrix_s bData = b.sMatrix_.get(); hl_sparse_matrix_s bData = b.sMatrix_.get();
real* aData = const_cast<real*>(a.getData()); real* aData = const_cast<real*>(a.getData());
real* outData = const_cast<real*>(out.getData()); real* outData = const_cast<real*>(out.getData());
...@@ -142,7 +107,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -142,7 +107,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
hl_matrix_dense_mul_csc(aData, hl_matrix_dense_mul_csc(aData,
HPPL_OP_N, HPPL_OP_N,
bData, bData,
bTrans, bTrans ? HPPL_OP_T : HPPL_OP_N,
outData, outData,
out.getHeight(), out.getHeight(),
out.getWidth(), out.getWidth(),
...@@ -153,7 +118,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, ...@@ -153,7 +118,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
hl_matrix_dense_mul_csr(aData, hl_matrix_dense_mul_csr(aData,
HPPL_OP_N, HPPL_OP_N,
bData, bData,
bTrans, bTrans ? HPPL_OP_T : HPPL_OP_N,
outData, outData,
out.getHeight(), out.getHeight(),
out.getWidth(), out.getWidth(),
...@@ -168,35 +133,26 @@ void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out, ...@@ -168,35 +133,26 @@ void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
const GpuMatrix& b, const GpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT,
bool aTrans,
bool bTrans,
bool cTrans) {
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match"; CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
CHECK(!out.isTransposed()) << "Transpose is not supported for out matrix";
if (!a.isTransposed() && !b.isTransposed()) {
CHECK(out.getHeight() == a.getHeight() &&
out.getWidth() == b.getWidth() &&
a.getWidth() == b.getHeight());
} else if (a.isTransposed() && !b.isTransposed()) {
CHECK(out.getHeight() == a.getWidth() &&
out.getWidth() == b.getWidth() &&
a.getHeight() == b.getHeight());
} else if (!a.isTransposed() && b.isTransposed()) {
CHECK(out.getHeight() == a.getHeight() &&
out.getWidth() == b.getHeight() &&
a.getWidth() == b.getWidth());
} else {
LOG(FATAL) << "Not support for both a and b are Transposed Matrices";
}
hl_trans_op_t aTrans = a.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
hl_trans_op_t bTrans = b.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
int dimK = !b.isTransposed() ? b.getHeight() : b.getWidth();
real* aData = const_cast<real*>(a.getData()); real* aData = const_cast<real*>(a.getData());
real* bData = const_cast<real*>(b.getData()); real* bData = const_cast<real*>(b.getData());
hl_sparse_matrix_s outData = out.sMatrix_.get(); hl_sparse_matrix_s outData = out.sMatrix_.get();
hl_sparse_matrix_mul(aData, aTrans, bData, bTrans, outData, hl_sparse_matrix_mul(aData,
out.getHeight(), out.getWidth(), dimK, scaleAB, scaleT); aTrans ? HPPL_OP_T : HPPL_OP_N,
bData,
bTrans ? HPPL_OP_T : HPPL_OP_N,
outData,
out.getHeight(),
out.getWidth(),
!bTrans ? b.getHeight() : b.getWidth(),
scaleAB,
scaleT);
} }
} // namespace paddle } // namespace paddle
...@@ -39,18 +39,21 @@ void testFuncDDDMatrix( ...@@ -39,18 +39,21 @@ void testFuncDDDMatrix(
size_t widthC = dimN; size_t widthC = dimN;
// init Test object // init Test object
FunctionCompare test("MulOp", FunctionCompare test("MulOp",
FuncConfig().set("scaleAB", alpha).set("scaleT", beta)); FuncConfig()
.set("scaleAB", alpha)
.set("scaleT", beta)
.set("aTrans", transa)
.set("bTrans", transb)
.set("cTrans", false));
// prepare input arguments // prepare input arguments
/// matrix A : HA * WA /// matrix A : HA * WA
test.addInputs(BufferArg( test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightA, widthA}));
VALUE_TYPE_FLOAT, TensorShape{heightA, widthA}, UNSPECIFIED, transa));
/// matrix B: HB * WB /// matrix B: HB * WB
test.addInputs(BufferArg( test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightB, widthB}));
VALUE_TYPE_FLOAT, TensorShape{heightB, widthB}, UNSPECIFIED, transb));
/// output matrix C: HC * WC /// output matrix C: HC * WC
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightC, widthC}), test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{heightC, widthC}),
ADD_TO); beta == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function // run Function
test.run(); test.run();
} }
...@@ -88,21 +91,22 @@ void testFuncDSparseDMatrix( ...@@ -88,21 +91,22 @@ void testFuncDSparseDMatrix(
real beta = 1.0; real beta = 1.0;
// init Test object // init Test object
FunctionCompare test("MulOp", FunctionCompare test("MulOp",
FuncConfig().set("scaleAB", alpha).set("scaleT", beta)); FuncConfig()
.set("scaleAB", alpha)
.set("scaleT", beta)
.set("aTrans", false)
.set("bTrans", false)
.set("cTrans", false));
// prepare input arguments // prepare input arguments
/// sparse matrix A : M * K /// sparse matrix A : M * K
test.addInputs(SparseMatrixArg(VALUE_TYPE_FLOAT, test.addInputs(SparseMatrixArg(
TensorShape{dimM, dimK}, VALUE_TYPE_FLOAT, TensorShape{dimM, dimK}, nnz, FORMAT, FLOAT_VALUE));
nnz,
FORMAT,
FLOAT_VALUE,
UNSPECIFIED,
false));
/// matrix B: K * N /// matrix B: K * N
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimK, dimN})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimK, dimN}));
/// output matrix C: M * N /// output matrix C: M * N
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}), ADD_TO); test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}),
beta == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function // run Function
test.run(); test.run();
} }
...@@ -138,22 +142,23 @@ void testFuncDDSparseMatrix( ...@@ -138,22 +142,23 @@ void testFuncDDSparseMatrix(
real beta = 1.0; real beta = 1.0;
// init Test object // init Test object
FunctionCompare test("MulOp", FunctionCompare test("MulOp",
FuncConfig().set("scaleAB", alpha).set("scaleT", beta)); FuncConfig()
.set("scaleAB", alpha)
.set("scaleT", beta)
.set("aTrans", false)
.set("bTrans", false)
.set("cTrans", false));
// prepare input arguments // prepare input arguments
/// matrix A : M * K /// matrix A : M * K
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK}));
/// matrix B: K * N /// matrix B: K * N
test.addInputs(SparseMatrixArg(VALUE_TYPE_FLOAT, test.addInputs(SparseMatrixArg(
TensorShape{dimK, dimN}, VALUE_TYPE_FLOAT, TensorShape{dimK, dimN}, nnz, FORMAT, FLOAT_VALUE));
nnz,
FORMAT,
FLOAT_VALUE,
UNSPECIFIED,
false));
/// output matrix C: M * N /// output matrix C: M * N
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}), ADD_TO); test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}),
beta == 1.0 ? ADD_TO : ASSIGN_TO);
// run Function // run Function
test.run(); test.run();
} }
...@@ -189,7 +194,12 @@ void testFuncSparseDDMatrix( ...@@ -189,7 +194,12 @@ void testFuncSparseDDMatrix(
real beta = 1.0; real beta = 1.0;
// init Test object // init Test object
FunctionCompare test("MulOp", FunctionCompare test("MulOp",
FuncConfig().set("scaleAB", alpha).set("scaleT", beta)); FuncConfig()
.set("scaleAB", alpha)
.set("scaleT", beta)
.set("aTrans", false)
.set("bTrans", false)
.set("cTrans", false));
// prepare input arguments // prepare input arguments
/// matrix A : M * K /// matrix A : M * K
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK}));
...@@ -198,14 +208,10 @@ void testFuncSparseDDMatrix( ...@@ -198,14 +208,10 @@ void testFuncSparseDDMatrix(
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimK, dimN})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimK, dimN}));
/// output sparse matrix C: M * N /// output sparse matrix C: M * N
test.addOutputs(SparseMatrixArg(VALUE_TYPE_FLOAT, test.addOutputs(
TensorShape{dimM, dimN}, SparseMatrixArg(
nnz, VALUE_TYPE_FLOAT, TensorShape{dimM, dimN}, nnz, FORMAT, FLOAT_VALUE),
FORMAT, beta == 1.0 ? ADD_TO : ASSIGN_TO);
FLOAT_VALUE,
UNSPECIFIED,
false),
ADD_TO);
// run Function // run Function
test.run(); test.run();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册