diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index 71e8f8e3a60c9ff340f36c5057a22cecc112fd48..6195e30b9974d3ad092b4cf604e6b74fa481835c 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -126,6 +126,36 @@ extern void hl_matrix_cross_entropy_bp(real* grad_d, int dimM, int dimN); +/** + * @brief Matrix multi-binary label cross entropy + * + * @param[in] output input matrix (M x N). + * @param[out] entropy output matrix (M x 1). + * @param[in] mat input sparse matrix. + * @param[in] dimM matrix height. + * @param[in] dimN matrix width. + */ +extern void hl_matrix_multi_binary_cross_entropy(real* output, + real* entropy, + hl_sparse_matrix_s mat, + int dimM, + int dimN); + +/** + * @brief Matrix multi-binary label cross entropy backprop + * + * @param[in] output input matrix (M x N). + * @param[out] grad output matrix (M x N). + * @param[in] mat input sparse matrix. + * @param[in] dimM matrix height. + * @param[in] dimN matrix width. + */ +extern void hl_matrix_multi_binary_cross_entropy_bp(real* output, + real* grad, + hl_sparse_matrix_s mat, + int dimM, + int dimN); + /** * @brief Matrix zero memory. * diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index e37b1275432caae29b14e95658e3db291632a672..76cac2e57769301fee2e5979e2685976daf35441 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -57,6 +57,18 @@ inline void hl_matrix_cross_entropy_bp(real* grad_d, int dimM, int dimN) {} +inline void hl_matrix_multi_binary_cross_entropy(real* output, + real* entropy, + hl_sparse_matrix_s mat, + int dimM, + int dimN) {} + +inline void hl_matrix_multi_binary_cross_entropy_bp(real* output, + real* grad, + hl_sparse_matrix_s mat, + int dimM, + int dimN) {} + inline void hl_matrix_zero_mem(real* data, int num) {} inline void hl_param_relu_forward(real* output, diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 3df9f63f9e4b79d61a818b2af49a4d9dfd84a9ab..001b62a6b94d60570adea26e058840104c0241eb 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include "hl_matrix_ops.cuh" #include "hl_matrix_apply.cuh" #include "hl_sequence.h" +#include "hl_sparse.ph" #include "paddle/utils/Logging.h" #include "hl_device_functions.cuh" #include "hl_gpu_matrix_kernel.cuh" @@ -317,6 +318,83 @@ void hl_matrix_classification_error(real* A_d, CHECK_SYNC("hl_matrix_classification_error"); } +__global__ void KeMatrixMultiBinaryCrossEntropy(real* output, + real* entropy, + int* row, + int* col, + int dimM, + int dimN) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < dimM) { + for (int i = 0; i < dimN; i ++) { + entropy[index] -= log(1 - output[index * dimN + i]); + } + int *row_col = col + row[index]; + int col_num = row[index + 1] - row[index]; + for (int i = 0; i < col_num; i ++) { + real o = output[index * dimN + row_col[i]]; + entropy[index] -= log(o / (1 - o)); + } + } +} + +void hl_matrix_multi_binary_cross_entropy(real* output, + real* entropy, + hl_sparse_matrix_s csr_mat, + int dimM, + int dimN) { + CHECK_NOTNULL(output); + CHECK_NOTNULL(entropy); + CHECK_NOTNULL(csr_mat); + int n_threads = 1024; + int blocks = (dimM + n_threads - 1) / n_threads; + dim3 threads(n_threads); + dim3 grid(blocks); + hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix); + KeMatrixMultiBinaryCrossEntropy<<< grid, threads, 0, STREAM_DEFAULT >>> + (output, entropy, mat->csr_row, mat->csr_col, dimM, dimN); + CHECK_SYNC("hl_matrix_multi_binary_cross_entropy failed"); +} + +__global__ void KeMatrixMultiBinaryCrossEntropyBp(real* output, + real* grad, + int* row, + int* col, + int dimM, + int dimN) { + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (row_idx < dimM) { + for (int i = 0; i < dimN; i ++) { + int index = row_idx * dimN + i; + grad[index] += 1.0 / (1 - output[index]); + } + int col_num = row[row_idx + 1] - row[row_idx]; + int *row_col = col + row[row_idx]; + for (int i = 0; i < col_num; i ++) { + int index = row_idx * dimN + row_col[i]; + grad[index] -= 1.0 / (output[index] * (1 - output[index])); + } + } +} + +void hl_matrix_multi_binary_cross_entropy_bp(real* output, + real* grad, + hl_sparse_matrix_s csr_mat, + int dimM, + int dimN) { + CHECK_NOTNULL(output); + CHECK_NOTNULL(grad); + CHECK_NOTNULL(csr_mat); + int n_threads = 1024; + int blocks = (dimM + n_threads - 1) / n_threads; + dim3 threads(n_threads); + dim3 grid(blocks); + hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix); + KeMatrixMultiBinaryCrossEntropyBp<<< grid, threads, 0, STREAM_DEFAULT >>> + (output, grad, mat->csr_row, mat->csr_col, dimM, dimN); + CHECK_SYNC("hl_matrix_multi_binary_cross_entropy_bp failed"); +} + __global__ void KeMatrixCrossEntropy(real* O, real* E, int* label, diff --git a/paddle/gserver/layers/CostLayer.cpp b/paddle/gserver/layers/CostLayer.cpp index 949788be497874a5bb34e49e11bdc8ba3205ba61..c86e562d0e445604f65352a5db0b9d28e77d0825 100644 --- a/paddle/gserver/layers/CostLayer.cpp +++ b/paddle/gserver/layers/CostLayer.cpp @@ -462,6 +462,8 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap, void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label, Matrix& target) { + label.idsToSparseMatrix(output.getWidth(), useGpu_); + if (dynamic_cast(label.value.get()) || dynamic_cast(label.value.get())) { target.multiBinaryLabelCrossEntropy(output, *label.value); @@ -476,6 +478,8 @@ void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label, void MultiBinaryLabelCrossEntropy::backwardImp( Matrix& output, Argument& label, Matrix& outputG) { + label.idsToSparseMatrix(output.getWidth(), useGpu_); + if (dynamic_cast(label.value.get()) || dynamic_cast(label.value.get())) { outputG.multiBinaryLabelCrossEntropyBp(output, *label.value); diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index e7e07e9e69dc7a1b51211364dad7043bdcbaf4c3..f19c14f56925ac3768045fd3c3afd787db43fbb5 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -538,9 +538,10 @@ TEST(Layer, multi_binary_label) { config.layerConfig.add_inputs(); config.layerConfig.add_inputs(); - // Not support GPU now - testLayerGrad(config, "multi_binary_label_cross_entropy", 100, - /* trans */ false, /* useGpu */ false); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "multi_binary_label_cross_entropy", 100, + /* trans */ false, useGpu); + } } TEST(Layer, multi_cross_with_selfnorm) { diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 950c3bb6cca28ad4e9c10bc984898c9d643478c4..9acc6005532fc06482a5d038aafe8af0de54579f 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1268,6 +1268,42 @@ void GpuMatrix::bilinearBackward(const Matrix& out, } } +void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) { + GpuMatrix* output_ptr = dynamic_cast(&output); + auto label_ptr = dynamic_cast(&label); + + CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; + CHECK(label_ptr->format_ == SPARSE_CSR) << "Matrix format not supported"; + CHECK(height_ == output_ptr->height_ && width_ == 1 + && output_ptr->width_ == label_ptr->getWidth() + && output_ptr->height_ == label_ptr->getHeight()) + << "Matrix dimensions are not equal"; + + real* output_d = output_ptr->data_; + real* entropy_d = data_; + hl_sparse_matrix_s mat_d = label_ptr->sMatrix_.get(); + hl_matrix_multi_binary_cross_entropy( + output_d, entropy_d, mat_d, height_, output_ptr->width_); +} + +void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix &output, Matrix &label) { + GpuMatrix* output_ptr = dynamic_cast(&output); + auto label_ptr = dynamic_cast(&label); + + CHECK(output_ptr && label_ptr) << "Invalid argument pointer"; + CHECK(label_ptr->format_ == SPARSE_CSR) << "Matrix format not supported"; + CHECK(height_ == output_ptr->height_ && width_ == output_ptr->width_ + && output_ptr->width_ == label_ptr->getWidth() + && output_ptr->height_ == label_ptr->getHeight()) + << "Matrix dimensions are not equal"; + + real* output_d = output_ptr->data_; + real* grad_d = data_; + hl_sparse_matrix_s mat_d = label_ptr->sMatrix_.get(); + hl_matrix_multi_binary_cross_entropy_bp( + output_d, grad_d, mat_d, height_, width_); +} + /** * CpuMatrix */ diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 700be7590240c357fffcd90992fc50e4e80d9137..6c3c4804d2fc67a3378c61e8b8ff1e3c0087dd83 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -1303,6 +1303,10 @@ public: const size_t numChannels, const real ratioH, const real ratioW); + + void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label); + + void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label); }; class CpuMatrix : public Matrix { diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index b3ee4bc34995aa20d069a8a867658dfa498a031c..a41e21903f5604ec2cc255e23a807c2c6c8a7f4b 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -2208,7 +2208,6 @@ void testCollectSharedBias(int numSamples, int dim, int channel) { MatrixCheckErr(*cpuBias, *check); } - TEST(Matrix, sharedBias) { for (auto numSamples : {1, 100, 520}) { for (auto dim : {100 * 16, 100 * 32}) { @@ -2222,6 +2221,71 @@ TEST(Matrix, sharedBias) { } } +void testMultiBinaryLabelCrossEntropy(int numSamples, int dim) { + MatrixPtr output = std::make_shared(numSamples, dim); + MatrixPtr cpuOutput = std::make_shared(numSamples, dim); + MatrixPtr gpuOutput = std::make_shared(numSamples, dim); + + MatrixPtr cpuEntropy = std::make_shared(numSamples, 1); + MatrixPtr gpuEntropy = std::make_shared(numSamples, 1); + + MatrixPtr cpuGrad = std::make_shared(numSamples, dim); + MatrixPtr gpuGrad = std::make_shared(numSamples, dim); + + auto cpuRows = IVector::create(numSamples + 1, false); + auto cpuCols = IVector::create(numSamples, false); + auto gpuRows = IVector::create(numSamples + 1, true); + auto gpuCols = IVector::create(numSamples, true); + cpuRows->setElement(0, 0); + gpuRows->setElement(0, 0); + for (int i = 0; i < numSamples; i ++) { + int id = rand() % dim; // NOLINT + cpuRows->setElement(i + 1, i + 1); + gpuRows->setElement(i + 1, i + 1); + cpuCols->setElement(i, id); + gpuCols->setElement(i, id); + } + + MatrixPtr cpuLabel = std::make_shared + (nullptr, cpuRows->getData(), cpuCols->getData(), + numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); + MatrixPtr gpuLabel = std::make_shared + (nullptr, gpuRows->getData(), gpuCols->getData(), + numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); + + output->randomizeUniform(); + cpuOutput->zeroMem(); + output->softmax(*cpuOutput); + gpuOutput->copyFrom(*cpuOutput); + + cpuEntropy->zeroMem(); + gpuEntropy->zeroMem(); + cpuEntropy->multiBinaryLabelCrossEntropy(*cpuOutput, *cpuLabel); + gpuEntropy->multiBinaryLabelCrossEntropy(*gpuOutput, *gpuLabel); + + MatrixPtr check1 = std::make_shared(numSamples, 1); + check1->copyFrom(*gpuEntropy); + MatrixCheckErr(*cpuEntropy, *check1); + + cpuGrad->zeroMem(); + gpuGrad->zeroMem(); + cpuGrad->multiBinaryLabelCrossEntropyBp(*cpuOutput, *cpuLabel); + gpuGrad->multiBinaryLabelCrossEntropyBp(*gpuOutput, *gpuLabel); + + MatrixPtr check2 = std::make_shared(numSamples, dim); + check2->copyFrom(*gpuGrad); + MatrixCheckErr(*cpuGrad, *check2); +} + +TEST(Matrix, multiBinaryCrossEntropy) { + for (auto numSamples : {1, 100, 500}) { + for (auto dim : {1000, 10000, 100000}) { + VLOG(3) << " numSamples=" << numSamples << " dim=" << dim; + testMultiBinaryLabelCrossEntropy(numSamples, dim); + } + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index 42c74661d2b2cebe0c2f5f14d0970ab2f1fec866..a5a96742e4cadcbaab1dd73d6014f548b5c2efd3 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -572,4 +572,26 @@ void Argument::subArgFrom(const Argument& input, size_t offset, size_t height, } } +void Argument::idsToSparseMatrix(int width, bool useGpu) { + if (ids) { + CHECK(!value); + int height = ids->getSize(); + int nnz = height; + auto rows = IVector::create(height + 1, useGpu); + auto cols = IVector::create(nnz, useGpu); + rows->setElement(0, 0); + for (int i = 0; i < height; i ++) { + int id = ids->getElement(i); + CHECK_LT(id, width); + rows->setElement(i + 1, i + 1); + cols->setElement(i, id); + } + value = Matrix::createSparseMatrix( + nullptr, rows->getData(), cols->getData(), + height, width, nnz, NO_VALUE, SPARSE_CSR, false, useGpu); + } else { + CHECK(value); + } +} + } // namespace paddle diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h index 81ff9029bc4c8fca7adbabd7ae65caf7ac2f3c2a..48e1551258fe738e4b7ac09bdc7e792ccba5ffa7 100644 --- a/paddle/parameter/Argument.h +++ b/paddle/parameter/Argument.h @@ -286,6 +286,14 @@ struct Argument { sequence has sub-sequence degrades to a sequence. */ void degradeSequence(const Argument& input, bool useGpu); + + /* + @brief convert the ids vector to value as a sparse matrix + the ids vector keeps valid + @param the matrix width (id range) + @useGpu + */ + void idsToSparseMatrix(int width, bool useGpu); }; } // namespace paddle