提交 75185d82 编写于 作者: G gangliao 提交者: GitHub

Merge pull request #3228 from gangliao/clang-format

ClangFormat for proto and cuda
......@@ -24,7 +24,7 @@
description: Format files with ClangFormat.
entry: clang-format -i
language: system
files: \.(c|cc|cxx|cpp|h|hpp|hxx)$
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$
- repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks:
......
......@@ -12,17 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_batch_transpose.h"
#include "hl_base.h"
#include "hl_batch_transpose.h"
const int TILE_DIM = 64;
const int BLOCK_ROWS = 16;
// No bank-conflict transpose for a batch of data.
__global__ void batchTransposeNoBankConflicts(real* odata,
const real* idata,
int numSamples, int width,
int height) {
__global__ void batchTransposeNoBankConflicts(
real* odata, const real* idata, int numSamples, int width, int height) {
__shared__ float tile[TILE_DIM][TILE_DIM + 1];
const int x = blockIdx.x * TILE_DIM + threadIdx.x;
......@@ -50,12 +48,12 @@ __global__ void batchTransposeNoBankConflicts(real* odata,
newX] = tile[threadIdx.x][j];
}
void batchTranspose(const real* input, real* output, int width, int height,
int batchSize) {
void batchTranspose(
const real* input, real* output, int width, int height, int batchSize) {
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
dim3 dimGrid(DIVUP(width, TILE_DIM), DIVUP(height, TILE_DIM), batchSize);
batchTransposeNoBankConflicts<<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
(output, input, batchSize, width, height);
batchTransposeNoBankConflicts<<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>(
output, input, batchSize, width, height);
CHECK_SYNC("batchTranspose failed!");
}
......@@ -12,27 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_aggregate.h"
#include "hl_base.h"
#include "hl_cuda.h"
#include "hl_cuda.ph"
#include "hl_aggregate.h"
#include "hl_thread.ph"
#include "hl_matrix_base.cuh"
#include "hl_thread.ph"
#include "paddle/utils/Logging.h"
/**
* @brief matrix row operator.
*/
template<class Agg, int blockSize>
__global__ void KeMatrixRowOp(Agg agg,
real *E,
real *Sum,
int dimN) {
template <class Agg, int blockSize>
__global__ void KeMatrixRowOp(Agg agg, real *E, real *Sum, int dimN) {
__shared__ real sum_s[blockSize];
int cnt = (dimN + blockSize -1) / blockSize;
int rowId = blockIdx.x + blockIdx.y*gridDim.x;
int index = rowId*dimN;
int cnt = (dimN + blockSize - 1) / blockSize;
int rowId = blockIdx.x + blockIdx.y * gridDim.x;
int index = rowId * dimN;
int tid = threadIdx.x;
int lmt = tid;
......@@ -44,7 +40,7 @@ __global__ void KeMatrixRowOp(Agg agg,
sum_s[tid] = tmp;
__syncthreads();
for (int stride = blockSize/2; stride > 0; stride = stride/2) {
for (int stride = blockSize / 2; stride > 0; stride = stride / 2) {
if (tid < stride) {
sum_s[tid] = agg(sum_s[tid], sum_s[tid + stride]);
}
......@@ -58,29 +54,21 @@ __global__ void KeMatrixRowOp(Agg agg,
}
template <class Agg>
void hl_matrix_row_op(Agg agg,
real *A_d,
real *C_d,
int dimM,
int dimN) {
void hl_matrix_row_op(Agg agg, real *A_d, real *C_d, int dimM, int dimN) {
int blocksX = dimM;
int blocksY = 1;
dim3 threads(128, 1);
dim3 grid(blocksX, blocksY);
KeMatrixRowOp<Agg, 128><<< grid, threads, 0, STREAM_DEFAULT >>>
(agg, A_d, C_d, dimN);
KeMatrixRowOp<Agg, 128><<<grid, threads, 0, STREAM_DEFAULT>>>(
agg, A_d, C_d, dimN);
}
void hl_matrix_row_sum(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
hl_matrix_row_op(aggregate::sum(),
A_d,
C_d,
dimM,
dimN);
hl_matrix_row_op(aggregate::sum(), A_d, C_d, dimM, dimN);
CHECK_SYNC("hl_matrix_row_sum failed");
}
......@@ -88,11 +76,7 @@ void hl_matrix_row_max(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
hl_matrix_row_op(aggregate::max(),
A_d,
C_d,
dimM,
dimN);
hl_matrix_row_op(aggregate::max(), A_d, C_d, dimM, dimN);
CHECK_SYNC("hl_matrix_row_max failed");
}
......@@ -100,23 +84,16 @@ void hl_matrix_row_min(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
hl_matrix_row_op(aggregate::min(),
A_d,
C_d,
dimM,
dimN);
hl_matrix_row_op(aggregate::min(), A_d, C_d, dimM, dimN);
CHECK_SYNC("hl_matrix_row_min failed");
}
/**
* @brief matrix column operator.
*/
template<class Agg>
__global__ void KeMatrixColumnOp(Agg agg,
real *E,
real *Sum,
int dimM,
int dimN) {
template <class Agg>
__global__ void KeMatrixColumnOp(
Agg agg, real *E, real *Sum, int dimM, int dimN) {
int rowIdx = blockIdx.x * blockDim.x + threadIdx.x;
real tmp = agg.init();
if (rowIdx < dimN) {
......@@ -127,15 +104,12 @@ __global__ void KeMatrixColumnOp(Agg agg,
}
}
template<class Agg, int blockDimX, int blockDimY>
__global__ void KeMatrixColumnOp_S(Agg agg,
real *E,
real *Sum,
int dimM,
int dimN) {
__shared__ real _sum[blockDimX*blockDimY];
int rowIdx = blockIdx.x * blockDim.x + threadIdx.x;
int index = threadIdx.y;
template <class Agg, int blockDimX, int blockDimY>
__global__ void KeMatrixColumnOp_S(
Agg agg, real *E, real *Sum, int dimM, int dimN) {
__shared__ real _sum[blockDimX * blockDimY];
int rowIdx = blockIdx.x * blockDim.x + threadIdx.x;
int index = threadIdx.y;
real tmp = agg.init();
if (rowIdx < dimN) {
......@@ -144,14 +118,14 @@ __global__ void KeMatrixColumnOp_S(Agg agg,
index += blockDimY;
}
}
_sum[threadIdx.x + threadIdx.y*blockDimX] = tmp;
_sum[threadIdx.x + threadIdx.y * blockDimX] = tmp;
__syncthreads();
if (rowIdx < dimN) {
if (threadIdx.y ==0) {
if (threadIdx.y == 0) {
real tmp = agg.init();
for (int i=0; i < blockDimY; i++) {
tmp = agg(tmp, _sum[threadIdx.x + i*blockDimX]);
for (int i = 0; i < blockDimY; i++) {
tmp = agg(tmp, _sum[threadIdx.x + i * blockDimX]);
}
Sum[rowIdx] = tmp;
}
......@@ -159,25 +133,21 @@ __global__ void KeMatrixColumnOp_S(Agg agg,
}
template <class Agg>
void hl_matrix_column_op(Agg agg,
real *A_d,
real *C_d,
int dimM,
int dimN) {
void hl_matrix_column_op(Agg agg, real *A_d, real *C_d, int dimM, int dimN) {
if (dimN >= 8192) {
int blocksX = (dimN + 128 -1) / 128;
int blocksX = (dimN + 128 - 1) / 128;
int blocksY = 1;
dim3 threads(128, 1);
dim3 grid(blocksX, blocksY);
KeMatrixColumnOp<Agg><<< grid, threads, 0, STREAM_DEFAULT >>>
(agg, A_d, C_d, dimM, dimN);
KeMatrixColumnOp<Agg><<<grid, threads, 0, STREAM_DEFAULT>>>(
agg, A_d, C_d, dimM, dimN);
} else {
int blocksX = (dimN + 32 -1) / 32;
int blocksX = (dimN + 32 - 1) / 32;
int blocksY = 1;
dim3 threads(32, 32);
dim3 grid(blocksX, blocksY);
KeMatrixColumnOp_S<Agg, 32, 32><<< grid, threads, 0, STREAM_DEFAULT>>>
(agg, A_d, C_d, dimM, dimN);
KeMatrixColumnOp_S<Agg, 32, 32><<<grid, threads, 0, STREAM_DEFAULT>>>(
agg, A_d, C_d, dimM, dimN);
}
return;
......@@ -187,11 +157,7 @@ void hl_matrix_column_sum(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
hl_matrix_column_op(aggregate::sum(),
A_d,
C_d,
dimM,
dimN);
hl_matrix_column_op(aggregate::sum(), A_d, C_d, dimM, dimN);
CHECK_SYNC("hl_matrix_column_sum failed");
}
......@@ -200,11 +166,7 @@ void hl_matrix_column_max(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
hl_matrix_column_op(aggregate::max(),
A_d,
C_d,
dimM,
dimN);
hl_matrix_column_op(aggregate::max(), A_d, C_d, dimM, dimN);
CHECK_SYNC("hl_matrix_column_max failed");
}
......@@ -213,11 +175,7 @@ void hl_matrix_column_min(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL(A_d);
CHECK_NOTNULL(C_d);
hl_matrix_column_op(aggregate::min(),
A_d,
C_d,
dimM,
dimN);
hl_matrix_column_op(aggregate::min(), A_d, C_d, dimM, dimN);
CHECK_SYNC("hl_matrix_column_min failed");
}
......@@ -226,16 +184,16 @@ template <int blockSize>
__global__ void KeVectorSum(real *E, real *Sum, int dimM) {
__shared__ double sum_s[blockSize];
int tid = threadIdx.x;
int index = blockIdx.y*blockDim.x+threadIdx.x;
int index = blockIdx.y * blockDim.x + threadIdx.x;
sum_s[tid] = 0.0f;
while (index < dimM) {
sum_s[tid] += E[index];
index += blockDim.x*gridDim.y;
index += blockDim.x * gridDim.y;
}
__syncthreads();
for (int stride = blockSize/2; stride > 0; stride = stride/2) {
for (int stride = blockSize / 2; stride > 0; stride = stride / 2) {
if (tid < stride) {
sum_s[tid] += sum_s[tid + stride];
}
......@@ -259,38 +217,39 @@ void hl_vector_sum(real *A_d, real *C_h, int dimM) {
dim3 threads(blockSize, 1);
dim3 grid(blocksX, blocksY);
struct _hl_event_st hl_event_st = {.cu_event = t_resource.event};
struct _hl_event_st hl_event_st = {.cu_event = t_resource.event};
hl_event_t hl_event = &hl_event_st;
while (!hl_cuda_event_is_ready(hl_event)) {}
while (!hl_cuda_event_is_ready(hl_event)) {
}
KeVectorSum<128><<< grid, threads, 0, STREAM_DEFAULT >>>
(A_d, t_resource.gpu_mem, dimM);
KeVectorSum<128><<< 1, threads, 0, STREAM_DEFAULT >>>
(t_resource.gpu_mem, t_resource.cpu_mem, 128);
KeVectorSum<128><<<grid, threads, 0, STREAM_DEFAULT>>>(
A_d, t_resource.gpu_mem, dimM);
KeVectorSum<128><<<1, threads, 0, STREAM_DEFAULT>>>(
t_resource.gpu_mem, t_resource.cpu_mem, 128);
hl_memcpy_async(C_h, t_resource.cpu_mem, sizeof(real), HPPL_STREAM_DEFAULT);
hl_stream_record_event(HPPL_STREAM_DEFAULT, hl_event);
hl_stream_synchronize(HPPL_STREAM_DEFAULT);
cudaError_t err = (cudaError_t)hl_get_device_last_error();
CHECK_EQ(cudaSuccess, err)
<< "CUDA error: " << hl_get_device_error_string((size_t)err);
CHECK_EQ(cudaSuccess, err) << "CUDA error: "
<< hl_get_device_error_string((size_t)err);
}
template <int blockSize>
__global__ void KeVectorAbsSum(real *E, real *Sum, int dimM) {
__shared__ double sum_s[blockSize];
int tid = threadIdx.x;
int index = blockIdx.y*blockDim.x+threadIdx.x;
int index = blockIdx.y * blockDim.x + threadIdx.x;
sum_s[tid] = 0.0f;
while (index < dimM) {
sum_s[tid] += abs(E[index]);
index += blockDim.x*gridDim.y;
index += blockDim.x * gridDim.y;
}
__syncthreads();
for (int stride = blockSize/2; stride > 0; stride = stride/2) {
for (int stride = blockSize / 2; stride > 0; stride = stride / 2) {
if (tid < stride) {
sum_s[tid] += sum_s[tid + stride];
}
......@@ -314,20 +273,21 @@ void hl_vector_abs_sum(real *A_d, real *C_h, int dimM) {
dim3 threads(blockSize, 1);
dim3 grid(blocksX, blocksY);
struct _hl_event_st hl_event_st = {.cu_event = t_resource.event};
struct _hl_event_st hl_event_st = {.cu_event = t_resource.event};
hl_event_t hl_event = &hl_event_st;
while (!hl_cuda_event_is_ready(hl_event)) {}
while (!hl_cuda_event_is_ready(hl_event)) {
}
KeVectorAbsSum<128><<< grid, threads, 0, STREAM_DEFAULT >>>
(A_d, t_resource.gpu_mem, dimM);
KeVectorAbsSum<128><<< 1, threads, 0, STREAM_DEFAULT >>>
(t_resource.gpu_mem, t_resource.cpu_mem, 128);
KeVectorAbsSum<128><<<grid, threads, 0, STREAM_DEFAULT>>>(
A_d, t_resource.gpu_mem, dimM);
KeVectorAbsSum<128><<<1, threads, 0, STREAM_DEFAULT>>>(
t_resource.gpu_mem, t_resource.cpu_mem, 128);
hl_memcpy_async(C_h, t_resource.cpu_mem, sizeof(real), HPPL_STREAM_DEFAULT);
hl_stream_record_event(HPPL_STREAM_DEFAULT, hl_event);
hl_stream_synchronize(HPPL_STREAM_DEFAULT);
cudaError_t err = (cudaError_t)hl_get_device_last_error();
CHECK_EQ(cudaSuccess, err)
<< "CUDA error: " << hl_get_device_error_string((size_t)err);
CHECK_EQ(cudaSuccess, err) << "CUDA error: "
<< hl_get_device_error_string((size_t)err);
}
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -16,36 +16,36 @@ limitations under the License. */
#include "hl_device_functions.cuh"
#include "paddle/utils/Logging.h"
__global__ void KeMaxSequenceForward(real *input,
const int *sequence,
__global__ void KeMaxSequenceForward(real* input,
const int* sequence,
real* output,
int *index,
int* index,
int numSequences,
int dim) {
int dimIdx = threadIdx.x;
int sequenceId = blockIdx.x;
if (sequenceId >= numSequences) return;
int start = sequence[sequenceId];
int end = sequence[sequenceId+1];
int end = sequence[sequenceId + 1];
for (int i = dimIdx; i < dim; i += blockDim.x) {
real tmp = -HL_FLOAT_MAX;
int tmpId = -1;
for (int insId = start; insId < end; insId++) {
if (tmp < input[insId*dim + i]) {
tmp = input[insId*dim + i];
if (tmp < input[insId * dim + i]) {
tmp = input[insId * dim + i];
tmpId = insId;
}
}
output[sequenceId*dim + i] = tmp;
index[sequenceId*dim + i] = tmpId;
output[sequenceId * dim + i] = tmp;
index[sequenceId * dim + i] = tmpId;
}
}
void hl_max_sequence_forward(real* input,
const int* sequence,
real* output,
int *index,
int* index,
int numSequences,
int dim) {
CHECK_NOTNULL(input);
......@@ -55,29 +55,23 @@ void hl_max_sequence_forward(real* input,
dim3 threads(256, 1);
dim3 grid(numSequences, 1);
KeMaxSequenceForward<<< grid, threads, 0, STREAM_DEFAULT >>>
(input, sequence, output, index, numSequences, dim);
KeMaxSequenceForward<<<grid, threads, 0, STREAM_DEFAULT>>>(
input, sequence, output, index, numSequences, dim);
CHECK_SYNC("hl_max_sequence_forward failed");
}
__global__ void KeMaxSequenceBackward(real *outputGrad,
int *index,
real* inputGrad,
int numSequences,
int dim) {
__global__ void KeMaxSequenceBackward(
real* outputGrad, int* index, real* inputGrad, int numSequences, int dim) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int colIdx = idx % dim;
if (idx < numSequences*dim) {
if (idx < numSequences * dim) {
int insId = index[idx];
inputGrad[insId * dim + colIdx] += outputGrad[idx];
}
}
void hl_max_sequence_backward(real* outputGrad,
int *index,
real* inputGrad,
int numSequences,
int dim) {
void hl_max_sequence_backward(
real* outputGrad, int* index, real* inputGrad, int numSequences, int dim) {
CHECK_NOTNULL(outputGrad);
CHECK_NOTNULL(index);
CHECK_NOTNULL(inputGrad);
......@@ -85,12 +79,12 @@ void hl_max_sequence_backward(real* outputGrad,
unsigned int blocks = (numSequences * dim + 128 - 1) / 128;
dim3 threads(128, 1);
dim3 grid(blocks, 1);
KeMaxSequenceBackward<<< grid, threads, 0, STREAM_DEFAULT >>>
(outputGrad, index, inputGrad, numSequences, dim);
KeMaxSequenceBackward<<<grid, threads, 0, STREAM_DEFAULT>>>(
outputGrad, index, inputGrad, numSequences, dim);
CHECK_SYNC("hl_max_sequence_backward failed");
}
template<int blockDimX, int blockDimY, int gridDimX, bool AddRow>
template <int blockDimX, int blockDimY, int gridDimX, bool AddRow>
__global__ void KeMatrixAddRows(real* output,
real* table,
int* ids,
......@@ -104,8 +98,8 @@ __global__ void KeMatrixAddRows(real* output,
while (sampleId < numSamples) {
int tableId = ids[sampleId];
if ((0 <= tableId) && (tableId < tableSize)) {
real *outputData = output + sampleId * dim;
real *tableData = table + tableId * dim;
real* outputData = output + sampleId * dim;
real* tableData = table + tableId * dim;
for (int i = idx; i < dim; i += blockDimX) {
if (AddRow == 0) {
outputData[i] += tableData[i];
......@@ -114,24 +108,27 @@ __global__ void KeMatrixAddRows(real* output,
}
}
}
sampleId += blockDimY*gridDimX;
sampleId += blockDimY * gridDimX;
}
}
template<int blockDimX, int blockDimY, int gridDimX, bool seq2batch, bool isAdd>
__global__
void KeSequence2Batch(real *batch,
real *sequence,
const int *batchIndex,
int seqWidth,
int batchCount) {
template <int blockDimX,
int blockDimY,
int gridDimX,
bool seq2batch,
bool isAdd>
__global__ void KeSequence2Batch(real* batch,
real* sequence,
const int* batchIndex,
int seqWidth,
int batchCount) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int id = blockIdx.x + idy * gridDimX;
while (id < batchCount) {
int seqId = batchIndex[id];
real* batchData = batch + id*seqWidth;
real* seqData = sequence + seqId*seqWidth;
real* batchData = batch + id * seqWidth;
real* seqData = sequence + seqId * seqWidth;
for (int i = idx; i < seqWidth; i += blockDimX) {
if (seq2batch) {
if (isAdd) {
......@@ -147,13 +144,13 @@ void KeSequence2Batch(real *batch,
}
}
}
id += blockDimY*gridDimX;
id += blockDimY * gridDimX;
}
}
void hl_sequence2batch_copy(real *batch,
real *sequence,
const int *batchIndex,
void hl_sequence2batch_copy(real* batch,
real* sequence,
const int* batchIndex,
int seqWidth,
int batchCount,
bool seq2batch) {
......@@ -164,18 +161,18 @@ void hl_sequence2batch_copy(real *batch,
dim3 threads(128, 8);
dim3 grid(8, 1);
if (seq2batch) {
KeSequence2Batch<128, 8, 8, 1, 0><<< grid, threads, 0, STREAM_DEFAULT >>>
(batch, sequence, batchIndex, seqWidth, batchCount);
KeSequence2Batch<128, 8, 8, 1, 0><<<grid, threads, 0, STREAM_DEFAULT>>>(
batch, sequence, batchIndex, seqWidth, batchCount);
} else {
KeSequence2Batch<128, 8, 8, 0, 0><<< grid, threads, 0, STREAM_DEFAULT >>>
(batch, sequence, batchIndex, seqWidth, batchCount);
KeSequence2Batch<128, 8, 8, 0, 0><<<grid, threads, 0, STREAM_DEFAULT>>>(
batch, sequence, batchIndex, seqWidth, batchCount);
}
CHECK_SYNC("hl_sequence2batch_copy failed");
}
void hl_sequence2batch_add(real *batch,
real *sequence,
int *batchIndex,
void hl_sequence2batch_add(real* batch,
real* sequence,
int* batchIndex,
int seqWidth,
int batchCount,
bool seq2batch) {
......@@ -186,23 +183,22 @@ void hl_sequence2batch_add(real *batch,
dim3 threads(128, 8);
dim3 grid(8, 1);
if (seq2batch) {
KeSequence2Batch<128, 8, 8, 1, 1><<< grid, threads, 0, STREAM_DEFAULT >>>
(batch, sequence, batchIndex, seqWidth, batchCount);
KeSequence2Batch<128, 8, 8, 1, 1><<<grid, threads, 0, STREAM_DEFAULT>>>(
batch, sequence, batchIndex, seqWidth, batchCount);
} else {
KeSequence2Batch<128, 8, 8, 0, 1><<< grid, threads, 0, STREAM_DEFAULT >>>
(batch, sequence, batchIndex, seqWidth, batchCount);
KeSequence2Batch<128, 8, 8, 0, 1><<<grid, threads, 0, STREAM_DEFAULT>>>(
batch, sequence, batchIndex, seqWidth, batchCount);
}
CHECK_SYNC("hl_sequence2batch_add failed");
}
template<bool normByTimes, bool seq2batch>
__global__
void KeSequence2BatchPadding(real* batch,
real* sequence,
const int* sequenceStartPositions,
const size_t sequenceWidth,
const size_t maxSequenceLength,
const size_t numSequences) {
template <bool normByTimes, bool seq2batch>
__global__ void KeSequence2BatchPadding(real* batch,
real* sequence,
const int* sequenceStartPositions,
const size_t sequenceWidth,
const size_t maxSequenceLength,
const size_t numSequences) {
int batchIdx = blockIdx.y;
int sequenceStart = sequenceStartPositions[batchIdx];
int sequenceLength = sequenceStartPositions[batchIdx + 1] - sequenceStart;
......@@ -276,37 +272,49 @@ void hl_sequence2batch_copy_padding(real* batch,
if (seq2batch) {
/* sequence -> batch */
if (normByTimes) {
KeSequence2BatchPadding<1, 1><<< grid, threads, 0, STREAM_DEFAULT >>>(
batch, sequence, sequenceStartPositions,
sequenceWidth, maxSequenceLength, numSequences);
KeSequence2BatchPadding<1, 1><<<grid, threads, 0, STREAM_DEFAULT>>>(
batch,
sequence,
sequenceStartPositions,
sequenceWidth,
maxSequenceLength,
numSequences);
} else {
KeSequence2BatchPadding<0, 1><<< grid, threads, 0, STREAM_DEFAULT >>>(
batch, sequence, sequenceStartPositions,
sequenceWidth, maxSequenceLength, numSequences);
KeSequence2BatchPadding<0, 1><<<grid, threads, 0, STREAM_DEFAULT>>>(
batch,
sequence,
sequenceStartPositions,
sequenceWidth,
maxSequenceLength,
numSequences);
}
} else {
/* batch -> sequence */
if (normByTimes) {
KeSequence2BatchPadding<1, 0><<< grid, threads, 0, STREAM_DEFAULT >>>(
batch, sequence, sequenceStartPositions,
sequenceWidth, maxSequenceLength, numSequences);
KeSequence2BatchPadding<1, 0><<<grid, threads, 0, STREAM_DEFAULT>>>(
batch,
sequence,
sequenceStartPositions,
sequenceWidth,
maxSequenceLength,
numSequences);
} else {
KeSequence2BatchPadding<0, 0><<< grid, threads, 0, STREAM_DEFAULT >>>(
batch, sequence, sequenceStartPositions,
sequenceWidth, maxSequenceLength, numSequences);
KeSequence2BatchPadding<0, 0><<<grid, threads, 0, STREAM_DEFAULT>>>(
batch,
sequence,
sequenceStartPositions,
sequenceWidth,
maxSequenceLength,
numSequences);
}
}
CHECK_SYNC("hl_sequence2batch_copy_padding failed");
}
__device__ inline float my_rsqrt(float x) {
return rsqrtf(x);
}
__device__ inline float my_rsqrt(float x) { return rsqrtf(x); }
__device__ inline double my_rsqrt(double x) {
return rsqrt(x);
}
__device__ inline double my_rsqrt(double x) { return rsqrt(x); }
__global__ void KeSequenceAvgForward(real* dst,
real* src,
......@@ -327,8 +335,8 @@ __global__ void KeSequenceAvgForward(real* dst,
for (int i = start; i < end; i++) {
sum += src[i * width + col];
}
sum = mode == 1 ? sum :
(mode == 0 ? sum / seqLength : sum * my_rsqrt((real)seqLength));
sum = mode == 1 ? sum : (mode == 0 ? sum / seqLength
: sum * my_rsqrt((real)seqLength));
dst[gid] += sum;
}
}
......@@ -347,10 +355,10 @@ void hl_sequence_avg_forward(real* dst,
int grid = DIVUP(width * height, 512);
CHECK(mode == 0 || mode == 1 || mode == 2)
<< "mode error in hl_sequence_avg_forward!";
<< "mode error in hl_sequence_avg_forward!";
KeSequenceAvgForward<<< grid, block, 0, STREAM_DEFAULT >>>
(dst, src, starts, height, width, mode);
KeSequenceAvgForward<<<grid, block, 0, STREAM_DEFAULT>>>(
dst, src, starts, height, width, mode);
CHECK_SYNC("hl_sequence_avg_forward failed");
}
......@@ -370,8 +378,8 @@ __global__ void KeSequenceAvgBackward(real* dst,
int seqLength = end - start;
if (seqLength == 0) return;
real grad = src[gid];
grad = mode == 1 ? grad :
(mode == 0 ? grad / seqLength : grad * my_rsqrt((real)seqLength));
grad = mode == 1 ? grad : (mode == 0 ? grad / seqLength
: grad * my_rsqrt((real)seqLength));
for (int i = start; i < end; i++) {
dst[i * width + col] += grad;
}
......@@ -392,9 +400,9 @@ void hl_sequence_avg_backward(real* dst,
int grid = DIVUP(width * height, 512);
CHECK(mode == 0 || mode == 1 || mode == 2)
<< "mode error in hl_sequence_avg_backward!";
<< "mode error in hl_sequence_avg_backward!";
KeSequenceAvgBackward<<< grid, block, 0, STREAM_DEFAULT >>>
(dst, src, starts, height, width, mode);
KeSequenceAvgBackward<<<grid, block, 0, STREAM_DEFAULT>>>(
dst, src, starts, height, width, mode);
CHECK_SYNC("hl_sequence_avg_backward failed");
}
此差异已折叠。
......@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <stdlib.h>
#include "hl_cuda.h"
#include "hl_time.h"
#include <cmath>
#include "hl_base.h"
#include "hl_cuda.h"
#include "hl_perturbation_util.cuh"
#include "hl_time.h"
#define _USE_MATH_DEFINES
......@@ -30,10 +29,16 @@ limitations under the License. */
* centerX, centerY: translation.
* sourceX, sourceY: output coordinates in the original image.
*/
__device__ void getTranformCoord(int x, int y, real theta, real scale,
real tgtCenter, real imgCenter,
real centerR, real centerC,
int* sourceX, int* sourceY) {
__device__ void getTranformCoord(int x,
int y,
real theta,
real scale,
real tgtCenter,
real imgCenter,
real centerR,
real centerC,
int* sourceX,
int* sourceY) {
real H[4] = {cosf(-theta), -sinf(-theta), sinf(-theta), cosf(-theta)};
// compute coornidates in the rotated and scaled image
......@@ -57,11 +62,17 @@ __device__ void getTranformCoord(int x, int y, real theta, real scale,
* created by Wei Xu (genome), converted by Jiang Wang
*/
__global__ void kSamplingPatches(const real* imgs, real* targets,
int imgSize, int tgtSize, const int channels,
int samplingRate, const real* thetas,
const real* scales, const int* centerRs,
const int* centerCs, const real padValue,
__global__ void kSamplingPatches(const real* imgs,
real* targets,
int imgSize,
int tgtSize,
const int channels,
int samplingRate,
const real* thetas,
const real* scales,
const int* centerRs,
const int* centerCs,
const real padValue,
const int numImages) {
const int caseIdx = blockIdx.x * 4 + threadIdx.x;
const int pxIdx = blockIdx.y * 128 + threadIdx.y;
......@@ -80,8 +91,15 @@ __global__ void kSamplingPatches(const real* imgs, real* targets,
const int pxY = pxIdx / tgtSize;
int srcPxX, srcPxY;
getTranformCoord(pxX, pxY, thetas[imgIdx], scales[imgIdx], tgtCenter,
imgCenter, centerCs[caseIdx], centerRs[caseIdx], &srcPxX,
getTranformCoord(pxX,
pxY,
thetas[imgIdx],
scales[imgIdx],
tgtCenter,
imgCenter,
centerCs[caseIdx],
centerRs[caseIdx],
&srcPxX,
&srcPxY);
imgs += (imgIdx * imgPixels + srcPxY * imgSize + srcPxX) * channels;
......@@ -100,10 +118,15 @@ __global__ void kSamplingPatches(const real* imgs, real* targets,
*
* created by Wei Xu
*/
void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
int*& gpuCenterR, int*& gpuCenterC,
int numImages, int imgSize, real rotateAngle,
real scaleRatio, int samplingRate,
void hl_generate_disturb_params(real*& gpuAngle,
real*& gpuScaleRatio,
int*& gpuCenterR,
int*& gpuCenterC,
int numImages,
int imgSize,
real rotateAngle,
real scaleRatio,
int samplingRate,
bool isTrain) {
// The number of output samples.
int numPatches = numImages * samplingRate;
......@@ -123,7 +146,8 @@ void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
for (int i = 0; i < numImages; i++) {
r_angle[i] =
(rotateAngle * M_PI / 180.0) * (rand() / (RAND_MAX + 1.0) // NOLINT
- 0.5);
-
0.5);
s_ratio[i] =
1 + (rand() / (RAND_MAX + 1.0) - 0.5) * scaleRatio; // NOLINT
}
......@@ -140,8 +164,10 @@ void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
int pxY =
(int)(real(imgSize - 1) * rand() / (RAND_MAX + 1.0)); // NOLINT
const real H[4] = {cos(-r_angle[i]), -sin(-r_angle[i]),
sin(-r_angle[i]), cos(-r_angle[i])};
const real H[4] = {cos(-r_angle[i]),
-sin(-r_angle[i]),
sin(-r_angle[i]),
cos(-r_angle[i])};
real x = pxX - imgCenter;
real y = pxY - imgCenter;
real xx = H[0] * x + H[1] * y;
......@@ -185,9 +211,12 @@ void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
delete[] center_c;
}
void hl_conv_random_disturb_with_params(const real* images, int imgSize,
int tgtSize, int channels,
int numImages, int samplingRate,
void hl_conv_random_disturb_with_params(const real* images,
int imgSize,
int tgtSize,
int channels,
int numImages,
int samplingRate,
const real* gpuRotationAngle,
const real* gpuScaleRatio,
const int* gpuCenterR,
......@@ -202,29 +231,59 @@ void hl_conv_random_disturb_with_params(const real* images, int imgSize,
dim3 threadsPerBlock(4, 128);
dim3 numBlocks(DIVUP(numPatches, 4), DIVUP(targetSize, 128));
kSamplingPatches <<<numBlocks, threadsPerBlock>>>
(images, target, imgSize, tgtSize, channels, samplingRate,
gpuRotationAngle, gpuScaleRatio, gpuCenterR, gpuCenterC,
paddingValue, numImages);
kSamplingPatches<<<numBlocks, threadsPerBlock>>>(images,
target,
imgSize,
tgtSize,
channels,
samplingRate,
gpuRotationAngle,
gpuScaleRatio,
gpuCenterR,
gpuCenterC,
paddingValue,
numImages);
hl_device_synchronize();
}
void hl_conv_random_disturb(const real* images, int imgSize,
int tgtSize, int channels, int numImages,
real scaleRatio, real rotateAngle,
int samplingRate, real* gpu_r_angle,
real* gpu_s_ratio, int* gpu_center_r,
int* gpu_center_c, int paddingValue,
bool isTrain, real* targets) {
void hl_conv_random_disturb(const real* images,
int imgSize,
int tgtSize,
int channels,
int numImages,
real scaleRatio,
real rotateAngle,
int samplingRate,
real* gpu_r_angle,
real* gpu_s_ratio,
int* gpu_center_r,
int* gpu_center_c,
int paddingValue,
bool isTrain,
real* targets) {
// generate the random disturbance sequence and the sampling locations
hl_generate_disturb_params(gpu_r_angle, gpu_s_ratio, gpu_center_r,
gpu_center_c, numImages, imgSize, rotateAngle,
scaleRatio, samplingRate, isTrain);
hl_conv_random_disturb_with_params(
images, imgSize, tgtSize, channels, numImages,
samplingRate, gpu_r_angle, gpu_s_ratio,
gpu_center_r, gpu_center_r, paddingValue,
targets);
hl_generate_disturb_params(gpu_r_angle,
gpu_s_ratio,
gpu_center_r,
gpu_center_c,
numImages,
imgSize,
rotateAngle,
scaleRatio,
samplingRate,
isTrain);
hl_conv_random_disturb_with_params(images,
imgSize,
tgtSize,
channels,
numImages,
samplingRate,
gpu_r_angle,
gpu_s_ratio,
gpu_center_r,
gpu_center_r,
paddingValue,
targets);
}
......@@ -12,15 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "hl_device_functions.cuh"
#include "hl_cuda.h"
#include "hl_device_functions.cuh"
#include "paddle/utils/Logging.h"
template<int blockDimX, int blockDimY, int gridDimX, bool AddRow>
__global__ void KeMatrixAddRows(real* output, int ldo,
real* table, int ldt,
template <int blockDimX, int blockDimY, int gridDimX, bool AddRow>
__global__ void KeMatrixAddRows(real* output,
int ldo,
real* table,
int ldt,
int* ids,
int numSamples,
int tableSize,
......@@ -31,8 +32,8 @@ __global__ void KeMatrixAddRows(real* output, int ldo,
while (idy < numSamples) {
int tableId = ids[idy];
if ((0 <= tableId) && (tableId < tableSize)) {
real *out = output + idy * ldo;
real *tab = table + tableId * ldt;
real* out = output + idy * ldo;
real* tab = table + tableId * ldt;
for (int i = idx; i < dim; i += blockDimX) {
if (AddRow) {
paddle::paddleAtomicAdd(&tab[i], out[i]);
......@@ -45,8 +46,10 @@ __global__ void KeMatrixAddRows(real* output, int ldo,
}
}
void hl_matrix_select_rows(real* output, int ldo,
real* table, int ldt,
void hl_matrix_select_rows(real* output,
int ldo,
real* table,
int ldt,
int* ids,
int numSamples,
int tableSize,
......@@ -57,14 +60,16 @@ void hl_matrix_select_rows(real* output, int ldo,
dim3 threads(128, 8);
dim3 grid(8, 1);
KeMatrixAddRows<128, 8, 8, 0><<< grid, threads, 0, STREAM_DEFAULT >>>
(output, ldo, table, ldt, ids, numSamples, tableSize, dim);
KeMatrixAddRows<128, 8, 8, 0><<<grid, threads, 0, STREAM_DEFAULT>>>(
output, ldo, table, ldt, ids, numSamples, tableSize, dim);
CHECK_SYNC("hl_matrix_select_rows failed");
}
void hl_matrix_add_to_rows(real* table, int ldt,
real* input, int ldi,
void hl_matrix_add_to_rows(real* table,
int ldt,
real* input,
int ldi,
int* ids,
int numSamples,
int tableSize,
......@@ -75,16 +80,15 @@ void hl_matrix_add_to_rows(real* table, int ldt,
dim3 threads(128, 8);
dim3 grid(8, 1);
KeMatrixAddRows<128, 8, 8, 1><<< grid, threads, 0, STREAM_DEFAULT >>>
(input, ldi, table, ldt, ids, numSamples, tableSize, dim);
KeMatrixAddRows<128, 8, 8, 1><<<grid, threads, 0, STREAM_DEFAULT>>>(
input, ldi, table, ldt, ids, numSamples, tableSize, dim);
CHECK_SYNC("hl_matrix_add_to_rows failed");
}
template<class T, int blockDimX, int gridDimX>
__global__ void KeVectorSelect(T* dst, int sized,
const T* src, int sizes,
const int* ids, int sizei) {
template <class T, int blockDimX, int gridDimX>
__global__ void KeVectorSelect(
T* dst, int sized, const T* src, int sizes, const int* ids, int sizei) {
int idx = threadIdx.x + blockDimX * blockIdx.x;
while (idx < sizei) {
int index = ids[idx];
......@@ -95,9 +99,8 @@ __global__ void KeVectorSelect(T* dst, int sized,
}
template <class T>
void hl_vector_select_from(T* dst, int sized,
const T* src, int sizes,
const int* ids, int sizei) {
void hl_vector_select_from(
T* dst, int sized, const T* src, int sizes, const int* ids, int sizei) {
CHECK_NOTNULL(dst);
CHECK_NOTNULL(src);
CHECK_NOTNULL(ids);
......@@ -105,18 +108,17 @@ void hl_vector_select_from(T* dst, int sized,
dim3 threads(512, 1);
dim3 grid(8, 1);
KeVectorSelect<T, 512, 8><<< grid, threads, 0, STREAM_DEFAULT >>>
(dst, sized, src, sizes, ids, sizei);
KeVectorSelect<T, 512, 8><<<grid, threads, 0, STREAM_DEFAULT>>>(
dst, sized, src, sizes, ids, sizei);
CHECK_SYNC("hl_vector_select_from failed");
}
template
void hl_vector_select_from(real* dst, int sized,
const real* src, int sizes,
const int* ids, int sizei);
template
void hl_vector_select_from(int* dst, int sized,
const int* src, int sizes,
const int* ids, int sizei);
template void hl_vector_select_from(real* dst,
int sized,
const real* src,
int sizes,
const int* ids,
int sizei);
template void hl_vector_select_from(
int* dst, int sized, const int* src, int sizes, const int* ids, int sizei);
......@@ -12,45 +12,37 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "hl_top_k.h"
#include "hl_sparse.ph"
#include "hl_top_k.h"
#include "paddle/utils/Logging.h"
// using namespace hppl;
struct Pair {
__device__ __forceinline__
Pair() {}
__device__ __forceinline__ Pair() {}
__device__ __forceinline__
Pair(real value, int id) : v_(value), id_(id) {}
__device__ __forceinline__ Pair(real value, int id) : v_(value), id_(id) {}
__device__ __forceinline__
void set(real value, int id) {
__device__ __forceinline__ void set(real value, int id) {
v_ = value;
id_ = id;
}
__device__ __forceinline__
void operator=(const Pair& in) {
__device__ __forceinline__ void operator=(const Pair& in) {
v_ = in.v_;
id_ = in.id_;
}
__device__ __forceinline__
bool operator<(const real value) const {
__device__ __forceinline__ bool operator<(const real value) const {
return (v_ < value);
}
__device__ __forceinline__
bool operator<(const Pair& in) const {
__device__ __forceinline__ bool operator<(const Pair& in) const {
return (v_ < in.v_) || ((v_ == in.v_) && (id_ > in.id_));
}
__device__ __forceinline__
bool operator>(const Pair& in) const {
__device__ __forceinline__ bool operator>(const Pair& in) const {
return (v_ > in.v_) || ((v_ == in.v_) && (id_ < in.id_));
}
......@@ -58,8 +50,9 @@ struct Pair {
int id_;
};
__device__ __forceinline__
void addTo(Pair topK[], const Pair &p, int beamSize) {
__device__ __forceinline__ void addTo(Pair topK[],
const Pair& p,
int beamSize) {
for (int k = beamSize - 2; k >= 0; k--) {
if (topK[k] < p) {
topK[k + 1] = topK[k];
......@@ -71,9 +64,8 @@ void addTo(Pair topK[], const Pair &p, int beamSize) {
topK[0] = p;
}
template<int beamSize>
__device__ __forceinline__
void addTo(Pair topK[], const Pair &p) {
template <int beamSize>
__device__ __forceinline__ void addTo(Pair topK[], const Pair& p) {
for (int k = beamSize - 2; k >= 0; k--) {
if (topK[k] < p) {
topK[k + 1] = topK[k];
......@@ -85,9 +77,9 @@ void addTo(Pair topK[], const Pair &p) {
topK[0] = p;
}
template<int blockSize>
__device__ __forceinline__
void getTopK(Pair topK[], real *src, int idx, int dim, int beamSize) {
template <int blockSize>
__device__ __forceinline__ void getTopK(
Pair topK[], real* src, int idx, int dim, int beamSize) {
while (idx < dim) {
if (topK[beamSize - 1] < src[idx]) {
Pair tmp(src[idx], idx);
......@@ -97,10 +89,9 @@ void getTopK(Pair topK[], real *src, int idx, int dim, int beamSize) {
}
}
template<int blockSize>
__device__ __forceinline__
void getTopK(Pair topK[], real *src, int idx, int dim,
const Pair& max, int beamSize) {
template <int blockSize>
__device__ __forceinline__ void getTopK(
Pair topK[], real* src, int idx, int dim, const Pair& max, int beamSize) {
while (idx < dim) {
if (topK[beamSize - 1] < src[idx]) {
Pair tmp(src[idx], idx);
......@@ -112,10 +103,9 @@ void getTopK(Pair topK[], real *src, int idx, int dim,
}
}
template<int blockSize>
__device__ __forceinline__
void getTopK(Pair topK[], real *val, int *col,
int idx, int dim, int beamSize) {
template <int blockSize>
__device__ __forceinline__ void getTopK(
Pair topK[], real* val, int* col, int idx, int dim, int beamSize) {
while (idx < dim) {
if (topK[beamSize - 1] < val[idx]) {
Pair tmp(val[idx], col[idx]);
......@@ -125,10 +115,14 @@ void getTopK(Pair topK[], real *val, int *col,
}
}
template<int blockSize>
__device__ __forceinline__
void getTopK(Pair topK[], real *val, int *col, int idx, int dim,
const Pair& max, int beamSize) {
template <int blockSize>
__device__ __forceinline__ void getTopK(Pair topK[],
real* val,
int* col,
int idx,
int dim,
const Pair& max,
int beamSize) {
while (idx < dim) {
if (topK[beamSize - 1] < val[idx]) {
Pair tmp(val[idx], col[idx]);
......@@ -140,12 +134,16 @@ void getTopK(Pair topK[], real *val, int *col, int idx, int dim,
}
}
template<int maxLength, int blockSize>
__device__ __forceinline__
void threadGetTopK(Pair topK[], int& beam, int beamSize,
real* src,
bool& firstStep, bool& isEmpty, Pair& max,
int dim, const int tid) {
template <int maxLength, int blockSize>
__device__ __forceinline__ void threadGetTopK(Pair topK[],
int& beam,
int beamSize,
real* src,
bool& firstStep,
bool& isEmpty,
Pair& max,
int dim,
const int tid) {
if (beam > 0) {
int length = beam < beamSize ? beam : beamSize;
if (firstStep) {
......@@ -160,8 +158,7 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
}
}
if (!isEmpty) {
getTopK<blockSize>(topK + maxLength - beam, src, tid, dim,
max, length);
getTopK<blockSize>(topK + maxLength - beam, src, tid, dim, max, length);
}
}
......@@ -171,12 +168,17 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
}
}
template<int maxLength, int blockSize>
__device__ __forceinline__
void threadGetTopK(Pair topK[], int& beam, int beamSize,
real* val, int* col,
bool& firstStep, bool& isEmpty, Pair& max,
int dim, const int tid) {
template <int maxLength, int blockSize>
__device__ __forceinline__ void threadGetTopK(Pair topK[],
int& beam,
int beamSize,
real* val,
int* col,
bool& firstStep,
bool& isEmpty,
Pair& max,
int dim,
const int tid) {
if (beam > 0) {
int length = beam < beamSize ? beam : beamSize;
if (firstStep) {
......@@ -191,8 +193,8 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
}
}
if (!isEmpty) {
getTopK<blockSize>(topK + maxLength - beam, val, col, tid, dim,
max, length);
getTopK<blockSize>(
topK + maxLength - beam, val, col, tid, dim, max, length);
}
}
......@@ -202,12 +204,16 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
}
}
template<int maxLength, int blockSize>
__device__ __forceinline__
void blockReduce(Pair* shTopK, int* maxId, Pair topK[],
real** topVal, int** topIds,
int& beam, int& beamSize,
const int tid, const int warp) {
template <int maxLength, int blockSize>
__device__ __forceinline__ void blockReduce(Pair* shTopK,
int* maxId,
Pair topK[],
real** topVal,
int** topIds,
int& beam,
int& beamSize,
const int tid,
const int warp) {
while (true) {
__syncthreads();
if (tid < blockSize / 2) {
......@@ -218,7 +224,7 @@ void blockReduce(Pair* shTopK, int* maxId, Pair topK[],
}
}
__syncthreads();
for (int stride = blockSize / 4; stride > 0; stride = stride/2) {
for (int stride = blockSize / 4; stride > 0; stride = stride / 2) {
if (tid < stride) {
if (shTopK[maxId[tid]] < shTopK[maxId[tid + stride]]) {
maxId[tid] = maxId[tid + stride];
......@@ -257,10 +263,12 @@ void blockReduce(Pair* shTopK, int* maxId, Pair topK[],
* 3. go to the second setp, until one thread's topK value is null;
* 4. go to the first setp, until get the topK value.
*/
template<int maxLength, int blockSize>
__global__ void KeMatrixTopK(real* topVal, int ldv,
int * topIds,
real* src, int lds,
template <int maxLength, int blockSize>
__global__ void KeMatrixTopK(real* topVal,
int ldv,
int* topIds,
real* src,
int lds,
int dim,
int beamSize) {
__shared__ Pair shTopK[blockSize];
......@@ -271,7 +279,7 @@ __global__ void KeMatrixTopK(real* topVal, int ldv,
topVal += blockIdx.x * ldv;
topIds += blockIdx.x * beamSize;
Pair topK[maxLength]; // NOLINT
Pair topK[maxLength]; // NOLINT
int beam = maxLength;
Pair max;
bool isEmpty = false;
......@@ -281,18 +289,19 @@ __global__ void KeMatrixTopK(real* topVal, int ldv,
topK[k].set(-HL_FLOAT_MAX, -1);
}
while (beamSize) {
threadGetTopK<maxLength, blockSize>
(topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);
threadGetTopK<maxLength, blockSize>(
topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);
shTopK[tid] = topK[0];
blockReduce<maxLength, blockSize>
(shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
blockReduce<maxLength, blockSize>(
shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
}
}
template<int maxLength, int blockSize>
__global__ void KeSMatrixTopK(real* topVal, int ldv,
int * topIds,
template <int maxLength, int blockSize>
__global__ void KeSMatrixTopK(real* topVal,
int ldv,
int* topIds,
real* val,
int* row,
int* col,
......@@ -304,7 +313,7 @@ __global__ void KeSMatrixTopK(real* topVal, int ldv,
topVal += blockIdx.x * ldv;
topIds += blockIdx.x * beamSize;
Pair topK[maxLength]; // NOLINT
Pair topK[maxLength]; // NOLINT
int beam = maxLength;
Pair max;
bool isEmpty = false;
......@@ -330,18 +339,20 @@ __global__ void KeSMatrixTopK(real* topVal, int ldv,
topK[k].set(-HL_FLOAT_MAX, -1);
}
while (beamSize) {
threadGetTopK<maxLength, blockSize>
(topK, beam, beamSize, val, col, firstStep, isEmpty, max, dim, tid);
threadGetTopK<maxLength, blockSize>(
topK, beam, beamSize, val, col, firstStep, isEmpty, max, dim, tid);
shTopK[tid] = topK[0];
blockReduce<maxLength, blockSize>
(shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
blockReduce<maxLength, blockSize>(
shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
}
}
void hl_matrix_top_k(real* topVal, int ldv,
int * topIds,
real* src, int lds,
void hl_matrix_top_k(real* topVal,
int ldv,
int* topIds,
real* src,
int lds,
int dim,
int beamSize,
int numSamples) {
......@@ -353,33 +364,32 @@ void hl_matrix_top_k(real* topVal, int ldv,
dim3 threads(256, 1);
dim3 grid(numSamples, 1);
KeMatrixTopK<5, 256><<< grid, threads, 0, STREAM_DEFAULT >>>
(topVal, ldv, topIds, src, lds, dim, beamSize);
KeMatrixTopK<5, 256><<<grid, threads, 0, STREAM_DEFAULT>>>(
topVal, ldv, topIds, src, lds, dim, beamSize);
CHECK_SYNC("hl_matrix_top_k failed");
}
void hl_sparse_matrix_top_k(real* topVal, int ldv,
int * topIds,
void hl_sparse_matrix_top_k(real* topVal,
int ldv,
int* topIds,
hl_sparse_matrix_s src,
int beamSize,
int numSamples) {
CHECK_NOTNULL(topVal);
CHECK_NOTNULL(topIds);
CHECK_NOTNULL(src);
CHECK_EQ(src->format, HL_SPARSE_CSR)
<<"sparse matrix format error!";
CHECK_EQ(src->format, HL_SPARSE_CSR) << "sparse matrix format error!";
hl_csr_matrix csr = (hl_csr_matrix)src->matrix;
if (csr->csr_val == NULL || csr->csr_row == NULL ||
csr->csr_col == NULL) {
if (csr->csr_val == NULL || csr->csr_row == NULL || csr->csr_col == NULL) {
LOG(FATAL) << "parameter src is null!";
}
dim3 threads(256, 1);
dim3 grid(numSamples, 1);
KeSMatrixTopK<5, 256><<< grid, threads, 0, STREAM_DEFAULT >>>
(topVal, ldv, topIds, csr->csr_val, csr->csr_row, csr->csr_col, beamSize);
KeSMatrixTopK<5, 256><<<grid, threads, 0, STREAM_DEFAULT>>>(
topVal, ldv, topIds, csr->csr_val, csr->csr_row, csr->csr_col, beamSize);
CHECK_SYNC("hl_sparse_matrix_top_k failed");
}
......@@ -392,10 +402,12 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv,
* 3. go to the second setp, until one thread's topK value is null;
* 4. go to the first setp, until get the topK value.
*/
template<int maxLength, int blockSize>
__global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
int * topIds,
real* src, int lds,
template <int maxLength, int blockSize>
__global__ void KeMatrixTopKClassificationError(real* topVal,
int ldv,
int* topIds,
real* src,
int lds,
int dim,
int beamSize,
int* label,
......@@ -408,7 +420,7 @@ __global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
topVal += blockIdx.x * ldv;
topIds += blockIdx.x * beamSize;
Pair topK[maxLength]; // NOLINT
Pair topK[maxLength]; // NOLINT
int beam = maxLength;
Pair max;
bool isEmpty = false;
......@@ -420,34 +432,36 @@ __global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
}
while (beamSize) {
threadGetTopK<maxLength, blockSize>
(topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);
threadGetTopK<maxLength, blockSize>(
topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);
shTopK[tid] = topK[0];
blockReduce<maxLength, blockSize>
(shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
blockReduce<maxLength, blockSize>(
shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
}
__syncthreads();
if (tid == 0) {
for (int i = 0; i < topkSize; i++) {
if (*--topIds == label[blockIdx.x]) {
recResult[blockIdx.x] = 0;
break;
}
recResult[blockIdx.x] = 1.0f;
if (*--topIds == label[blockIdx.x]) {
recResult[blockIdx.x] = 0;
break;
}
recResult[blockIdx.x] = 1.0f;
}
}
}
void hl_matrix_classification_error(real* topVal, int ldv,
int* topIds,
real* src, int lds,
int dim,
int topkSize,
int numSamples,
int* label,
real* recResult) {
void hl_matrix_classification_error(real* topVal,
int ldv,
int* topIds,
real* src,
int lds,
int dim,
int topkSize,
int numSamples,
int* label,
real* recResult) {
CHECK_NOTNULL(topVal);
CHECK_NOTNULL(topIds);
CHECK_NOTNULL(src);
......@@ -456,9 +470,8 @@ void hl_matrix_classification_error(real* topVal, int ldv,
dim3 threads(256, 1);
dim3 grid(numSamples, 1);
KeMatrixTopKClassificationError<5, 256>
<<< grid, threads, 0, STREAM_DEFAULT >>>
(topVal, ldv, topIds, src, lds, dim, topkSize, label, recResult);
KeMatrixTopKClassificationError<5, 256><<<grid, threads, 0, STREAM_DEFAULT>>>(
topVal, ldv, topIds, src, lds, dim, topkSize, label, recResult);
CHECK_SYNC("hl_matrix_top_k classification error failed");
}
......@@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax="proto2";
syntax = "proto2";
package paddle.framework;
// Attribute Type for paddle's Op.
// Op contains many attributes. Each type of attributes could be different.
// The AttrType will be shared between AttrDesc and AttrProto.
enum AttrType {
INT = 0;
FLOAT = 1;
STRING = 2;
INTS = 3;
FLOATS = 4;
STRINGS = 5;
INT = 0;
FLOAT = 1;
STRING = 2;
INTS = 3;
FLOATS = 4;
STRINGS = 5;
}
\ No newline at end of file
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax="proto2";
syntax = "proto2";
package paddle.framework;
import "attribute.proto";
......@@ -22,14 +22,14 @@ import "attribute.proto";
//
// e.g, for scale=3.0: name=scala, type=AttrType.FLOAT, value=3.0
message AttrDesc {
required string name = 1;
required AttrType type = 2;
optional int32 i = 3;
optional float f = 4;
optional string s = 5;
repeated int32 ints = 6;
repeated float floats = 7;
repeated string strings = 8;
required string name = 1;
required AttrType type = 2;
optional int32 i = 3;
optional float f = 4;
optional string s = 5;
repeated int32 ints = 6;
repeated float floats = 7;
repeated string strings = 8;
};
// Protocol Message to describe an Operator.
......@@ -42,15 +42,15 @@ message AttrDesc {
// 3rd-party language can build this proto message and call
// AddOp(const OpDesc& op_desc) of Paddle core to create an Operator.
message OpDesc {
// input names of this Operator.
repeated string inputs = 1;
// input names of this Operator.
repeated string inputs = 1;
// output names of this Operator.
repeated string outputs = 2;
// output names of this Operator.
repeated string outputs = 2;
// type of this Operator, such as "add", "sub", "fc".
required string type = 3;
// type of this Operator, such as "add", "sub", "fc".
required string type = 3;
// Attributes of this Operator. e.g., scale=3.0 in cosine op.
repeated AttrDesc attrs = 4;
// Attributes of this Operator. e.g., scale=3.0 in cosine op.
repeated AttrDesc attrs = 4;
};
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "MulOp.h"
#include "hl_base.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"
......
......@@ -12,15 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "PadOp.h"
#include "hl_base.h"
namespace paddle {
__global__ void KePad(real* outputs, const real* inputs,
int inC, int inH, int inW,
int padc, int padh, int padw,
int outC, int outH, int outW, int nthreads) {
__global__ void KePad(real* outputs,
const real* inputs,
int inC,
int inH,
int inW,
int padc,
int padh,
int padw,
int outC,
int outH,
int outW,
int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % inW;
......@@ -50,16 +58,33 @@ void Pad<DEVICE_TYPE_GPU>(real* outputs,
int outC = inC + cstart + cend;
int outH = inH + hstart + hend;
int outW = inW + wstart + wend;
KePad<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(outputs, inputs, inC, inH, inW, cstart, hstart, wstart,
outC, outH, outW, nth);
KePad<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(outputs,
inputs,
inC,
inH,
inW,
cstart,
hstart,
wstart,
outC,
outH,
outW,
nth);
CHECK_SYNC("Pad");
}
__global__ void KePadDiff(real* inGrad, const real* outGrad,
int inC, int inH, int inW,
int padc, int padh, int padw,
int outC, int outH, int outW, int nthreads) {
__global__ void KePadDiff(real* inGrad,
const real* outGrad,
int inC,
int inH,
int inW,
int padc,
int padh,
int padw,
int outC,
int outH,
int outW,
int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % inW;
......@@ -89,9 +114,18 @@ void PadGrad<DEVICE_TYPE_GPU>(real* inGrad,
int outC = inC + cstart + cend;
int outH = inH + hstart + hend;
int outW = inW + wstart + wend;
KePadDiff <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(inGrad, outGrad, inC, inH, inW, cstart, hstart, wstart,
outC, outH, outW, nth);
KePadDiff<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(inGrad,
outGrad,
inC,
inH,
inW,
cstart,
hstart,
wstart,
outC,
outH,
outW,
nth);
CHECK_SYNC("PadGrad");
}
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
./trainer/tests/pydata_provider_wrapper_dir/test_pydata_provider_wrapper.proto
./trainer/tests/pydata_provider_wrapper_dir/test_pydata_provider_wrapper.proto_data
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册