提交 728defbe 编写于 作者: H Haonan

copy the data when createSparseMatrix

上级 069d0004
...@@ -462,29 +462,49 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap, ...@@ -462,29 +462,49 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap,
void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label, void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label,
Matrix& target) { Matrix& target) {
label.idsToSparseMatrix(output.getWidth(), useGpu_); MatrixPtr value = nullptr;
if (label.ids) {
CHECK(!label.value);
value = Matrix::createSparseMatrix(
label.ids->getSize(), output.getWidth(), label.ids->getSize(),
NO_VALUE, SPARSE_CSR, false, useGpu_);
label.idsToSparseMatrix(value);
} else {
CHECK(label.value);
value = label.value;
}
if (dynamic_cast<CpuSparseMatrix*>(label.value.get()) || if (dynamic_cast<CpuSparseMatrix*>(value.get()) ||
dynamic_cast<GpuSparseMatrix*>(label.value.get())) { dynamic_cast<GpuSparseMatrix*>(value.get())) {
target.multiBinaryLabelCrossEntropy(output, *label.value); target.multiBinaryLabelCrossEntropy(output, *value);
} else { } else {
Matrix::resizeOrCreate(targetPerDim_, output.getHeight(), output.getWidth(), Matrix::resizeOrCreate(targetPerDim_, output.getHeight(), output.getWidth(),
false, useGpu_); false, useGpu_);
targetPerDim_->binaryLabelCrossEntropy(output, *label.value); targetPerDim_->binaryLabelCrossEntropy(output, *value);
targetPerDim_->rowSum(target); targetPerDim_->rowSum(target);
} }
} }
void MultiBinaryLabelCrossEntropy::backwardImp( void MultiBinaryLabelCrossEntropy::backwardImp(
Matrix& output, Argument& label, Matrix& outputG) { Matrix& output, Argument& label, Matrix& outputG) {
label.idsToSparseMatrix(output.getWidth(), useGpu_); MatrixPtr value = nullptr;
if (label.ids) {
CHECK(!value);
value = Matrix::createSparseMatrix(
label.ids->getSize(), output.getWidth(), label.ids->getSize(),
NO_VALUE, SPARSE_CSR, false, useGpu_);
label.idsToSparseMatrix(value);
} else {
CHECK(label.value);
value = label.value;
}
if (dynamic_cast<CpuSparseMatrix*>(label.value.get()) || if (dynamic_cast<CpuSparseMatrix*>(value.get()) ||
dynamic_cast<GpuSparseMatrix*>(label.value.get())) { dynamic_cast<GpuSparseMatrix*>(value.get())) {
outputG.multiBinaryLabelCrossEntropyBp(output, *label.value); outputG.multiBinaryLabelCrossEntropyBp(output, *value);
} else { } else {
outputG.binaryLabelCrossEntropyBp(output, *label.value); outputG.binaryLabelCrossEntropyBp(output, *value);
} }
} }
......
...@@ -572,25 +572,41 @@ void Argument::subArgFrom(const Argument& input, size_t offset, size_t height, ...@@ -572,25 +572,41 @@ void Argument::subArgFrom(const Argument& input, size_t offset, size_t height,
} }
} }
void Argument::idsToSparseMatrix(int width, bool useGpu) { void Argument::idsToSparseMatrix(MatrixPtr sparse_mat) {
if (ids) {
CHECK(!value);
int height = ids->getSize(); int height = ids->getSize();
int nnz = height; int width = sparse_mat->getWidth();
auto rows = IVector::create(height + 1, useGpu);
auto cols = IVector::create(nnz, useGpu); CpuIVector cpu_ids(height);
rows->setElement(0, 0); cpu_ids.copyFrom(*ids);
int *id_data = cpu_ids.getData();
int *rows = nullptr;
int *cols = nullptr;
if (sparse_mat->useGpu()) {
auto gpu_sparse_mat =
dynamic_cast<GpuSparseMatrix*>(sparse_mat.get());
rows = gpu_sparse_mat->rows_;
cols = gpu_sparse_mat->cols_;
} else {
rows = sparse_mat->getRows();
cols = sparse_mat->getCols();
}
rows[0] = 0;
for (int i = 0; i < height; i ++) { for (int i = 0; i < height; i ++) {
int id = ids->getElement(i); int id = id_data[i];
CHECK_LT(id, width); CHECK_LT(id, width);
rows->setElement(i + 1, i + 1); rows[i + 1] = i + 1;
cols->setElement(i, id); cols[i] = id;
} }
value = Matrix::createSparseMatrix(
nullptr, rows->getData(), cols->getData(), if (sparse_mat->useGpu()) {
height, width, nnz, NO_VALUE, SPARSE_CSR, false, useGpu); auto gpu_sparse_mat =
} else { dynamic_cast<GpuSparseMatrix*>(sparse_mat.get());
CHECK(value); hl_memcpy_csr_matrix(gpu_sparse_mat->sMatrix_.get(),
nullptr, rows, cols,
HPPL_STREAM_DEFAULT);
hl_stream_synchronize(HPPL_STREAM_DEFAULT);
} }
} }
......
...@@ -289,11 +289,9 @@ struct Argument { ...@@ -289,11 +289,9 @@ struct Argument {
/* /*
@brief convert the ids vector to value as a sparse matrix @brief convert the ids vector to value as a sparse matrix
the ids vector keeps valid @param[out] the output sparse_mat (already allocated)
@param the matrix width (id range)
@useGpu
*/ */
void idsToSparseMatrix(int width, bool useGpu); void idsToSparseMatrix(MatrixPtr sparse_mat);
}; };
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册