diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index 71e8f8e3a60c9ff340f36c5057a22cecc112fd48..6195e30b9974d3ad092b4cf604e6b74fa481835c 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -126,6 +126,36 @@ extern void hl_matrix_cross_entropy_bp(real* grad_d, int dimM, int dimN); +/** + * @brief Matrix multi-binary label cross entropy + * + * @param[in] output input matrix (M x N). + * @param[out] entropy output matrix (M x 1). + * @param[in] mat input sparse matrix. + * @param[in] dimM matrix height. + * @param[in] dimN matrix width. + */ +extern void hl_matrix_multi_binary_cross_entropy(real* output, + real* entropy, + hl_sparse_matrix_s mat, + int dimM, + int dimN); + +/** + * @brief Matrix multi-binary label cross entropy backprop + * + * @param[in] output input matrix (M x N). + * @param[out] grad output matrix (M x N). + * @param[in] mat input sparse matrix. + * @param[in] dimM matrix height. + * @param[in] dimN matrix width. + */ +extern void hl_matrix_multi_binary_cross_entropy_bp(real* output, + real* grad, + hl_sparse_matrix_s mat, + int dimM, + int dimN); + /** * @brief Matrix zero memory. * diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index e37b1275432caae29b14e95658e3db291632a672..76cac2e57769301fee2e5979e2685976daf35441 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -57,6 +57,18 @@ inline void hl_matrix_cross_entropy_bp(real* grad_d, int dimM, int dimN) {} +inline void hl_matrix_multi_binary_cross_entropy(real* output, + real* entropy, + hl_sparse_matrix_s mat, + int dimM, + int dimN) {} + +inline void hl_matrix_multi_binary_cross_entropy_bp(real* output, + real* grad, + hl_sparse_matrix_s mat, + int dimM, + int dimN) {} + inline void hl_matrix_zero_mem(real* data, int num) {} inline void hl_param_relu_forward(real* output, diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 3df9f63f9e4b79d61a818b2af49a4d9dfd84a9ab..0b7cd3375671d58464dac93458ec6659add8b730 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include "hl_matrix_ops.cuh" #include "hl_matrix_apply.cuh" #include "hl_sequence.h" +#include "hl_sparse.ph" #include "paddle/utils/Logging.h" #include "hl_device_functions.cuh" #include "hl_gpu_matrix_kernel.cuh" @@ -317,6 +318,85 @@ void hl_matrix_classification_error(real* A_d, CHECK_SYNC("hl_matrix_classification_error"); } +__global__ void KeMatrixMultiBinaryCrossEntropy(real* output, + real* entropy, + int* row, + int* col, + int dimM, + int dimN) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < dimM) { + for (int i = 0; i < dimN; i ++) { + entropy[index] -= log(1 - output[index * dimN + i]); + } + int *row_col = col + row[index]; + int col_num = row[index + 1] - row[index]; + for (int i = 0; i < col_num; i ++) { + real o = output[index * dimN + row_col[i]]; + entropy[index] -= log(o / (1 - o)); + } + } +} + +void hl_matrix_multi_binary_cross_entropy(real* output, + real* entropy, + hl_sparse_matrix_s csr_mat, + int dimM, + int dimN) { + CHECK_NOTNULL(output); + CHECK_NOTNULL(entropy); + CHECK_NOTNULL(csr_mat); + CHECK_EQ(csr_mat->format, HL_SPARSE_CSR); + int n_threads = 1024; + int blocks = (dimM + n_threads - 1) / n_threads; + dim3 threads(n_threads); + dim3 grid(blocks); + hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix); + KeMatrixMultiBinaryCrossEntropy<<< grid, threads, 0, STREAM_DEFAULT >>> + (output, entropy, mat->csr_row, mat->csr_col, dimM, dimN); + CHECK_SYNC("hl_matrix_multi_binary_cross_entropy failed"); +} + +__global__ void KeMatrixMultiBinaryCrossEntropyBp(real* output, + real* grad, + int* row, + int* col, + int dimM, + int dimN) { + int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (row_idx < dimM) { + for (int i = 0; i < dimN; i ++) { + int index = row_idx * dimN + i; + grad[index] += 1.0 / (1 - output[index]); + } + int col_num = row[row_idx + 1] - row[row_idx]; + int *row_col = col + row[row_idx]; + for (int i = 0; i < col_num; i ++) { + int index = row_idx * dimN + row_col[i]; + grad[index] -= 1.0 / (output[index] * (1 - output[index])); + } + } +} + +void hl_matrix_multi_binary_cross_entropy_bp(real* output, + real* grad, + hl_sparse_matrix_s csr_mat, + int dimM, + int dimN) { + CHECK_NOTNULL(output); + CHECK_NOTNULL(grad); + CHECK_NOTNULL(csr_mat); + CHECK_EQ(csr_mat->format, HL_SPARSE_CSR); + int n_threads = 1024; + int blocks = (dimM + n_threads - 1) / n_threads; + dim3 threads(n_threads); + dim3 grid(blocks); + hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix); + KeMatrixMultiBinaryCrossEntropyBp<<< grid, threads, 0, STREAM_DEFAULT >>> + (output, grad, mat->csr_row, mat->csr_col, dimM, dimN); + CHECK_SYNC("hl_matrix_multi_binary_cross_entropy_bp failed"); +} + __global__ void KeMatrixCrossEntropy(real* O, real* E, int* label, @@ -685,7 +765,7 @@ __global__ void KeMatrixAddSharedBias(real* A, int dim = N / channel; if (index < M * N) { int i = index % N; - i = i / dim; + i = i / dim; A[index] += scale * B[i]; } } @@ -713,7 +793,7 @@ __global__ void KeMatrixCollectSharedBias(real *B, const int dim, const int limit, real scale) { - if (dim < limit) { + if (dim < limit) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < channel) { real sum = 0.0; diff --git a/paddle/gserver/layers/CostLayer.cpp b/paddle/gserver/layers/CostLayer.cpp index 949788be497874a5bb34e49e11bdc8ba3205ba61..2f85dd3c3b69d21cffede49b001298c6629900a6 100644 --- a/paddle/gserver/layers/CostLayer.cpp +++ b/paddle/gserver/layers/CostLayer.cpp @@ -462,25 +462,43 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap, void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label, Matrix& target) { - if (dynamic_cast(label.value.get()) || - dynamic_cast(label.value.get())) { - target.multiBinaryLabelCrossEntropy(output, *label.value); + MatrixPtr value = nullptr; + if (label.ids) { + CHECK(!label.value); + value = label.ids->toOneHotSparseMatrix(output.getWidth(), useGpu_); + } else { + CHECK(label.value); + value = label.value; + } + + if (dynamic_cast(value.get()) || + dynamic_cast(value.get())) { + target.multiBinaryLabelCrossEntropy(output, *value); } else { Matrix::resizeOrCreate(targetPerDim_, output.getHeight(), output.getWidth(), false, useGpu_); - targetPerDim_->binaryLabelCrossEntropy(output, *label.value); + targetPerDim_->binaryLabelCrossEntropy(output, *value); targetPerDim_->rowSum(target); } } void MultiBinaryLabelCrossEntropy::backwardImp( Matrix& output, Argument& label, Matrix& outputG) { - if (dynamic_cast(label.value.get()) || - dynamic_cast(label.value.get())) { - outputG.multiBinaryLabelCrossEntropyBp(output, *label.value); + MatrixPtr value = nullptr; + if (label.ids) { + CHECK(!value); + value = label.ids->toOneHotSparseMatrix(output.getWidth(), useGpu_); + } else { + CHECK(label.value); + value = label.value; + } + + if (dynamic_cast(value.get()) || + dynamic_cast(value.get())) { + outputG.multiBinaryLabelCrossEntropyBp(output, *value); } else { - outputG.binaryLabelCrossEntropyBp(output, *label.value); + outputG.binaryLabelCrossEntropyBp(output, *value); } } diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index e7e07e9e69dc7a1b51211364dad7043bdcbaf4c3..f3cd2b4faf0c173cbb4997aac1a00ebba3027c92 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -528,7 +528,7 @@ TEST(Layer, multi_cross) { } } -TEST(Layer, multi_binary_label) { +TEST(Layer, multi_binary_label_sparse_mat) { TestConfig config; config.layerConfig.set_type("multi_binary_label_cross_entropy"); config.biasSize = 0; @@ -538,9 +538,26 @@ TEST(Layer, multi_binary_label) { config.layerConfig.add_inputs(); config.layerConfig.add_inputs(); - // Not support GPU now - testLayerGrad(config, "multi_binary_label_cross_entropy", 100, - /* trans */ false, /* useGpu */ false); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "multi_binary_label_cross_entropy", 100, + /* trans */ false, useGpu); + } +} + +TEST(layer, multi_binary_label_id) { + TestConfig config; + config.layerConfig.set_type("multi_binary_label_cross_entropy"); + config.biasSize = 0; + + config.inputDefs.push_back({INPUT_DATA, "layer_0", 50, 0}); + config.inputDefs.push_back({INPUT_LABEL, "layer_1", 10, 0}); + config.layerConfig.add_inputs(); + config.layerConfig.add_inputs(); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "multi_binary_label_cross_entropy", 100, + /* trans */ false, useGpu); + } } TEST(Layer, multi_cross_with_selfnorm) { diff --git a/paddle/math/CpuSparseMatrix.cpp b/paddle/math/CpuSparseMatrix.cpp index 842efdbe3d77ec3443374f62df5c520252aa7ce4..64ee124a5613a99ac3d7ff36897e4f2d0489ad51 100644 --- a/paddle/math/CpuSparseMatrix.cpp +++ b/paddle/math/CpuSparseMatrix.cpp @@ -409,9 +409,6 @@ void CpuSparseMatrix::setRow(size_t row, size_t colNum, if (format_ == SPARSE_CSR) { CHECK_LT(row, height_); CHECK(NULL != cols); - for (size_t i = row; i < height_; i++) { - CHECK_EQ(rows_[i + 1], rows_[i]); - } if (0 == row) { rows_[row] = 0; } diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 950c3bb6cca28ad4e9c10bc984898c9d643478c4..5ee8fbebfcfbe9696f7836b6b1c88e724551da8e 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1268,6 +1268,42 @@ void GpuMatrix::bilinearBackward(const Matrix& out, } } +void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) { + GpuMatrix* outputPtr = dynamic_cast(&output); + auto labelPtr = dynamic_cast(&label); + + CHECK(outputPtr && labelPtr) << "Invalid argument pointer"; + CHECK(labelPtr->format_ == SPARSE_CSR) << "Matrix format not supported"; + CHECK(height_ == outputPtr->height_ && width_ == 1 + && outputPtr->width_ == labelPtr->getWidth() + && outputPtr->height_ == labelPtr->getHeight()) + << "Matrix dimensions are not equal"; + + real* output_d = outputPtr->data_; + real* entropy_d = data_; + hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get(); + hl_matrix_multi_binary_cross_entropy( + output_d, entropy_d, mat_d, height_, outputPtr->width_); +} + +void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix &output, Matrix &label) { + GpuMatrix* outputPtr = dynamic_cast(&output); + auto labelPtr = dynamic_cast(&label); + + CHECK(outputPtr && labelPtr) << "Invalid argument pointer"; + CHECK(labelPtr->format_ == SPARSE_CSR) << "Matrix format not supported"; + CHECK(height_ == outputPtr->height_ && width_ == outputPtr->width_ + && outputPtr->width_ == labelPtr->getWidth() + && outputPtr->height_ == labelPtr->getHeight()) + << "Matrix dimensions are not equal"; + + real* output_d = outputPtr->data_; + real* grad_d = data_; + hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get(); + hl_matrix_multi_binary_cross_entropy_bp( + output_d, grad_d, mat_d, height_, width_); +} + /** * CpuMatrix */ diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 700be7590240c357fffcd90992fc50e4e80d9137..6c3c4804d2fc67a3378c61e8b8ff1e3c0087dd83 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -1303,6 +1303,10 @@ public: const size_t numChannels, const real ratioH, const real ratioW); + + void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label); + + void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label); }; class CpuMatrix : public Matrix { diff --git a/paddle/math/Vector.cpp b/paddle/math/Vector.cpp index 7553ea25e09d2f52f1f8b9205f954510b77cbfa9..23c9cacceaea2a2e265108b2467c1d21a2fe312f 100644 --- a/paddle/math/Vector.cpp +++ b/paddle/math/Vector.cpp @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/utils/ThreadLocal.h" #include "paddle/utils/Thread.h" #include "paddle/utils/Flags.h" +#include "Matrix.h" #include "hl_gpu.h" #include "hl_table_apply.h" @@ -73,6 +74,31 @@ std::shared_ptr> VectorT::create(size_t size, } } +template <> +MatrixPtr VectorT::toOneHotSparseMatrix(size_t idRange, bool useGpu) { + LOG(FATAL) << "Wrong for real vector"; + return nullptr; +} + +template <> +MatrixPtr VectorT::toOneHotSparseMatrix(size_t idRange, bool useGpu) { + int height = getSize(); + int width = idRange; + MatrixPtr mat = Matrix::createSparseMatrix( + height, idRange, height, NO_VALUE, SPARSE_CSR, false, useGpu); + + CpuIVector cpuIds(height); + cpuIds.copyFrom(*this); + int *idData = cpuIds.getData(); + + for (int i = 0; i < height; i ++) { + const unsigned int id = idData[i]; + CHECK_LT(id, width); + mat->setRow(i, 1, &id, nullptr); + } + return mat; +} + template GpuVectorT::GpuVectorT(size_t size) : VectorT(size, std::make_shared(sizeof(T) * size), diff --git a/paddle/math/Vector.h b/paddle/math/Vector.h index ee0a83bf038f04ee9f7b3561639aa90da68a6e29..faf8186b6d10d7cbc14376ff3b6543d1303b2ab1 100644 --- a/paddle/math/Vector.h +++ b/paddle/math/Vector.h @@ -37,6 +37,8 @@ class BaseVector; class SyncThreadPool; +class Matrix; + template class BaseVector : public BaseMatrixT { public: @@ -155,6 +157,12 @@ public: subVecFrom(src, interval.first, interval.second - interval.first); } + /** + * convert the vector to a sparse one_hot matrix of width idRange + * only applies to IVector + */ + std::shared_ptr toOneHotSparseMatrix(size_t idRange, bool useGpu); + /** * This function will crash if the size of src and dest is different. */ diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index b3ee4bc34995aa20d069a8a867658dfa498a031c..9c03695ba5055c4bdb3e7c578d3e352fbd6fae6f 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -2208,7 +2208,6 @@ void testCollectSharedBias(int numSamples, int dim, int channel) { MatrixCheckErr(*cpuBias, *check); } - TEST(Matrix, sharedBias) { for (auto numSamples : {1, 100, 520}) { for (auto dim : {100 * 16, 100 * 32}) { @@ -2222,6 +2221,60 @@ TEST(Matrix, sharedBias) { } } +void testMultiBinaryLabelCrossEntropy(int numSamples, int dim) { + MatrixPtr output = std::make_shared(numSamples, dim); + MatrixPtr cpuOutput = std::make_shared(numSamples, dim); + MatrixPtr gpuOutput = std::make_shared(numSamples, dim); + + MatrixPtr cpuEntropy = std::make_shared(numSamples, 1); + MatrixPtr gpuEntropy = std::make_shared(numSamples, 1); + + MatrixPtr cpuGrad = std::make_shared(numSamples, dim); + MatrixPtr gpuGrad = std::make_shared(numSamples, dim); + + MatrixPtr cpuLabel = std::make_shared + (numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); + MatrixPtr gpuLabel = std::make_shared + (numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false); + for (int i = 0; i < numSamples; i ++) { + const unsigned int id = rand() % dim; // NOLINT + cpuLabel->setRow(i, 1, &id, nullptr); + gpuLabel->setRow(i, 1, &id, nullptr); + } + + output->randomizeUniform(); + cpuOutput->zeroMem(); + output->softmax(*cpuOutput); + gpuOutput->copyFrom(*cpuOutput); + + cpuEntropy->zeroMem(); + gpuEntropy->zeroMem(); + cpuEntropy->multiBinaryLabelCrossEntropy(*cpuOutput, *cpuLabel); + gpuEntropy->multiBinaryLabelCrossEntropy(*gpuOutput, *gpuLabel); + + MatrixPtr check1 = std::make_shared(numSamples, 1); + check1->copyFrom(*gpuEntropy); + MatrixCheckErr(*cpuEntropy, *check1); + + cpuGrad->zeroMem(); + gpuGrad->zeroMem(); + cpuGrad->multiBinaryLabelCrossEntropyBp(*cpuOutput, *cpuLabel); + gpuGrad->multiBinaryLabelCrossEntropyBp(*gpuOutput, *gpuLabel); + + MatrixPtr check2 = std::make_shared(numSamples, dim); + check2->copyFrom(*gpuGrad); + MatrixCheckErr(*cpuGrad, *check2); +} + +TEST(Matrix, multiBinaryCrossEntropy) { + for (auto numSamples : {100, 1000, 10000}) { + for (auto dim : {100, 1000, 10000}) { + VLOG(3) << " numSamples=" << numSamples << " dim=" << dim; + testMultiBinaryLabelCrossEntropy(numSamples, dim); + } + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv);