diff --git a/paddle/gserver/layers/CostLayer.cpp b/paddle/gserver/layers/CostLayer.cpp index c86e562d0e445604f65352a5db0b9d28e77d0825..900981d1e7d36c8eb2f2677c7455eab153503ef2 100644 --- a/paddle/gserver/layers/CostLayer.cpp +++ b/paddle/gserver/layers/CostLayer.cpp @@ -462,29 +462,49 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap, void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label, Matrix& target) { - label.idsToSparseMatrix(output.getWidth(), useGpu_); + MatrixPtr value = nullptr; + if (label.ids) { + CHECK(!label.value); + value = Matrix::createSparseMatrix( + label.ids->getSize(), output.getWidth(), label.ids->getSize(), + NO_VALUE, SPARSE_CSR, false, useGpu_); + label.idsToSparseMatrix(value); + } else { + CHECK(label.value); + value = label.value; + } - if (dynamic_cast(label.value.get()) || - dynamic_cast(label.value.get())) { - target.multiBinaryLabelCrossEntropy(output, *label.value); + if (dynamic_cast(value.get()) || + dynamic_cast(value.get())) { + target.multiBinaryLabelCrossEntropy(output, *value); } else { Matrix::resizeOrCreate(targetPerDim_, output.getHeight(), output.getWidth(), false, useGpu_); - targetPerDim_->binaryLabelCrossEntropy(output, *label.value); + targetPerDim_->binaryLabelCrossEntropy(output, *value); targetPerDim_->rowSum(target); } } void MultiBinaryLabelCrossEntropy::backwardImp( Matrix& output, Argument& label, Matrix& outputG) { - label.idsToSparseMatrix(output.getWidth(), useGpu_); + MatrixPtr value = nullptr; + if (label.ids) { + CHECK(!value); + value = Matrix::createSparseMatrix( + label.ids->getSize(), output.getWidth(), label.ids->getSize(), + NO_VALUE, SPARSE_CSR, false, useGpu_); + label.idsToSparseMatrix(value); + } else { + CHECK(label.value); + value = label.value; + } - if (dynamic_cast(label.value.get()) || - dynamic_cast(label.value.get())) { - outputG.multiBinaryLabelCrossEntropyBp(output, *label.value); + if (dynamic_cast(value.get()) || + dynamic_cast(value.get())) { + outputG.multiBinaryLabelCrossEntropyBp(output, *value); } else { - outputG.binaryLabelCrossEntropyBp(output, *label.value); + outputG.binaryLabelCrossEntropyBp(output, *value); } } diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index a5a96742e4cadcbaab1dd73d6014f548b5c2efd3..354d0ead071b3d3286ef69379e89f6301e74bfe4 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -572,25 +572,41 @@ void Argument::subArgFrom(const Argument& input, size_t offset, size_t height, } } -void Argument::idsToSparseMatrix(int width, bool useGpu) { - if (ids) { - CHECK(!value); - int height = ids->getSize(); - int nnz = height; - auto rows = IVector::create(height + 1, useGpu); - auto cols = IVector::create(nnz, useGpu); - rows->setElement(0, 0); - for (int i = 0; i < height; i ++) { - int id = ids->getElement(i); - CHECK_LT(id, width); - rows->setElement(i + 1, i + 1); - cols->setElement(i, id); - } - value = Matrix::createSparseMatrix( - nullptr, rows->getData(), cols->getData(), - height, width, nnz, NO_VALUE, SPARSE_CSR, false, useGpu); +void Argument::idsToSparseMatrix(MatrixPtr sparse_mat) { + int height = ids->getSize(); + int width = sparse_mat->getWidth(); + + CpuIVector cpu_ids(height); + cpu_ids.copyFrom(*ids); + int *id_data = cpu_ids.getData(); + + int *rows = nullptr; + int *cols = nullptr; + if (sparse_mat->useGpu()) { + auto gpu_sparse_mat = + dynamic_cast(sparse_mat.get()); + rows = gpu_sparse_mat->rows_; + cols = gpu_sparse_mat->cols_; } else { - CHECK(value); + rows = sparse_mat->getRows(); + cols = sparse_mat->getCols(); + } + + rows[0] = 0; + for (int i = 0; i < height; i ++) { + int id = id_data[i]; + CHECK_LT(id, width); + rows[i + 1] = i + 1; + cols[i] = id; + } + + if (sparse_mat->useGpu()) { + auto gpu_sparse_mat = + dynamic_cast(sparse_mat.get()); + hl_memcpy_csr_matrix(gpu_sparse_mat->sMatrix_.get(), + nullptr, rows, cols, + HPPL_STREAM_DEFAULT); + hl_stream_synchronize(HPPL_STREAM_DEFAULT); } } diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h index 48e1551258fe738e4b7ac09bdc7e792ccba5ffa7..695033138b545e94af4eda2e3c389125acb08661 100644 --- a/paddle/parameter/Argument.h +++ b/paddle/parameter/Argument.h @@ -289,11 +289,9 @@ struct Argument { /* @brief convert the ids vector to value as a sparse matrix - the ids vector keeps valid - @param the matrix width (id range) - @useGpu + @param[out] the output sparse_mat (already allocated) */ - void idsToSparseMatrix(int width, bool useGpu); + void idsToSparseMatrix(MatrixPtr sparse_mat); }; } // namespace paddle