提交 529f24c2 编写于 作者: H hedaoyuan

cpu cmrnorm

上级 b3f0f3d2
......@@ -381,57 +381,45 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad,
CHECK_SYNC("hl_avgpool_backward failed");
}
__global__ void KeCMRNormFillScale(size_t nthreads, const real* in,
__global__ void KeCMRNormFillScale(size_t imageSize, const real* in,
real* scale, size_t channels,
size_t height, size_t width, size_t size,
real alpha) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < nthreads) {
// find out the local offset
size_t w = index % width;
size_t h = (index / width) % height;
size_t n = index / width / height;
size_t offset = (n * channels * height + h) * width + w;
size_t step = height * width;
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < imageSize) {
const int w = idx % width;
const int h = (idx / width) % height;
const int n = idx / width / height;
const int offset = (n * channels * height + h) * width + w;
in += offset;
scale += offset;
size_t head = 0;
size_t pre_pad = (size - 1) / 2;
size_t post_pad = size - pre_pad - 1;
real accum_scale = 0;
// fill the scale at [n, :, h, w]
// accumulate values
while (head < post_pad) {
accum_scale += in[head * step] * in[head * step];
++head;
}
// until we reach size, nothing needs to be subtracted
while (head < size) {
accum_scale += in[head * step] * in[head * step];
scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
++head;
}
// both add and subtract
while (head < channels) {
accum_scale += in[head * step] * in[head * step];
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
++head;
}
// subtract only
while (head < channels + post_pad) {
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
++head;
const int step = height * width;
const int pre_pad = (size - 1) / 2;
const int post_pad = size - pre_pad - 1;
real accum = 0;
int index = 0;
while (index < channels + post_pad) {
if (index < channels) {
accum += in[index * step] * in[index * step];
}
if (index >= size) {
accum -= in[(index - size) * step] * in[(index - size) * step];
}
if (index >= post_pad) {
scale[(index - post_pad) * step] = 1. + accum * alpha;
}
++index;
}
}
}
__global__ void KeCMRNormOutput(size_t nthreads, const real* in,
const real* scale, real negative_beta,
real* out) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < nthreads) {
__global__ void KeCMRNormOutput(size_t inputSize, const real* in,
const real* scale, real negative_beta,
real* out) {
const int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < inputSize) {
out[index] = in[index] * pow(scale[index], negative_beta);
}
}
......@@ -440,84 +428,60 @@ void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale,
real* out, size_t channels,
size_t height, size_t width, size_t sizeX,
real alpha, real beta) {
size_t threadsNum = frameCnt * height * width;
size_t blocksX = (threadsNum + 1024 - 1) / 1024;
size_t blocksY = 1;
dim3 threads(1024, 1);
dim3 grid(blocksX, blocksY);
KeCMRNormFillScale<<<grid, threads, 0, STREAM_DEFAULT>>>
(threadsNum, in, scale, channels, height, width, sizeX, alpha);
threadsNum = frameCnt * height * width *channels;
blocksX = (threadsNum + 1024 -1) / 1024;
dim3 threads2(1024, 1);
dim3 grid2(blocksX, blocksY);
KeCMRNormOutput<<<grid2, threads2, 0, STREAM_DEFAULT>>>
(threadsNum, in, scale, beta, out);
size_t imageSize = frameCnt * height * width;
int blockSize = 1024;
int gridSize = (imageSize + 1024 - 1) / 1024;
KeCMRNormFillScale<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(imageSize, in, scale, channels, height, width, sizeX, alpha);
size_t inputSize = frameCnt * height * width *channels;
blockSize = 1024;
gridSize = (inputSize + 1024 - 1) / 1024;
KeCMRNormOutput<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(inputSize, in, scale, beta, out);
CHECK_SYNC("hl_CMRNorm_forward");
}
__global__ void KeCMRNormDiff(size_t nthreads, const real* bottom_data,
__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data,
const real* top_data, const real* scale,
const real* top_diff, size_t channels,
size_t height, size_t width, size_t size,
real negative_beta, real cache_ratio,
real* bottom_diff ) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < nthreads) {
// find out the local offset
size_t w = index % width;
size_t h = (index / width) % height;
size_t n = index / width / height;
size_t offset = (n * channels * height + h) * width + w;
size_t step = height * width;
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < imageSize) {
const int w = idx % width;
const int h = (idx / width) % height;
const int n = idx / width / height;
const int offset = (n * channels * height + h) * width + w;
bottom_data += offset;
top_data += offset;
scale += offset;
top_diff += offset;
bottom_diff += offset;
int head = 0;
int pre_pad = size - (size + 1) / 2;
int post_pad = size - pre_pad - 1;
real accum_ratio = 0;
// accumulate values
while (head < post_pad) {
accum_ratio += top_diff[head * step] *
top_data[head * step] / scale[head * step];
++head;
}
// until we reach size, nothing needs to be subtracted
while (head < size) {
accum_ratio += top_diff[head * step] *
top_data[head * step] / scale[head * step];
bottom_diff[(head - post_pad) * step] +=
top_diff[(head - post_pad) * step] *
pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
// both add and subtract
while (head < channels) {
accum_ratio += top_diff[head * step] * top_data[head * step] /
scale[head * step];
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
bottom_diff[(head - post_pad) * step] +=
top_diff[(head - post_pad) * step] *
pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
// subtract only
while (head < channels + post_pad) {
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
bottom_diff[(head - post_pad) * step] +=
top_diff[(head - post_pad) * step] *
pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
const int step = height * width;
const int pre_pad = size - (size + 1) / 2;
const int post_pad = size - pre_pad - 1;
int index = 0;
real accum = 0;
while (index < channels + post_pad) {
if (index < channels) {
accum += top_diff[index * step] * top_data[index * step] /
scale[index * step];
}
if (index >= size) {
accum -= top_diff[(index - size) * step] *
top_data[(index - size) * step] / scale[(index - size) * step];
}
if (index >= post_pad) {
bottom_diff[(index - post_pad) * step] +=
top_diff[(index - post_pad) * step] *
pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(index - post_pad) * step] * accum;
}
++index;
}
}
}
......@@ -528,14 +492,12 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
real *inDiff, size_t channels,
size_t height, size_t width, size_t sizeX,
real alpha, real beta) {
size_t threadsNum = frameCnt * height * width;
size_t blocksX = (threadsNum + 1024 - 1) / 1024;
size_t blocksY = 1;
dim3 threads(1024, 1);
dim3 grid(blocksX, blocksY);
KeCMRNormDiff <<<grid, threads, 0, STREAM_DEFAULT>>>
(threadsNum, inV, outV, scale, outDiff, channels,
height, width, sizeX, alpha, beta, inDiff);
size_t imageSize = frameCnt * height * width;
int blockSize = 1024;
int gridSize = (imageSize + 1024 - 1) / 1024;
KeCMRNormDiff <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(imageSize, inV, outV, scale, outDiff, channels,
height, width, sizeX, alpha, beta, inDiff);
CHECK_SYNC("hl_CMRNorm_backward");
}
......
......@@ -1021,11 +1021,10 @@ void testNormLayer(const string& normType, bool trans, bool useGpu) {
testLayerGrad(config, "norm", 100, trans, useGpu);
}
#ifndef PADDLE_ONLY_CPU
TEST(Layer, NormLayer) {
testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ true);
testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ false);
}
#endif
void setPoolConfig(TestConfig* config,
PoolConfig* pool,
......
......@@ -2227,52 +2227,43 @@ void CpuMatrix::crossMapNormalFwd(Matrix& input,
size_t sizeX,
float scale,
float pow) {
size_t num = input.getHeight();
CHECK(isContiguous());
CHECK(input.isContiguous());
CHECK(denoms.isContiguous());
CHECK_EQ(getHeight(), input.getHeight());
CHECK_EQ(getWidth(), input.getWidth());
CHECK_EQ(getHeight(), denoms.getHeight());
CHECK_EQ(getWidth(), denoms.getWidth());
size_t numSample = input.getHeight();
size_t numCols = input.getWidth();
size_t height = imgSizeH;
size_t width = imgSizeW;
size_t numCols = input.getWidth();
CHECK(height * width * channels == input.getWidth());
CHECK(denoms.getHeight() == input.getHeight() &&
denoms.getWidth() == input.getWidth() && input.getHeight() == height_ &&
input.getWidth() == width_);
real* imgData = input.getData();
real* diffData = input.getData();
real* targetData = getData();
size_t halfSize = sizeX / 2;
size_t imgPixels = height * width;
// use integral vector to implement the sum in local window
real* integralData =
(real*)malloc((channels + sizeX + 1) * sizeof(real)); // NOLINT // TODO:
for (size_t i = 0; i <= halfSize; i++) {
integralData[i] = 0;
}
for (size_t i = 0; i < num; i++) {
real* targetPtr = targetData + i * numCols;
real* imgPtr = imgData + i * numCols;
real* diffPtr = diffData + i * numCols;
for (size_t m = 0; m < height; m++) {
for (size_t n = 0; n < width; n++) {
for (size_t c = 0; c < channels; c++) {
integralData[c + halfSize + 1] =
integralData[c + halfSize] + _square(*(diffPtr + c * imgPixels));
}
for (size_t k = channels + halfSize + 1; k <= channels + sizeX; k++) {
integralData[k] = integralData[channels + halfSize];
CHECK(height * width * channels == numCols);
// TODO(hedaoyuan) After commit TensorExpress code,
// Reconstruction this code to remove the temporary memory.
CpuMatrix tmp(channels, height * width);
CpuMatrix tmp2(tmp.getData(), 1, channels * height * width);
denoms.zero();
const int start = -((int)sizeX - 1) / 2;
const int end = (int)sizeX + start;
for (size_t i = 0; i < numSample; i++) {
input.subMatrix(i, 1)->square2(tmp2);
CpuMatrix subDen(
denoms.subMatrix(i, 1)->getData(), channels, height * width);
for (int c = 0; c < (int)channels; c++) {
for (int s = start; s < end; s++) {
if (c + s >= 0 && c + s < (int)channels) {
subDen.subMatrix(c, 1)->add(*tmp.subMatrix(c + s, 1));
}
for (size_t k = 0; k < channels; k += 1) {
real a = integralData[k + sizeX] - integralData[k];
a = scale * a + 1;
targetPtr[k * imgPixels] = imgPtr[k * imgPixels] * _pow(a, -pow);
}
diffPtr++;
targetPtr++;
imgPtr++;
}
}
}
free(integralData);
integralData = NULL;
denoms.add(scale, (real)1);
this->pow2(denoms, -pow);
this->dotMul(input);
}
void CpuMatrix::crossMapNormalBwd(Matrix& localGrad,
......@@ -2282,19 +2273,63 @@ void CpuMatrix::crossMapNormalBwd(Matrix& localGrad,
size_t channels,
size_t imgSizeH,
size_t imgSizeW,
size_t size,
size_t sizeX,
float scale,
float pow) {
LOG(FATAL) << "Not implemented";
CHECK(imgSizeH * imgSizeW * channels == preOutV.getWidth());
CHECK(denoms.getHeight() == preOutV.getHeight() &&
denoms.getWidth() == preOutV.getWidth() &&
preOutV.getHeight() == height_ && preOutV.getWidth() == width_);
CHECK(denoms.getHeight() == localGrad.getHeight() &&
denoms.getWidth() == localGrad.getWidth());
// NOLINT // TODO:
CHECK(isContiguous());
CHECK(localGrad.isContiguous());
CHECK(denoms.isContiguous());
CHECK(preOutV.isContiguous());
CHECK(localOutV.isContiguous());
CHECK_EQ(getHeight(), localGrad.getHeight());
CHECK_EQ(getWidth(), localGrad.getWidth());
CHECK_EQ(getHeight(), denoms.getHeight());
CHECK_EQ(getWidth(), denoms.getWidth());
CHECK_EQ(getHeight(), preOutV.getHeight());
CHECK_EQ(getWidth(), preOutV.getWidth());
CHECK_EQ(getHeight(), localOutV.getHeight());
CHECK_EQ(getWidth(), localOutV.getWidth());
size_t numSample = getHeight();
size_t numCols = getWidth();
size_t height = imgSizeH;
size_t width = imgSizeW;
CHECK(height * width * channels == numCols);
// TODO(hedaoyuan) After commit TensorExpress code,
// Reconstruction this code to remove the temporary memory.
CpuMatrix tmp(1, height * width);
const int start = -((int)sizeX) / 2;
const int end = (int)sizeX + start;
const real ratio = -(real)2 * scale * pow;
for (size_t i = 0; i < numSample; i++) {
CpuMatrix inputDiff(
this->subMatrix(i, 1)->getData(), channels, height * width);
CpuMatrix outDiff(
localGrad.subMatrix(i, 1)->getData(), channels, height * width);
CpuMatrix input(
preOutV.subMatrix(i, 1)->getData(), channels, height * width);
CpuMatrix output(
localOutV.subMatrix(i, 1)->getData(), channels, height * width);
CpuMatrix subDen(
denoms.subMatrix(i, 1)->getData(), channels, height * width);
for (int c = 0; c < (int)channels; c++) {
tmp.pow2(*subDen.subMatrix(c, 1), -pow);
inputDiff.subMatrix(c, 1)
->addDotMul(tmp, *outDiff.subMatrix(c, 1), (real)1, (real)1);
for (int s = start; s < end; s++) {
if (c + s >= 0 && c + s < (int)channels) {
tmp.dotMul(*outDiff.subMatrix(c + s, 1), *output.subMatrix(c + s, 1));
tmp.mulScalar(ratio);
tmp.dotDiv(tmp, *subDen.subMatrix(c + s, 1));
tmp.dotMul(*input.subMatrix(c, 1));
inputDiff.subMatrix(c, 1)->add(tmp);
}
}
}
}
}
/**
......
......@@ -1261,6 +1261,121 @@ TEST(Matrix, MaxOutFwdBwd) {
}
}
}
void testCrossMapNormalFwd(
int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) {
float scale = 1.5;
float pow = 0.5;
int width = imgSizeH * imgSizeW * channels;
MatrixPtr input = CpuMatrix::create(numSamples, width, false, false);
MatrixPtr denorms = CpuMatrix::create(numSamples, width, false, false);
MatrixPtr target = CpuMatrix::create(numSamples, width, false, false);
MatrixPtr inputGpu = GpuMatrix::create(numSamples, width, false, true);
MatrixPtr denormsGpu = GpuMatrix::create(numSamples, width, false, true);
MatrixPtr targetGpu = GpuMatrix::create(numSamples, width, false, true);
input->randomizeUniform();
target->randomizeUniform();
inputGpu->copyFrom(*input);
targetGpu->copyFrom(*target);
target->crossMapNormalFwd(
*input, imgSizeH, imgSizeW, *denorms, channels, sizeX, scale, pow);
targetGpu->crossMapNormalFwd(
*inputGpu, imgSizeH, imgSizeW, *denormsGpu, channels, sizeX, scale, pow);
TensorCheckErr(*target, *targetGpu);
TensorCheckErr(*denorms, *denormsGpu);
}
TEST(Matrix, crossMapNormalFwd) {
for (auto numSamples : {5, 32}) {
for (auto channels : {1, 5, 32}) {
for (auto imgSizeH : {5, 33, 100}) {
for (auto imgSizeW : {5, 32, 96}) {
for (auto sizeX : {1, 2, 3, 5, 7}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW
<< " sizeX=" << sizeX;
testCrossMapNormalFwd(
numSamples, channels, imgSizeH, imgSizeW, sizeX);
}
}
}
}
}
}
void testCrossMapNormalBwd(
int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) {
float scale = 1.5;
float pow = 0.5;
size_t width = imgSizeH * imgSizeW * channels;
MatrixPtr localGrad = CpuMatrix::create(numSamples, width, false, false);
MatrixPtr denoms = CpuMatrix::create(numSamples, width, false, false);
MatrixPtr output = CpuMatrix::create(numSamples, width, false, false);
MatrixPtr preOutV = CpuMatrix::create(numSamples, width, false, false);
MatrixPtr localOutV = CpuMatrix::create(numSamples, width, false, false);
localGrad->randomizeUniform();
denoms->randomizeUniform();
preOutV->randomizeUniform();
localOutV->randomizeUniform();
output->randomizeUniform();
denoms->add(0.01);
MatrixPtr localGradGpu = GpuMatrix::create(numSamples, width, false, true);
MatrixPtr denomsGpu = GpuMatrix::create(numSamples, width, false, true);
MatrixPtr outputGpu = GpuMatrix::create(numSamples, width, false, true);
MatrixPtr preOutVGpu = GpuMatrix::create(numSamples, width, false, true);
MatrixPtr localOutVGpu = GpuMatrix::create(numSamples, width, false, true);
localGradGpu->copyFrom(*localGrad);
denomsGpu->copyFrom(*denoms);
preOutVGpu->copyFrom(*preOutV);
localOutVGpu->copyFrom(*localOutV);
outputGpu->copyFrom(*output);
output->crossMapNormalBwd(*localGrad,
*denoms,
*preOutV,
*localOutV,
channels,
imgSizeH,
imgSizeW,
sizeX,
scale,
pow);
outputGpu->crossMapNormalBwd(*localGradGpu,
*denomsGpu,
*preOutVGpu,
*localOutVGpu,
channels,
imgSizeH,
imgSizeW,
sizeX,
scale,
pow);
TensorCheckErr(*output, *outputGpu);
}
TEST(Matrix, crossMapNormalBwd) {
for (auto numSamples : {5, 32}) {
for (auto channels : {1, 5, 32}) {
for (auto imgSizeH : {5, 33, 100}) {
for (auto imgSizeW : {5, 32, 96}) {
for (auto sizeX : {1, 2, 3, 5, 7}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW
<< " sizeX=" << sizeX;
testCrossMapNormalBwd(
numSamples, channels, imgSizeH, imgSizeW, sizeX);
}
}
}
}
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册