diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 001b62a6b94d60570adea26e058840104c0241eb..0b7cd3375671d58464dac93458ec6659add8b730 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -346,6 +346,7 @@ void hl_matrix_multi_binary_cross_entropy(real* output, CHECK_NOTNULL(output); CHECK_NOTNULL(entropy); CHECK_NOTNULL(csr_mat); + CHECK_EQ(csr_mat->format, HL_SPARSE_CSR); int n_threads = 1024; int blocks = (dimM + n_threads - 1) / n_threads; dim3 threads(n_threads); @@ -385,6 +386,7 @@ void hl_matrix_multi_binary_cross_entropy_bp(real* output, CHECK_NOTNULL(output); CHECK_NOTNULL(grad); CHECK_NOTNULL(csr_mat); + CHECK_EQ(csr_mat->format, HL_SPARSE_CSR); int n_threads = 1024; int blocks = (dimM + n_threads - 1) / n_threads; dim3 threads(n_threads); @@ -763,7 +765,7 @@ __global__ void KeMatrixAddSharedBias(real* A, int dim = N / channel; if (index < M * N) { int i = index % N; - i = i / dim; + i = i / dim; A[index] += scale * B[i]; } } @@ -791,7 +793,7 @@ __global__ void KeMatrixCollectSharedBias(real *B, const int dim, const int limit, real scale) { - if (dim < limit) { + if (dim < limit) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < channel) { real sum = 0.0; diff --git a/paddle/gserver/layers/CostLayer.cpp b/paddle/gserver/layers/CostLayer.cpp index 900981d1e7d36c8eb2f2677c7455eab153503ef2..2f85dd3c3b69d21cffede49b001298c6629900a6 100644 --- a/paddle/gserver/layers/CostLayer.cpp +++ b/paddle/gserver/layers/CostLayer.cpp @@ -465,10 +465,7 @@ void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label, MatrixPtr value = nullptr; if (label.ids) { CHECK(!label.value); - value = Matrix::createSparseMatrix( - label.ids->getSize(), output.getWidth(), label.ids->getSize(), - NO_VALUE, SPARSE_CSR, false, useGpu_); - label.idsToSparseMatrix(value); + value = label.ids->toOneHotSparseMatrix(output.getWidth(), useGpu_); } else { CHECK(label.value); value = label.value; @@ -491,10 +488,7 @@ void MultiBinaryLabelCrossEntropy::backwardImp( MatrixPtr value = nullptr; if (label.ids) { CHECK(!value); - value = Matrix::createSparseMatrix( - label.ids->getSize(), output.getWidth(), label.ids->getSize(), - NO_VALUE, SPARSE_CSR, false, useGpu_); - label.idsToSparseMatrix(value); + value = label.ids->toOneHotSparseMatrix(output.getWidth(), useGpu_); } else { CHECK(label.value); value = label.value; diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index f19c14f56925ac3768045fd3c3afd787db43fbb5..f3cd2b4faf0c173cbb4997aac1a00ebba3027c92 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -528,7 +528,7 @@ TEST(Layer, multi_cross) { } } -TEST(Layer, multi_binary_label) { +TEST(Layer, multi_binary_label_sparse_mat) { TestConfig config; config.layerConfig.set_type("multi_binary_label_cross_entropy"); config.biasSize = 0; @@ -544,6 +544,22 @@ TEST(Layer, multi_binary_label) { } } +TEST(layer, multi_binary_label_id) { + TestConfig config; + config.layerConfig.set_type("multi_binary_label_cross_entropy"); + config.biasSize = 0; + + config.inputDefs.push_back({INPUT_DATA, "layer_0", 50, 0}); + config.inputDefs.push_back({INPUT_LABEL, "layer_1", 10, 0}); + config.layerConfig.add_inputs(); + config.layerConfig.add_inputs(); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "multi_binary_label_cross_entropy", 100, + /* trans */ false, useGpu); + } +} + TEST(Layer, multi_cross_with_selfnorm) { TestConfig config; config.layerConfig.set_type("multi_class_cross_entropy_with_selfnorm"); diff --git a/paddle/math/CpuSparseMatrix.cpp b/paddle/math/CpuSparseMatrix.cpp index 842efdbe3d77ec3443374f62df5c520252aa7ce4..64ee124a5613a99ac3d7ff36897e4f2d0489ad51 100644 --- a/paddle/math/CpuSparseMatrix.cpp +++ b/paddle/math/CpuSparseMatrix.cpp @@ -409,9 +409,6 @@ void CpuSparseMatrix::setRow(size_t row, size_t colNum, if (format_ == SPARSE_CSR) { CHECK_LT(row, height_); CHECK(NULL != cols); - for (size_t i = row; i < height_; i++) { - CHECK_EQ(rows_[i + 1], rows_[i]); - } if (0 == row) { rows_[row] = 0; } diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 9acc6005532fc06482a5d038aafe8af0de54579f..5ee8fbebfcfbe9696f7836b6b1c88e724551da8e 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1269,37 +1269,37 @@ void GpuMatrix::bilinearBackward(const Matrix& out, } void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) { - GpuMatrix* output_ptr = dynamic_cast(&output); - auto label_ptr = dynamic_cast(&label); - - CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; - CHECK(label_ptr->format_ == SPARSE_CSR) << "Matrix format not supported"; - CHECK(height_ == output_ptr->height_ && width_ == 1 - && output_ptr->width_ == label_ptr->getWidth() - && output_ptr->height_ == label_ptr->getHeight()) + GpuMatrix* outputPtr = dynamic_cast(&output); + auto labelPtr = dynamic_cast(&label); + + CHECK(outputPtr && labelPtr) << "Invalid argument pointer"; + CHECK(labelPtr->format_ == SPARSE_CSR) << "Matrix format not supported"; + CHECK(height_ == outputPtr->height_ && width_ == 1 + && outputPtr->width_ == labelPtr->getWidth() + && outputPtr->height_ == labelPtr->getHeight()) << "Matrix dimensions are not equal"; - real* output_d = output_ptr->data_; + real* output_d = outputPtr->data_; real* entropy_d = data_; - hl_sparse_matrix_s mat_d = label_ptr->sMatrix_.get(); + hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get(); hl_matrix_multi_binary_cross_entropy( - output_d, entropy_d, mat_d, height_, output_ptr->width_); + output_d, entropy_d, mat_d, height_, outputPtr->width_); } void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix &output, Matrix &label) { - GpuMatrix* output_ptr = dynamic_cast(&output); - auto label_ptr = dynamic_cast(&label); - - CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; - CHECK(label_ptr->format_ == SPARSE_CSR) << "Matrix format not supported"; - CHECK(height_ == output_ptr->height_ && width_ == output_ptr->width_ - && output_ptr->width_ == label_ptr->getWidth() - && output_ptr->height_ == label_ptr->getHeight()) + GpuMatrix* outputPtr = dynamic_cast(&output); + auto labelPtr = dynamic_cast(&label); + + CHECK(outputPtr && labelPtr) << "Invalid argument pointer"; + CHECK(labelPtr->format_ == SPARSE_CSR) << "Matrix format not supported"; + CHECK(height_ == outputPtr->height_ && width_ == outputPtr->width_ + && outputPtr->width_ == labelPtr->getWidth() + && outputPtr->height_ == labelPtr->getHeight()) << "Matrix dimensions are not equal"; - real* output_d = output_ptr->data_; + real* output_d = outputPtr->data_; real* grad_d = data_; - hl_sparse_matrix_s mat_d = label_ptr->sMatrix_.get(); + hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get(); hl_matrix_multi_binary_cross_entropy_bp( output_d, grad_d, mat_d, height_, width_); } diff --git a/paddle/math/Vector.cpp b/paddle/math/Vector.cpp index 7553ea25e09d2f52f1f8b9205f954510b77cbfa9..23c9cacceaea2a2e265108b2467c1d21a2fe312f 100644 --- a/paddle/math/Vector.cpp +++ b/paddle/math/Vector.cpp @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/utils/ThreadLocal.h" #include "paddle/utils/Thread.h" #include "paddle/utils/Flags.h" +#include "Matrix.h" #include "hl_gpu.h" #include "hl_table_apply.h" @@ -73,6 +74,31 @@ std::shared_ptr> VectorT::create(size_t size, } } +template <> +MatrixPtr VectorT::toOneHotSparseMatrix(size_t idRange, bool useGpu) { + LOG(FATAL) << "Wrong for real vector"; + return nullptr; +} + +template <> +MatrixPtr VectorT::toOneHotSparseMatrix(size_t idRange, bool useGpu) { + int height = getSize(); + int width = idRange; + MatrixPtr mat = Matrix::createSparseMatrix( + height, idRange, height, NO_VALUE, SPARSE_CSR, false, useGpu); + + CpuIVector cpuIds(height); + cpuIds.copyFrom(*this); + int *idData = cpuIds.getData(); + + for (int i = 0; i < height; i ++) { + const unsigned int id = idData[i]; + CHECK_LT(id, width); + mat->setRow(i, 1, &id, nullptr); + } + return mat; +} + template GpuVectorT::GpuVectorT(size_t size) : VectorT(size, std::make_shared(sizeof(T) * size), diff --git a/paddle/math/Vector.h b/paddle/math/Vector.h index ee0a83bf038f04ee9f7b3561639aa90da68a6e29..faf8186b6d10d7cbc14376ff3b6543d1303b2ab1 100644 --- a/paddle/math/Vector.h +++ b/paddle/math/Vector.h @@ -37,6 +37,8 @@ class BaseVector; class SyncThreadPool; +class Matrix; + template class BaseVector : public BaseMatrixT { public: @@ -155,6 +157,12 @@ public: subVecFrom(src, interval.first, interval.second - interval.first); } + /** + * convert the vector to a sparse one_hot matrix of width idRange + * only applies to IVector + */ + std::shared_ptr toOneHotSparseMatrix(size_t idRange, bool useGpu); + /** * This function will crash if the size of src and dest is different. */ diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index a41e21903f5604ec2cc255e23a807c2c6c8a7f4b..9c03695ba5055c4bdb3e7c578d3e352fbd6fae6f 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -2232,26 +2232,15 @@ void testMultiBinaryLabelCrossEntropy(int numSamples, int dim) { MatrixPtr cpuGrad = std::make_shared(numSamples, dim); MatrixPtr gpuGrad = std::make_shared(numSamples, dim); - auto cpuRows = IVector::create(numSamples + 1, false); - auto cpuCols = IVector::create(numSamples, false); - auto gpuRows = IVector::create(numSamples + 1, true); - auto gpuCols = IVector::create(numSamples, true); - cpuRows->setElement(0, 0); - gpuRows->setElement(0, 0); - for (int i = 0; i < numSamples; i ++) { - int id = rand() % dim; // NOLINT - cpuRows->setElement(i + 1, i + 1); - gpuRows->setElement(i + 1, i + 1); - cpuCols->setElement(i, id); - gpuCols->setElement(i, id); - } - MatrixPtr cpuLabel = std::make_shared - (nullptr, cpuRows->getData(), cpuCols->getData(), - numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); + (numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); MatrixPtr gpuLabel = std::make_shared - (nullptr, gpuRows->getData(), gpuCols->getData(), - numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); + (numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); + for (int i = 0; i < numSamples; i ++) { + const unsigned int id = rand() % dim; // NOLINT + cpuLabel->setRow(i, 1, &id, nullptr); + gpuLabel->setRow(i, 1, &id, nullptr); + } output->randomizeUniform(); cpuOutput->zeroMem(); @@ -2278,8 +2267,8 @@ void testMultiBinaryLabelCrossEntropy(int numSamples, int dim) { } TEST(Matrix, multiBinaryCrossEntropy) { - for (auto numSamples : {1, 100, 500}) { - for (auto dim : {1000, 10000, 100000}) { + for (auto numSamples : {100, 1000, 10000}) { + for (auto dim : {100, 1000, 10000}) { VLOG(3) << " numSamples=" << numSamples << " dim=" << dim; testMultiBinaryLabelCrossEntropy(numSamples, dim); } diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index 354d0ead071b3d3286ef69379e89f6301e74bfe4..42c74661d2b2cebe0c2f5f14d0970ab2f1fec866 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -572,42 +572,4 @@ void Argument::subArgFrom(const Argument& input, size_t offset, size_t height, } } -void Argument::idsToSparseMatrix(MatrixPtr sparse_mat) { - int height = ids->getSize(); - int width = sparse_mat->getWidth(); - - CpuIVector cpu_ids(height); - cpu_ids.copyFrom(*ids); - int *id_data = cpu_ids.getData(); - - int *rows = nullptr; - int *cols = nullptr; - if (sparse_mat->useGpu()) { - auto gpu_sparse_mat = - dynamic_cast(sparse_mat.get()); - rows = gpu_sparse_mat->rows_; - cols = gpu_sparse_mat->cols_; - } else { - rows = sparse_mat->getRows(); - cols = sparse_mat->getCols(); - } - - rows[0] = 0; - for (int i = 0; i < height; i ++) { - int id = id_data[i]; - CHECK_LT(id, width); - rows[i + 1] = i + 1; - cols[i] = id; - } - - if (sparse_mat->useGpu()) { - auto gpu_sparse_mat = - dynamic_cast(sparse_mat.get()); - hl_memcpy_csr_matrix(gpu_sparse_mat->sMatrix_.get(), - nullptr, rows, cols, - HPPL_STREAM_DEFAULT); - hl_stream_synchronize(HPPL_STREAM_DEFAULT); - } -} - } // namespace paddle diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h index 695033138b545e94af4eda2e3c389125acb08661..81ff9029bc4c8fca7adbabd7ae65caf7ac2f3c2a 100644 --- a/paddle/parameter/Argument.h +++ b/paddle/parameter/Argument.h @@ -286,12 +286,6 @@ struct Argument { sequence has sub-sequence degrades to a sequence. */ void degradeSequence(const Argument& input, bool useGpu); - - /* - @brief convert the ids vector to value as a sparse matrix - @param[out] the output sparse_mat (already allocated) - */ - void idsToSparseMatrix(MatrixPtr sparse_mat); }; } // namespace paddle