提交 728defbe 编写于 作者: H Haonan

copy the data when createSparseMatrix

上级 069d0004
...@@ -462,29 +462,49 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap, ...@@ -462,29 +462,49 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap,
void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label, void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label,
Matrix& target) { Matrix& target) {
label.idsToSparseMatrix(output.getWidth(), useGpu_); MatrixPtr value = nullptr;
if (label.ids) {
CHECK(!label.value);
value = Matrix::createSparseMatrix(
label.ids->getSize(), output.getWidth(), label.ids->getSize(),
NO_VALUE, SPARSE_CSR, false, useGpu_);
label.idsToSparseMatrix(value);
} else {
CHECK(label.value);
value = label.value;
}
if (dynamic_cast<CpuSparseMatrix*>(label.value.get()) || if (dynamic_cast<CpuSparseMatrix*>(value.get()) ||
dynamic_cast<GpuSparseMatrix*>(label.value.get())) { dynamic_cast<GpuSparseMatrix*>(value.get())) {
target.multiBinaryLabelCrossEntropy(output, *label.value); target.multiBinaryLabelCrossEntropy(output, *value);
} else { } else {
Matrix::resizeOrCreate(targetPerDim_, output.getHeight(), output.getWidth(), Matrix::resizeOrCreate(targetPerDim_, output.getHeight(), output.getWidth(),
false, useGpu_); false, useGpu_);
targetPerDim_->binaryLabelCrossEntropy(output, *label.value); targetPerDim_->binaryLabelCrossEntropy(output, *value);
targetPerDim_->rowSum(target); targetPerDim_->rowSum(target);
} }
} }
void MultiBinaryLabelCrossEntropy::backwardImp( void MultiBinaryLabelCrossEntropy::backwardImp(
Matrix& output, Argument& label, Matrix& outputG) { Matrix& output, Argument& label, Matrix& outputG) {
label.idsToSparseMatrix(output.getWidth(), useGpu_); MatrixPtr value = nullptr;
if (label.ids) {
CHECK(!value);
value = Matrix::createSparseMatrix(
label.ids->getSize(), output.getWidth(), label.ids->getSize(),
NO_VALUE, SPARSE_CSR, false, useGpu_);
label.idsToSparseMatrix(value);
} else {
CHECK(label.value);
value = label.value;
}
if (dynamic_cast<CpuSparseMatrix*>(label.value.get()) || if (dynamic_cast<CpuSparseMatrix*>(value.get()) ||
dynamic_cast<GpuSparseMatrix*>(label.value.get())) { dynamic_cast<GpuSparseMatrix*>(value.get())) {
outputG.multiBinaryLabelCrossEntropyBp(output, *label.value); outputG.multiBinaryLabelCrossEntropyBp(output, *value);
} else { } else {
outputG.binaryLabelCrossEntropyBp(output, *label.value); outputG.binaryLabelCrossEntropyBp(output, *value);
} }
} }
......
...@@ -572,25 +572,41 @@ void Argument::subArgFrom(const Argument& input, size_t offset, size_t height, ...@@ -572,25 +572,41 @@ void Argument::subArgFrom(const Argument& input, size_t offset, size_t height,
} }
} }
void Argument::idsToSparseMatrix(int width, bool useGpu) { void Argument::idsToSparseMatrix(MatrixPtr sparse_mat) {
if (ids) { int height = ids->getSize();
CHECK(!value); int width = sparse_mat->getWidth();
int height = ids->getSize();
int nnz = height; CpuIVector cpu_ids(height);
auto rows = IVector::create(height + 1, useGpu); cpu_ids.copyFrom(*ids);
auto cols = IVector::create(nnz, useGpu); int *id_data = cpu_ids.getData();
rows->setElement(0, 0);
for (int i = 0; i < height; i ++) { int *rows = nullptr;
int id = ids->getElement(i); int *cols = nullptr;
CHECK_LT(id, width); if (sparse_mat->useGpu()) {
rows->setElement(i + 1, i + 1); auto gpu_sparse_mat =
cols->setElement(i, id); dynamic_cast<GpuSparseMatrix*>(sparse_mat.get());
} rows = gpu_sparse_mat->rows_;
value = Matrix::createSparseMatrix( cols = gpu_sparse_mat->cols_;
nullptr, rows->getData(), cols->getData(),
height, width, nnz, NO_VALUE, SPARSE_CSR, false, useGpu);
} else { } else {
CHECK(value); rows = sparse_mat->getRows();
cols = sparse_mat->getCols();
}
rows[0] = 0;
for (int i = 0; i < height; i ++) {
int id = id_data[i];
CHECK_LT(id, width);
rows[i + 1] = i + 1;
cols[i] = id;
}
if (sparse_mat->useGpu()) {
auto gpu_sparse_mat =
dynamic_cast<GpuSparseMatrix*>(sparse_mat.get());
hl_memcpy_csr_matrix(gpu_sparse_mat->sMatrix_.get(),
nullptr, rows, cols,
HPPL_STREAM_DEFAULT);
hl_stream_synchronize(HPPL_STREAM_DEFAULT);
} }
} }
......
...@@ -289,11 +289,9 @@ struct Argument { ...@@ -289,11 +289,9 @@ struct Argument {
/* /*
@brief convert the ids vector to value as a sparse matrix @brief convert the ids vector to value as a sparse matrix
the ids vector keeps valid @param[out] the output sparse_mat (already allocated)
@param the matrix width (id range)
@useGpu
*/ */
void idsToSparseMatrix(int width, bool useGpu); void idsToSparseMatrix(MatrixPtr sparse_mat);
}; };
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册