提交 5591292b 编写于 作者: H Haonan

modifications according to comments

上级 728defbe
......@@ -346,6 +346,7 @@ void hl_matrix_multi_binary_cross_entropy(real* output,
CHECK_NOTNULL(output);
CHECK_NOTNULL(entropy);
CHECK_NOTNULL(csr_mat);
CHECK_EQ(csr_mat->format, HL_SPARSE_CSR);
int n_threads = 1024;
int blocks = (dimM + n_threads - 1) / n_threads;
dim3 threads(n_threads);
......@@ -385,6 +386,7 @@ void hl_matrix_multi_binary_cross_entropy_bp(real* output,
CHECK_NOTNULL(output);
CHECK_NOTNULL(grad);
CHECK_NOTNULL(csr_mat);
CHECK_EQ(csr_mat->format, HL_SPARSE_CSR);
int n_threads = 1024;
int blocks = (dimM + n_threads - 1) / n_threads;
dim3 threads(n_threads);
......@@ -763,7 +765,7 @@ __global__ void KeMatrixAddSharedBias(real* A,
int dim = N / channel;
if (index < M * N) {
int i = index % N;
i = i / dim;
i = i / dim;
A[index] += scale * B[i];
}
}
......@@ -791,7 +793,7 @@ __global__ void KeMatrixCollectSharedBias(real *B,
const int dim,
const int limit,
real scale) {
if (dim < limit) {
if (dim < limit) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < channel) {
real sum = 0.0;
......
......@@ -465,10 +465,7 @@ void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label,
MatrixPtr value = nullptr;
if (label.ids) {
CHECK(!label.value);
value = Matrix::createSparseMatrix(
label.ids->getSize(), output.getWidth(), label.ids->getSize(),
NO_VALUE, SPARSE_CSR, false, useGpu_);
label.idsToSparseMatrix(value);
value = label.ids->toOneHotSparseMatrix(output.getWidth(), useGpu_);
} else {
CHECK(label.value);
value = label.value;
......@@ -491,10 +488,7 @@ void MultiBinaryLabelCrossEntropy::backwardImp(
MatrixPtr value = nullptr;
if (label.ids) {
CHECK(!value);
value = Matrix::createSparseMatrix(
label.ids->getSize(), output.getWidth(), label.ids->getSize(),
NO_VALUE, SPARSE_CSR, false, useGpu_);
label.idsToSparseMatrix(value);
value = label.ids->toOneHotSparseMatrix(output.getWidth(), useGpu_);
} else {
CHECK(label.value);
value = label.value;
......
......@@ -528,7 +528,7 @@ TEST(Layer, multi_cross) {
}
}
TEST(Layer, multi_binary_label) {
TEST(Layer, multi_binary_label_sparse_mat) {
TestConfig config;
config.layerConfig.set_type("multi_binary_label_cross_entropy");
config.biasSize = 0;
......@@ -544,6 +544,22 @@ TEST(Layer, multi_binary_label) {
}
}
TEST(layer, multi_binary_label_id) {
TestConfig config;
config.layerConfig.set_type("multi_binary_label_cross_entropy");
config.biasSize = 0;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 50, 0});
config.inputDefs.push_back({INPUT_LABEL, "layer_1", 10, 0});
config.layerConfig.add_inputs();
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "multi_binary_label_cross_entropy", 100,
/* trans */ false, useGpu);
}
}
TEST(Layer, multi_cross_with_selfnorm) {
TestConfig config;
config.layerConfig.set_type("multi_class_cross_entropy_with_selfnorm");
......
......@@ -409,9 +409,6 @@ void CpuSparseMatrix::setRow(size_t row, size_t colNum,
if (format_ == SPARSE_CSR) {
CHECK_LT(row, height_);
CHECK(NULL != cols);
for (size_t i = row; i < height_; i++) {
CHECK_EQ(rows_[i + 1], rows_[i]);
}
if (0 == row) {
rows_[row] = 0;
}
......
......@@ -1269,37 +1269,37 @@ void GpuMatrix::bilinearBackward(const Matrix& out,
}
void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) {
GpuMatrix* output_ptr = dynamic_cast<GpuMatrix*>(&output);
auto label_ptr = dynamic_cast<GpuSparseMatrix*>(&label);
CHECK(output_ptr && label_ptr) << "Invalid argument pointer";
CHECK(label_ptr->format_ == SPARSE_CSR) << "Matrix format not supported";
CHECK(height_ == output_ptr->height_ && width_ == 1
&& output_ptr->width_ == label_ptr->getWidth()
&& output_ptr->height_ == label_ptr->getHeight())
GpuMatrix* outputPtr = dynamic_cast<GpuMatrix*>(&output);
auto labelPtr = dynamic_cast<GpuSparseMatrix*>(&label);
CHECK(outputPtr && labelPtr) << "Invalid argument pointer";
CHECK(labelPtr->format_ == SPARSE_CSR) << "Matrix format not supported";
CHECK(height_ == outputPtr->height_ && width_ == 1
&& outputPtr->width_ == labelPtr->getWidth()
&& outputPtr->height_ == labelPtr->getHeight())
<< "Matrix dimensions are not equal";
real* output_d = output_ptr->data_;
real* output_d = outputPtr->data_;
real* entropy_d = data_;
hl_sparse_matrix_s mat_d = label_ptr->sMatrix_.get();
hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get();
hl_matrix_multi_binary_cross_entropy(
output_d, entropy_d, mat_d, height_, output_ptr->width_);
output_d, entropy_d, mat_d, height_, outputPtr->width_);
}
void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix &output, Matrix &label) {
GpuMatrix* output_ptr = dynamic_cast<GpuMatrix*>(&output);
auto label_ptr = dynamic_cast<GpuSparseMatrix*>(&label);
CHECK(output_ptr && label_ptr) << "Invalid argument pointer";
CHECK(label_ptr->format_ == SPARSE_CSR) << "Matrix format not supported";
CHECK(height_ == output_ptr->height_ && width_ == output_ptr->width_
&& output_ptr->width_ == label_ptr->getWidth()
&& output_ptr->height_ == label_ptr->getHeight())
GpuMatrix* outputPtr = dynamic_cast<GpuMatrix*>(&output);
auto labelPtr = dynamic_cast<GpuSparseMatrix*>(&label);
CHECK(outputPtr && labelPtr) << "Invalid argument pointer";
CHECK(labelPtr->format_ == SPARSE_CSR) << "Matrix format not supported";
CHECK(height_ == outputPtr->height_ && width_ == outputPtr->width_
&& outputPtr->width_ == labelPtr->getWidth()
&& outputPtr->height_ == labelPtr->getHeight())
<< "Matrix dimensions are not equal";
real* output_d = output_ptr->data_;
real* output_d = outputPtr->data_;
real* grad_d = data_;
hl_sparse_matrix_s mat_d = label_ptr->sMatrix_.get();
hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get();
hl_matrix_multi_binary_cross_entropy_bp(
output_d, grad_d, mat_d, height_, width_);
}
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/utils/ThreadLocal.h"
#include "paddle/utils/Thread.h"
#include "paddle/utils/Flags.h"
#include "Matrix.h"
#include "hl_gpu.h"
#include "hl_table_apply.h"
......@@ -73,6 +74,31 @@ std::shared_ptr<VectorT<T>> VectorT<T>::create(size_t size,
}
}
template <>
MatrixPtr VectorT<real>::toOneHotSparseMatrix(size_t idRange, bool useGpu) {
LOG(FATAL) << "Wrong for real vector";
return nullptr;
}
template <>
MatrixPtr VectorT<int>::toOneHotSparseMatrix(size_t idRange, bool useGpu) {
int height = getSize();
int width = idRange;
MatrixPtr mat = Matrix::createSparseMatrix(
height, idRange, height, NO_VALUE, SPARSE_CSR, false, useGpu);
CpuIVector cpuIds(height);
cpuIds.copyFrom(*this);
int *idData = cpuIds.getData();
for (int i = 0; i < height; i ++) {
const unsigned int id = idData[i];
CHECK_LT(id, width);
mat->setRow(i, 1, &id, nullptr);
}
return mat;
}
template <class T>
GpuVectorT<T>::GpuVectorT(size_t size)
: VectorT<T>(size, std::make_shared<GpuMemoryHandle>(sizeof(T) * size),
......
......@@ -37,6 +37,8 @@ class BaseVector;
class SyncThreadPool;
class Matrix;
template<class T>
class BaseVector : public BaseMatrixT<T> {
public:
......@@ -155,6 +157,12 @@ public:
subVecFrom(src, interval.first, interval.second - interval.first);
}
/**
* convert the vector to a sparse one_hot matrix of width idRange
* only applies to IVector
*/
std::shared_ptr<Matrix> toOneHotSparseMatrix(size_t idRange, bool useGpu);
/**
* This function will crash if the size of src and dest is different.
*/
......
......@@ -2232,26 +2232,15 @@ void testMultiBinaryLabelCrossEntropy(int numSamples, int dim) {
MatrixPtr cpuGrad = std::make_shared<CpuMatrix>(numSamples, dim);
MatrixPtr gpuGrad = std::make_shared<GpuMatrix>(numSamples, dim);
auto cpuRows = IVector::create(numSamples + 1, false);
auto cpuCols = IVector::create(numSamples, false);
auto gpuRows = IVector::create(numSamples + 1, true);
auto gpuCols = IVector::create(numSamples, true);
cpuRows->setElement(0, 0);
gpuRows->setElement(0, 0);
for (int i = 0; i < numSamples; i ++) {
int id = rand() % dim; // NOLINT
cpuRows->setElement(i + 1, i + 1);
gpuRows->setElement(i + 1, i + 1);
cpuCols->setElement(i, id);
gpuCols->setElement(i, id);
}
MatrixPtr cpuLabel = std::make_shared<CpuSparseMatrix>
(nullptr, cpuRows->getData(), cpuCols->getData(),
numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false);
(numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false);
MatrixPtr gpuLabel = std::make_shared<GpuSparseMatrix>
(nullptr, gpuRows->getData(), gpuCols->getData(),
numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false);
(numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false);
for (int i = 0; i < numSamples; i ++) {
const unsigned int id = rand() % dim; // NOLINT
cpuLabel->setRow(i, 1, &id, nullptr);
gpuLabel->setRow(i, 1, &id, nullptr);
}
output->randomizeUniform();
cpuOutput->zeroMem();
......@@ -2278,8 +2267,8 @@ void testMultiBinaryLabelCrossEntropy(int numSamples, int dim) {
}
TEST(Matrix, multiBinaryCrossEntropy) {
for (auto numSamples : {1, 100, 500}) {
for (auto dim : {1000, 10000, 100000}) {
for (auto numSamples : {100, 1000, 10000}) {
for (auto dim : {100, 1000, 10000}) {
VLOG(3) << " numSamples=" << numSamples << " dim=" << dim;
testMultiBinaryLabelCrossEntropy(numSamples, dim);
}
......
......@@ -572,42 +572,4 @@ void Argument::subArgFrom(const Argument& input, size_t offset, size_t height,
}
}
void Argument::idsToSparseMatrix(MatrixPtr sparse_mat) {
int height = ids->getSize();
int width = sparse_mat->getWidth();
CpuIVector cpu_ids(height);
cpu_ids.copyFrom(*ids);
int *id_data = cpu_ids.getData();
int *rows = nullptr;
int *cols = nullptr;
if (sparse_mat->useGpu()) {
auto gpu_sparse_mat =
dynamic_cast<GpuSparseMatrix*>(sparse_mat.get());
rows = gpu_sparse_mat->rows_;
cols = gpu_sparse_mat->cols_;
} else {
rows = sparse_mat->getRows();
cols = sparse_mat->getCols();
}
rows[0] = 0;
for (int i = 0; i < height; i ++) {
int id = id_data[i];
CHECK_LT(id, width);
rows[i + 1] = i + 1;
cols[i] = id;
}
if (sparse_mat->useGpu()) {
auto gpu_sparse_mat =
dynamic_cast<GpuSparseMatrix*>(sparse_mat.get());
hl_memcpy_csr_matrix(gpu_sparse_mat->sMatrix_.get(),
nullptr, rows, cols,
HPPL_STREAM_DEFAULT);
hl_stream_synchronize(HPPL_STREAM_DEFAULT);
}
}
} // namespace paddle
......@@ -286,12 +286,6 @@ struct Argument {
sequence has sub-sequence degrades to a sequence.
*/
void degradeSequence(const Argument& input, bool useGpu);
/*
@brief convert the ids vector to value as a sparse matrix
@param[out] the output sparse_mat (already allocated)
*/
void idsToSparseMatrix(MatrixPtr sparse_mat);
};
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册