提交 19d81466 编写于 作者: E emailweixu 提交者: GitHub

Merge pull request #358 from yu239/multi_binary_cross_entropy

multi_binary_cross_entropy when ids vector is provided
......@@ -126,6 +126,36 @@ extern void hl_matrix_cross_entropy_bp(real* grad_d,
int dimM,
int dimN);
/**
* @brief Matrix multi-binary label cross entropy
*
* @param[in] output input matrix (M x N).
* @param[out] entropy output matrix (M x 1).
* @param[in] mat input sparse matrix.
* @param[in] dimM matrix height.
* @param[in] dimN matrix width.
*/
extern void hl_matrix_multi_binary_cross_entropy(real* output,
real* entropy,
hl_sparse_matrix_s mat,
int dimM,
int dimN);
/**
* @brief Matrix multi-binary label cross entropy backprop
*
* @param[in] output input matrix (M x N).
* @param[out] grad output matrix (M x N).
* @param[in] mat input sparse matrix.
* @param[in] dimM matrix height.
* @param[in] dimN matrix width.
*/
extern void hl_matrix_multi_binary_cross_entropy_bp(real* output,
real* grad,
hl_sparse_matrix_s mat,
int dimM,
int dimN);
/**
* @brief Matrix zero memory.
*
......
......@@ -57,6 +57,18 @@ inline void hl_matrix_cross_entropy_bp(real* grad_d,
int dimM,
int dimN) {}
inline void hl_matrix_multi_binary_cross_entropy(real* output,
real* entropy,
hl_sparse_matrix_s mat,
int dimM,
int dimN) {}
inline void hl_matrix_multi_binary_cross_entropy_bp(real* output,
real* grad,
hl_sparse_matrix_s mat,
int dimM,
int dimN) {}
inline void hl_matrix_zero_mem(real* data, int num) {}
inline void hl_param_relu_forward(real* output,
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "hl_matrix_ops.cuh"
#include "hl_matrix_apply.cuh"
#include "hl_sequence.h"
#include "hl_sparse.ph"
#include "paddle/utils/Logging.h"
#include "hl_device_functions.cuh"
#include "hl_gpu_matrix_kernel.cuh"
......@@ -317,6 +318,85 @@ void hl_matrix_classification_error(real* A_d,
CHECK_SYNC("hl_matrix_classification_error");
}
__global__ void KeMatrixMultiBinaryCrossEntropy(real* output,
real* entropy,
int* row,
int* col,
int dimM,
int dimN) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < dimM) {
for (int i = 0; i < dimN; i ++) {
entropy[index] -= log(1 - output[index * dimN + i]);
}
int *row_col = col + row[index];
int col_num = row[index + 1] - row[index];
for (int i = 0; i < col_num; i ++) {
real o = output[index * dimN + row_col[i]];
entropy[index] -= log(o / (1 - o));
}
}
}
void hl_matrix_multi_binary_cross_entropy(real* output,
real* entropy,
hl_sparse_matrix_s csr_mat,
int dimM,
int dimN) {
CHECK_NOTNULL(output);
CHECK_NOTNULL(entropy);
CHECK_NOTNULL(csr_mat);
CHECK_EQ(csr_mat->format, HL_SPARSE_CSR);
int n_threads = 1024;
int blocks = (dimM + n_threads - 1) / n_threads;
dim3 threads(n_threads);
dim3 grid(blocks);
hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix);
KeMatrixMultiBinaryCrossEntropy<<< grid, threads, 0, STREAM_DEFAULT >>>
(output, entropy, mat->csr_row, mat->csr_col, dimM, dimN);
CHECK_SYNC("hl_matrix_multi_binary_cross_entropy failed");
}
__global__ void KeMatrixMultiBinaryCrossEntropyBp(real* output,
real* grad,
int* row,
int* col,
int dimM,
int dimN) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx < dimM) {
for (int i = 0; i < dimN; i ++) {
int index = row_idx * dimN + i;
grad[index] += 1.0 / (1 - output[index]);
}
int col_num = row[row_idx + 1] - row[row_idx];
int *row_col = col + row[row_idx];
for (int i = 0; i < col_num; i ++) {
int index = row_idx * dimN + row_col[i];
grad[index] -= 1.0 / (output[index] * (1 - output[index]));
}
}
}
void hl_matrix_multi_binary_cross_entropy_bp(real* output,
real* grad,
hl_sparse_matrix_s csr_mat,
int dimM,
int dimN) {
CHECK_NOTNULL(output);
CHECK_NOTNULL(grad);
CHECK_NOTNULL(csr_mat);
CHECK_EQ(csr_mat->format, HL_SPARSE_CSR);
int n_threads = 1024;
int blocks = (dimM + n_threads - 1) / n_threads;
dim3 threads(n_threads);
dim3 grid(blocks);
hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix);
KeMatrixMultiBinaryCrossEntropyBp<<< grid, threads, 0, STREAM_DEFAULT >>>
(output, grad, mat->csr_row, mat->csr_col, dimM, dimN);
CHECK_SYNC("hl_matrix_multi_binary_cross_entropy_bp failed");
}
__global__ void KeMatrixCrossEntropy(real* O,
real* E,
int* label,
......@@ -685,7 +765,7 @@ __global__ void KeMatrixAddSharedBias(real* A,
int dim = N / channel;
if (index < M * N) {
int i = index % N;
i = i / dim;
i = i / dim;
A[index] += scale * B[i];
}
}
......@@ -713,7 +793,7 @@ __global__ void KeMatrixCollectSharedBias(real *B,
const int dim,
const int limit,
real scale) {
if (dim < limit) {
if (dim < limit) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < channel) {
real sum = 0.0;
......
......@@ -462,25 +462,43 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap,
void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label,
Matrix& target) {
if (dynamic_cast<CpuSparseMatrix*>(label.value.get()) ||
dynamic_cast<GpuSparseMatrix*>(label.value.get())) {
target.multiBinaryLabelCrossEntropy(output, *label.value);
MatrixPtr value = nullptr;
if (label.ids) {
CHECK(!label.value);
value = label.ids->toOneHotSparseMatrix(output.getWidth(), useGpu_);
} else {
CHECK(label.value);
value = label.value;
}
if (dynamic_cast<CpuSparseMatrix*>(value.get()) ||
dynamic_cast<GpuSparseMatrix*>(value.get())) {
target.multiBinaryLabelCrossEntropy(output, *value);
} else {
Matrix::resizeOrCreate(targetPerDim_, output.getHeight(), output.getWidth(),
false, useGpu_);
targetPerDim_->binaryLabelCrossEntropy(output, *label.value);
targetPerDim_->binaryLabelCrossEntropy(output, *value);
targetPerDim_->rowSum(target);
}
}
void MultiBinaryLabelCrossEntropy::backwardImp(
Matrix& output, Argument& label, Matrix& outputG) {
if (dynamic_cast<CpuSparseMatrix*>(label.value.get()) ||
dynamic_cast<GpuSparseMatrix*>(label.value.get())) {
outputG.multiBinaryLabelCrossEntropyBp(output, *label.value);
MatrixPtr value = nullptr;
if (label.ids) {
CHECK(!value);
value = label.ids->toOneHotSparseMatrix(output.getWidth(), useGpu_);
} else {
CHECK(label.value);
value = label.value;
}
if (dynamic_cast<CpuSparseMatrix*>(value.get()) ||
dynamic_cast<GpuSparseMatrix*>(value.get())) {
outputG.multiBinaryLabelCrossEntropyBp(output, *value);
} else {
outputG.binaryLabelCrossEntropyBp(output, *label.value);
outputG.binaryLabelCrossEntropyBp(output, *value);
}
}
......
......@@ -528,7 +528,7 @@ TEST(Layer, multi_cross) {
}
}
TEST(Layer, multi_binary_label) {
TEST(Layer, multi_binary_label_sparse_mat) {
TestConfig config;
config.layerConfig.set_type("multi_binary_label_cross_entropy");
config.biasSize = 0;
......@@ -538,9 +538,26 @@ TEST(Layer, multi_binary_label) {
config.layerConfig.add_inputs();
config.layerConfig.add_inputs();
// Not support GPU now
testLayerGrad(config, "multi_binary_label_cross_entropy", 100,
/* trans */ false, /* useGpu */ false);
for (auto useGpu : {false, true}) {
testLayerGrad(config, "multi_binary_label_cross_entropy", 100,
/* trans */ false, useGpu);
}
}
TEST(layer, multi_binary_label_id) {
TestConfig config;
config.layerConfig.set_type("multi_binary_label_cross_entropy");
config.biasSize = 0;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 50, 0});
config.inputDefs.push_back({INPUT_LABEL, "layer_1", 10, 0});
config.layerConfig.add_inputs();
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "multi_binary_label_cross_entropy", 100,
/* trans */ false, useGpu);
}
}
TEST(Layer, multi_cross_with_selfnorm) {
......
......@@ -409,9 +409,6 @@ void CpuSparseMatrix::setRow(size_t row, size_t colNum,
if (format_ == SPARSE_CSR) {
CHECK_LT(row, height_);
CHECK(NULL != cols);
for (size_t i = row; i < height_; i++) {
CHECK_EQ(rows_[i + 1], rows_[i]);
}
if (0 == row) {
rows_[row] = 0;
}
......
......@@ -1268,6 +1268,42 @@ void GpuMatrix::bilinearBackward(const Matrix& out,
}
}
void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) {
GpuMatrix* outputPtr = dynamic_cast<GpuMatrix*>(&output);
auto labelPtr = dynamic_cast<GpuSparseMatrix*>(&label);
CHECK(outputPtr && labelPtr) << "Invalid argument pointer";
CHECK(labelPtr->format_ == SPARSE_CSR) << "Matrix format not supported";
CHECK(height_ == outputPtr->height_ && width_ == 1
&& outputPtr->width_ == labelPtr->getWidth()
&& outputPtr->height_ == labelPtr->getHeight())
<< "Matrix dimensions are not equal";
real* output_d = outputPtr->data_;
real* entropy_d = data_;
hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get();
hl_matrix_multi_binary_cross_entropy(
output_d, entropy_d, mat_d, height_, outputPtr->width_);
}
void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix &output, Matrix &label) {
GpuMatrix* outputPtr = dynamic_cast<GpuMatrix*>(&output);
auto labelPtr = dynamic_cast<GpuSparseMatrix*>(&label);
CHECK(outputPtr && labelPtr) << "Invalid argument pointer";
CHECK(labelPtr->format_ == SPARSE_CSR) << "Matrix format not supported";
CHECK(height_ == outputPtr->height_ && width_ == outputPtr->width_
&& outputPtr->width_ == labelPtr->getWidth()
&& outputPtr->height_ == labelPtr->getHeight())
<< "Matrix dimensions are not equal";
real* output_d = outputPtr->data_;
real* grad_d = data_;
hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get();
hl_matrix_multi_binary_cross_entropy_bp(
output_d, grad_d, mat_d, height_, width_);
}
/**
* CpuMatrix
*/
......
......@@ -1303,6 +1303,10 @@ public:
const size_t numChannels,
const real ratioH,
const real ratioW);
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label);
};
class CpuMatrix : public Matrix {
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/utils/ThreadLocal.h"
#include "paddle/utils/Thread.h"
#include "paddle/utils/Flags.h"
#include "Matrix.h"
#include "hl_gpu.h"
#include "hl_table_apply.h"
......@@ -73,6 +74,31 @@ std::shared_ptr<VectorT<T>> VectorT<T>::create(size_t size,
}
}
template <>
MatrixPtr VectorT<real>::toOneHotSparseMatrix(size_t idRange, bool useGpu) {
LOG(FATAL) << "Wrong for real vector";
return nullptr;
}
template <>
MatrixPtr VectorT<int>::toOneHotSparseMatrix(size_t idRange, bool useGpu) {
int height = getSize();
int width = idRange;
MatrixPtr mat = Matrix::createSparseMatrix(
height, idRange, height, NO_VALUE, SPARSE_CSR, false, useGpu);
CpuIVector cpuIds(height);
cpuIds.copyFrom(*this);
int *idData = cpuIds.getData();
for (int i = 0; i < height; i ++) {
const unsigned int id = idData[i];
CHECK_LT(id, width);
mat->setRow(i, 1, &id, nullptr);
}
return mat;
}
template <class T>
GpuVectorT<T>::GpuVectorT(size_t size)
: VectorT<T>(size, std::make_shared<GpuMemoryHandle>(sizeof(T) * size),
......
......@@ -37,6 +37,8 @@ class BaseVector;
class SyncThreadPool;
class Matrix;
template<class T>
class BaseVector : public BaseMatrixT<T> {
public:
......@@ -155,6 +157,12 @@ public:
subVecFrom(src, interval.first, interval.second - interval.first);
}
/**
* convert the vector to a sparse one_hot matrix of width idRange
* only applies to IVector
*/
std::shared_ptr<Matrix> toOneHotSparseMatrix(size_t idRange, bool useGpu);
/**
* This function will crash if the size of src and dest is different.
*/
......
......@@ -2208,7 +2208,6 @@ void testCollectSharedBias(int numSamples, int dim, int channel) {
MatrixCheckErr(*cpuBias, *check);
}
TEST(Matrix, sharedBias) {
for (auto numSamples : {1, 100, 520}) {
for (auto dim : {100 * 16, 100 * 32}) {
......@@ -2222,6 +2221,60 @@ TEST(Matrix, sharedBias) {
}
}
void testMultiBinaryLabelCrossEntropy(int numSamples, int dim) {
MatrixPtr output = std::make_shared<CpuMatrix>(numSamples, dim);
MatrixPtr cpuOutput = std::make_shared<CpuMatrix>(numSamples, dim);
MatrixPtr gpuOutput = std::make_shared<GpuMatrix>(numSamples, dim);
MatrixPtr cpuEntropy = std::make_shared<CpuMatrix>(numSamples, 1);
MatrixPtr gpuEntropy = std::make_shared<GpuMatrix>(numSamples, 1);
MatrixPtr cpuGrad = std::make_shared<CpuMatrix>(numSamples, dim);
MatrixPtr gpuGrad = std::make_shared<GpuMatrix>(numSamples, dim);
MatrixPtr cpuLabel = std::make_shared<CpuSparseMatrix>
(numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false);
MatrixPtr gpuLabel = std::make_shared<GpuSparseMatrix>
(numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false);
for (int i = 0; i < numSamples; i ++) {
const unsigned int id = rand() % dim; // NOLINT
cpuLabel->setRow(i, 1, &id, nullptr);
gpuLabel->setRow(i, 1, &id, nullptr);
}
output->randomizeUniform();
cpuOutput->zeroMem();
output->softmax(*cpuOutput);
gpuOutput->copyFrom(*cpuOutput);
cpuEntropy->zeroMem();
gpuEntropy->zeroMem();
cpuEntropy->multiBinaryLabelCrossEntropy(*cpuOutput, *cpuLabel);
gpuEntropy->multiBinaryLabelCrossEntropy(*gpuOutput, *gpuLabel);
MatrixPtr check1 = std::make_shared<CpuMatrix>(numSamples, 1);
check1->copyFrom(*gpuEntropy);
MatrixCheckErr(*cpuEntropy, *check1);
cpuGrad->zeroMem();
gpuGrad->zeroMem();
cpuGrad->multiBinaryLabelCrossEntropyBp(*cpuOutput, *cpuLabel);
gpuGrad->multiBinaryLabelCrossEntropyBp(*gpuOutput, *gpuLabel);
MatrixPtr check2 = std::make_shared<CpuMatrix>(numSamples, dim);
check2->copyFrom(*gpuGrad);
MatrixCheckErr(*cpuGrad, *check2);
}
TEST(Matrix, multiBinaryCrossEntropy) {
for (auto numSamples : {100, 1000, 10000}) {
for (auto dim : {100, 1000, 10000}) {
VLOG(3) << " numSamples=" << numSamples << " dim=" << dim;
testMultiBinaryLabelCrossEntropy(numSamples, dim);
}
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册