提交 42e12179 编写于 作者: T tianbingsz 提交者: GitHub

Merge pull request #854 from hedaoyuan/cmrnorm

Cmrnorm
......@@ -96,6 +96,7 @@ function(link_paddle_exe TARGET_NAME)
target_circle_link_libraries(${TARGET_NAME}
ARCHIVE_START
paddle_gserver
paddle_function
${METRIC_LIBS}
ARCHIVE_END
paddle_pserver
......@@ -106,6 +107,7 @@ function(link_paddle_exe TARGET_NAME)
paddle_parameter
paddle_proto
paddle_cuda
paddle_test_main
${METRIC_LIBS}
${PROTOBUF_LIBRARY}
${LIBGLOG_LIBRARY}
......
add_subdirectory(cuda)
add_subdirectory(function)
add_subdirectory(utils)
add_subdirectory(math)
add_subdirectory(parameter)
......
......@@ -46,6 +46,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp
WORKING_DIRECTORY ${PROJ_ROOT}/paddle
DEPENDS python_swig_sources
paddle_parameter
paddle_function
paddle_math
paddle_utils
paddle_gserver
......
......@@ -30,8 +30,8 @@ try:
whole_end = ""
LIB_DIRS = [
"math", 'utils', 'parameter', "gserver", "api", "cuda", "pserver",
"trainer"
"math", 'function', 'utils', 'parameter', "gserver", "api", "cuda",
"pserver", "trainer"
]
PARENT_LIB_DIRS = ['proto']
......@@ -75,6 +75,7 @@ try:
libs = [
whole_start,
"-lpaddle_gserver",
"-lpaddle_function",
whole_end,
"-lpaddle_pserver",
"-lpaddle_trainer_lib",
......
......@@ -240,62 +240,6 @@ extern void hl_avgpool_backward(const int frameCnt,
real* backGrad,
const int outStride);
/**
* @brief Cross-map-respose normalize forward.
*
* @param[in] frameCnt batch size of input image.
* @param[in] in input data.
* @param[in] scale buffer.
* @param[out] out output data.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] sizeX size.
* @param[in] alpha scale.
* @param[in] beta scale.
*
*/
extern void hl_CMRNorm_forward(size_t frameCnt,
const real* in,
real* scale,
real* out,
size_t channels,
size_t height,
size_t width,
size_t sizeX,
real alpha,
real beta);
/**
* @brief Cross-map-respose normalize backward.
*
* @param[in] frameCnt batch size of input image.
* @param[in] inV input data.
* @param[in] scale buffer.
* @param[out] outV output value.
* @param[out] outDiff output grad.
* @param[out] inDiff input grad.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] sizeX size.
* @param[in] alpha scale.
* @param[in] beta scale.
*
*/
extern void hl_CMRNorm_backward(size_t frameCnt,
const real* inV,
const real* scale,
const real* outV,
const real* outDiff,
real* inDiff,
size_t channels,
size_t height,
size_t width,
size_t sizeX,
real alpha,
real beta);
/**
* @brief Bilinear interpolation forward.
*
......
......@@ -117,30 +117,6 @@ inline void hl_avgpool_backward(const int frameCnt,
real* backGrad,
const int outStride) {}
inline void hl_CMRNorm_forward(size_t frameCnt,
const real* in,
real* scale,
real* out,
size_t channels,
size_t height,
size_t width,
size_t sizeX,
real alpha,
real beta) {}
inline void hl_CMRNorm_backward(size_t frameCnt,
const real* inV,
const real* scale,
const real* outV,
const real* outDiff,
real* inDiff,
size_t channels,
size_t height,
size_t width,
size_t sizeX,
real alpha,
real beta) {}
inline void hl_bilinear_forward(const real* inData,
const size_t inImgH,
const size_t inImgW,
......
......@@ -381,164 +381,6 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad,
CHECK_SYNC("hl_avgpool_backward failed");
}
__global__ void KeCMRNormFillScale(size_t nthreads, const real* in,
real* scale, size_t channels,
size_t height, size_t width, size_t size,
real alpha) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < nthreads) {
// find out the local offset
size_t w = index % width;
size_t h = (index / width) % height;
size_t n = index / width / height;
size_t offset = (n * channels * height + h) * width + w;
size_t step = height * width;
in += offset;
scale += offset;
size_t head = 0;
size_t pre_pad = (size - 1) / 2;
size_t post_pad = size - pre_pad - 1;
real accum_scale = 0;
// fill the scale at [n, :, h, w]
// accumulate values
while (head < post_pad) {
accum_scale += in[head * step] * in[head * step];
++head;
}
// until we reach size, nothing needs to be subtracted
while (head < size) {
accum_scale += in[head * step] * in[head * step];
scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
++head;
}
// both add and subtract
while (head < channels) {
accum_scale += in[head * step] * in[head * step];
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
++head;
}
// subtract only
while (head < channels + post_pad) {
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
++head;
}
}
}
__global__ void KeCMRNormOutput(size_t nthreads, const real* in,
const real* scale, real negative_beta,
real* out) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < nthreads) {
out[index] = in[index] * pow(scale[index], negative_beta);
}
}
void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale,
real* out, size_t channels,
size_t height, size_t width, size_t sizeX,
real alpha, real beta) {
size_t threadsNum = frameCnt * height * width;
size_t blocksX = (threadsNum + 1024 - 1) / 1024;
size_t blocksY = 1;
dim3 threads(1024, 1);
dim3 grid(blocksX, blocksY);
KeCMRNormFillScale<<<grid, threads, 0, STREAM_DEFAULT>>>
(threadsNum, in, scale, channels, height, width, sizeX, alpha);
threadsNum = frameCnt * height * width *channels;
blocksX = (threadsNum + 1024 -1) / 1024;
dim3 threads2(1024, 1);
dim3 grid2(blocksX, blocksY);
KeCMRNormOutput<<<grid2, threads2, 0, STREAM_DEFAULT>>>
(threadsNum, in, scale, beta, out);
CHECK_SYNC("hl_CMRNorm_forward");
}
__global__ void KeCMRNormDiff(size_t nthreads, const real* bottom_data,
const real* top_data, const real* scale,
const real* top_diff, size_t channels,
size_t height, size_t width, size_t size,
real negative_beta, real cache_ratio,
real* bottom_diff ) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < nthreads) {
// find out the local offset
size_t w = index % width;
size_t h = (index / width) % height;
size_t n = index / width / height;
size_t offset = (n * channels * height + h) * width + w;
size_t step = height * width;
bottom_data += offset;
top_data += offset;
scale += offset;
top_diff += offset;
bottom_diff += offset;
int head = 0;
int pre_pad = size - (size + 1) / 2;
int post_pad = size - pre_pad - 1;
real accum_ratio = 0;
// accumulate values
while (head < post_pad) {
accum_ratio += top_diff[head * step] *
top_data[head * step] / scale[head * step];
++head;
}
// until we reach size, nothing needs to be subtracted
while (head < size) {
accum_ratio += top_diff[head * step] *
top_data[head * step] / scale[head * step];
bottom_diff[(head - post_pad) * step] +=
top_diff[(head - post_pad) * step] *
pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
// both add and subtract
while (head < channels) {
accum_ratio += top_diff[head * step] * top_data[head * step] /
scale[head * step];
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
bottom_diff[(head - post_pad) * step] +=
top_diff[(head - post_pad) * step] *
pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
// subtract only
while (head < channels + post_pad) {
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
bottom_diff[(head - post_pad) * step] +=
top_diff[(head - post_pad) * step] *
pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
}
}
void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
const real* scale,
const real* outV, const real* outDiff,
real *inDiff, size_t channels,
size_t height, size_t width, size_t sizeX,
real alpha, real beta) {
size_t threadsNum = frameCnt * height * width;
size_t blocksX = (threadsNum + 1024 - 1) / 1024;
size_t blocksY = 1;
dim3 threads(1024, 1);
dim3 grid(blocksX, blocksY);
KeCMRNormDiff <<<grid, threads, 0, STREAM_DEFAULT>>>
(threadsNum, inV, outV, scale, outDiff, channels,
height, width, sizeX, alpha, beta, inDiff);
CHECK_SYNC("hl_CMRNorm_backward");
}
__global__ void KeBilinearInterpFw(const real* in,
const size_t inImgH,
const size_t inImgW,
......
file(GLOB h_files . *_op.h)
file(GLOB cpp_files . *_op.cpp)
list(APPEND h_files Function.h)
list(APPEND cpp_files Function.cpp)
if(WITH_GPU)
file(GLOB cu_files . *_op_gpu.cu)
cuda_compile(cu_objs ${cu_files})
endif()
add_library(paddle_function STATIC ${cpp_files} ${cu_objs})
add_library(paddle_test_main STATIC TestMain.cpp)
if(WITH_GPU)
# TODO:
# file(GLOB test_files . *_op_test.cpp)
# add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files})
add_simple_unittest(cross_map_normal_op_test)
endif()
add_style_check_target(paddle_function ${h_files})
add_style_check_target(paddle_function ${cpp_files})
if(WITH_GPU)
add_style_check_target(paddle_function ${cu_files})
endif()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Function.h"
namespace paddle {
template <>
size_t FuncConfig::get<size_t>(const std::string& key) const {
auto it = valueMap_.find(key);
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
return it->second.s;
}
template <>
real FuncConfig::get<real>(const std::string& key) const {
auto it = valueMap_.find(key);
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
return it->second.r;
}
template <>
FuncConfig& FuncConfig::set<size_t>(const std::string& key, size_t v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
valueMap_[key].s = v;
return *this;
}
template <>
FuncConfig& FuncConfig::set<real>(const std::string& key, real v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
valueMap_[key].r = v;
return *this;
}
ClassRegistrar<FunctionBase> FunctionBase::funcRegistrar_;
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <vector>
#include "paddle/math/Matrix.h"
#include "paddle/utils/ClassRegistrar.h"
namespace paddle {
enum DeviceType {
DEVICE_TYPE_UNSPECIFIED = 0,
DEVICE_TYPE_CPU = 1,
DEVICE_TYPE_GPU = 2,
};
template <DeviceType Device>
struct MatrixT;
template <>
struct MatrixT<DEVICE_TYPE_CPU> {
using type = CpuMatrix;
};
template <>
struct MatrixT<DEVICE_TYPE_GPU> {
using type = GpuMatrix;
};
typedef std::vector<size_t> Dims;
class Tensor {
public:
Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {}
real* getData() const { return buf_; }
real* buf_;
Dims dims_;
};
typedef std::vector<Tensor> Arguments;
class FuncConfig {
public:
union value {
size_t s;
real r;
};
template <typename T>
T get(const std::string& key) const;
template <typename T>
FuncConfig& set(const std::string& key, T v);
protected:
std::map<std::string, value> valueMap_;
};
class FunctionBase {
public:
virtual ~FunctionBase() {}
virtual void init(const FuncConfig& config) {}
virtual void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) {}
static ClassRegistrar<FunctionBase> funcRegistrar_;
};
#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName
#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \
static InitFunction __reg_type_##typeName##deviceName([]() { \
FunctionBase::funcRegistrar_ \
.registerClass<className<DEVICE_TYPE_##deviceName>>( \
FUNC_NAME(typeName, deviceName)); \
})
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Function.h"
#include "paddle/math/Vector.h"
#include "paddle/math/tests/TensorCheck.h"
namespace paddle {
class FunctionCompare {
public:
FunctionCompare(const std::string& name, const FuncConfig& config)
: cpu(FunctionBase::funcRegistrar_.createByType(name + "-CPU")),
gpu(FunctionBase::funcRegistrar_.createByType(name + "-GPU")) {
cpu->init(config);
gpu->init(config);
}
void cmpWithArg(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) {
// init cpu and gpu arguments
auto initArgs = [=](
Arguments& cpuArgs, Arguments& gpuArgs, const Arguments& inArgs) {
for (auto arg : inArgs) {
size_t size = sizeof(real);
for (auto dim : arg.dims_) {
size *= dim;
}
cpuMemory.emplace_back(std::make_shared<CpuMemoryHandle>(size));
gpuMemory.emplace_back(std::make_shared<GpuMemoryHandle>(size));
cpuArgs.emplace_back(
Tensor((real*)cpuMemory.back()->getBuf(), arg.dims_));
gpuArgs.emplace_back(
Tensor((real*)gpuMemory.back()->getBuf(), arg.dims_));
// will use an api to refactor this code.
CpuVector cpuVector(size / sizeof(real),
(real*)cpuArgs.back().getData());
GpuVector gpuVector(size / sizeof(real),
(real*)gpuArgs.back().getData());
cpuVector.uniform(0.001, 1);
gpuVector.copyFrom(cpuVector);
}
};
initArgs(cpuInputs, gpuInputs, inputs);
initArgs(cpuOutputs, gpuOutputs, outputs);
initArgs(cpuInouts, gpuInouts, inouts);
// function calculate
cpu->calc(cpuInputs, cpuOutputs, cpuInouts);
gpu->calc(gpuInputs, gpuOutputs, gpuInouts);
// check outputs and inouts
auto checkArgs = [=](const Arguments& cpuArgs, const Arguments& gpuArgs) {
for (size_t i = 0; i < cpuArgs.size(); i++) {
auto cpu = cpuArgs[i];
auto gpu = gpuArgs[i];
size_t size = 1;
for (auto dim : cpu.dims_) {
size *= dim;
}
CpuVector cpuVector(size, (real*)cpu.getData());
GpuVector gpuVector(size, (real*)gpu.getData());
autotest::TensorCheckErr(cpuVector, gpuVector);
}
};
checkArgs(cpuOutputs, gpuOutputs);
checkArgs(cpuInouts, gpuInouts);
}
protected:
std::shared_ptr<FunctionBase> cpu;
std::shared_ptr<FunctionBase> gpu;
std::vector<CpuMemHandlePtr> cpuMemory;
std::vector<GpuMemHandlePtr> gpuMemory;
Arguments cpuInputs;
Arguments cpuOutputs;
Arguments cpuInouts;
Arguments gpuInputs;
Arguments gpuOutputs;
Arguments gpuInouts;
};
} // namespace paddle
using paddle::FunctionCompare;
using paddle::FuncConfig;
using paddle::Dims;
using paddle::Tensor;
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/utils/Util.h"
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
paddle::initMain(argc, argv);
return RUN_ALL_TESTS();
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cross_map_normal_op.h"
#include "paddle/math/Vector.h"
namespace paddle {
template <>
void CrossMapNormal<DEVICE_TYPE_CPU>(real* outputs,
real* denoms,
const real* inputs,
size_t numSamples,
size_t channels,
size_t height,
size_t width,
size_t size,
real scale,
real pow) {
size_t oneImage = height * width;
size_t oneSample = channels * oneImage;
CpuVector outputsV(numSamples * oneSample, outputs);
CpuVector inputsV(numSamples * oneSample, const_cast<real*>(inputs));
CpuVector denomsV(numSamples * oneSample, denoms);
// f(x) = x * ( 1 + scale * SUM((x)^2) )^(-pow)
// x represents inputs
// f(x) represents outputs
// denoms save the intermediate result for backward
denomsV = denomsV.constant(1.0);
const int start = -((int)size - 1) / 2;
const int end = (int)size + start;
for (size_t i = 0; i < numSamples; i++) {
real* oneDenom = denoms + i * oneSample;
real* oneInput = const_cast<real*>(inputs) + i * oneSample;
for (int c = 0; c < (int)channels; c++) {
CpuVector denom(oneImage, oneDenom + c * oneImage);
for (int s = start; s < end; s++) {
if (c + s >= 0 && c + s < (int)channels) {
CpuVector input(oneImage, oneInput + (c + s) * oneImage);
denom += input.square() * scale;
}
}
}
}
outputsV = inputsV * denomsV.pow(-pow);
}
template <>
void CrossMapNormalGrad<DEVICE_TYPE_CPU>(real* inputsGrad,
const real* inputsValue,
const real* outputsValue,
const real* outputsGrad,
const real* denoms,
size_t numSamples,
size_t channels,
size_t height,
size_t width,
size_t size,
real scale,
real pow) {
size_t oneSample = channels * height * width;
std::function<CpuVector(real*, size_t)> oneImage = [=](real* data,
size_t offset) {
return CpuVector(height * width, data + offset);
};
const int start = -((int)size) / 2;
const int end = (int)size + start;
const real ratio = -(real)2 * scale * pow;
for (size_t i = 0; i < numSamples; i++) {
size_t sOffset = i * oneSample;
real* oneInputGrad = inputsGrad + sOffset;
real* oneInputValue = const_cast<real*>(inputsValue) + sOffset;
real* oneDenom = const_cast<real*>(denoms) + sOffset;
real* oneOutputGrad = const_cast<real*>(outputsGrad) + sOffset;
real* oneOutputValue = const_cast<real*>(outputsValue) + sOffset;
for (int c = 0; c < (int)channels; c++) {
size_t cOffset = c * height * width;
CpuVector inputGrad = oneImage(oneInputGrad, cOffset);
CpuVector inputValue = oneImage(oneInputValue, cOffset);
CpuVector denom = oneImage(oneDenom, cOffset);
CpuVector outputGrad = oneImage(oneOutputGrad, cOffset);
inputGrad = inputGrad + denom.pow(-pow) * outputGrad;
for (int s = start; s < end; s++) {
if (c + s >= 0 && c + s < (int)channels) {
size_t offset = (c + s) * height * width;
CpuVector output = oneImage(oneOutputValue, offset);
CpuVector outputGrad = oneImage(oneOutputGrad, offset);
CpuVector denom = oneImage(oneDenom, offset);
inputGrad += ((outputGrad * output * ratio) / denom) * inputValue;
}
}
}
}
}
/**
* \param inputs[0] input value.
* \param outputs[0] output value.
* \param outputs[1] denoms.
*/
template <DeviceType Device>
class CrossMapNormalFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(1, inputs.size());
CHECK_EQ(2, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK_EQ(inputs[0].dims_.size(), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]);
}
size_t samples = inputs[0].dims_[0];
size_t channels = inputs[0].dims_[1];
size_t height = inputs[0].dims_[2];
size_t width = inputs[0].dims_[3];
CrossMapNormal<Device>(outputs[0].getData(),
outputs[1].getData(),
inputs[0].getData(),
samples,
channels,
height,
width,
size_,
scale_,
pow_);
}
private:
size_t size_;
real scale_;
real pow_;
};
/**
* \param inputs[0] input value.
* \param inputs[1] output value.
* \param inputs[2] output grad.
* \param inputs[3] denoms.
* \param outputs[0] input grad.
*/
template <DeviceType Device>
class CrossMapNormalGradFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(4, inputs.size());
CHECK_EQ(1, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK_EQ(inputs[0].dims_.size(), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
}
size_t samples = inputs[0].dims_[0];
size_t channels = inputs[0].dims_[1];
size_t height = inputs[0].dims_[2];
size_t width = inputs[0].dims_[3];
CrossMapNormalGrad<Device>(outputs[0].getData(),
inputs[0].getData(),
inputs[1].getData(),
inputs[2].getData(),
inputs[3].getData(),
samples,
channels,
height,
width,
size_,
scale_,
pow_);
}
private:
size_t size_;
real scale_;
real pow_;
};
REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc);
REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc);
REGISTER_TYPED_FUNC(CrossMapNormalGrad, GPU, CrossMapNormalGradFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
/**
* \brief Cross map respose normalize forward.
* The data structure of image data is NCHW.
*
* \param[out] outputs output data.
* \param[in] denoms denoms buffer.
* \param[in] inputs input data.
* \param[in] numSamples batch size of input image.
* \param[in] channels number of channel.
* \param[in] height image height.
* \param[in] width image width.
* \param[in] size size.
* \param[in] scale scale.
* \param[in] pow scale.
*
*/
template <DeviceType Device>
void CrossMapNormal(real* outputs,
real* denoms,
const real* inputs,
size_t numSamples,
size_t channels,
size_t height,
size_t width,
size_t size,
real scale,
real pow);
/**
* \brief Cross map respose normalize backward.
* The data structure of image data is NCHW.
*
* \param[out] inputsGrad input grad.
* \param[in] inputsValue input value.
* \param[out] outputsValue output value.
* \param[out] outputsGrad output grad.
* \param[in] denoms denoms buffer.
* \param[in] numSamples batch size of input image.
* \param[in] channels number of channel.
* \param[in] height image height.
* \param[in] width image width.
* \param[in] size size.
* \param[in] scale scale.
* \param[in] pow scale.
*
*/
template <DeviceType Device>
void CrossMapNormalGrad(real* inputsGrad,
const real* inputsValue,
const real* outputsValue,
const real* outputsGrad,
const real* denoms,
size_t numSamples,
size_t channels,
size_t height,
size_t width,
size_t size,
real scale,
real pow);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "cross_map_normal_op.h"
namespace paddle {
__global__ void KeCMRNormFillScale(size_t imageSize, const real* in,
real* scale, size_t channels,
size_t height, size_t width, size_t size,
real alpha) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < imageSize) {
const int w = idx % width;
const int h = (idx / width) % height;
const int n = idx / width / height;
const int offset = (n * channels * height + h) * width + w;
in += offset;
scale += offset;
const int step = height * width;
const int pre_pad = (size - 1) / 2;
const int post_pad = size - pre_pad - 1;
real accum = 0;
int index = 0;
while (index < channels + post_pad) {
if (index < channels) {
accum += in[index * step] * in[index * step];
}
if (index >= size) {
accum -= in[(index - size) * step] * in[(index - size) * step];
}
if (index >= post_pad) {
scale[(index - post_pad) * step] = 1. + accum * alpha;
}
++index;
}
}
}
__global__ void KeCMRNormOutput(size_t inputSize, const real* in,
const real* scale, real negative_beta,
real* out) {
const int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < inputSize) {
out[index] = in[index] * pow(scale[index], negative_beta);
}
}
template <>
void CrossMapNormal<DEVICE_TYPE_GPU>(real* outputs,
real* denoms,
const real* inputs,
size_t numSamples,
size_t channels,
size_t height,
size_t width,
size_t size,
real scale,
real pow) {
size_t imageSize = numSamples * height * width;
int blockSize = 1024;
int gridSize = (imageSize + 1024 - 1) / 1024;
KeCMRNormFillScale<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(imageSize, inputs, denoms, channels, height, width, size, scale);
size_t inputSize = numSamples * height * width *channels;
blockSize = 1024;
gridSize = (inputSize + 1024 - 1) / 1024;
KeCMRNormOutput<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(inputSize, inputs, denoms, -pow, outputs);
CHECK_SYNC("CrossMapNormal");
}
__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data,
const real* top_data, const real* scale,
const real* top_diff, size_t channels,
size_t height, size_t width, size_t size,
real negative_beta, real cache_ratio,
real* bottom_diff ) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < imageSize) {
const int w = idx % width;
const int h = (idx / width) % height;
const int n = idx / width / height;
const int offset = (n * channels * height + h) * width + w;
bottom_data += offset;
top_data += offset;
scale += offset;
top_diff += offset;
bottom_diff += offset;
const int step = height * width;
const int pre_pad = size - (size + 1) / 2;
const int post_pad = size - pre_pad - 1;
int index = 0;
real accum = 0;
while (index < channels + post_pad) {
if (index < channels) {
accum += top_diff[index * step] * top_data[index * step] /
scale[index * step];
}
if (index >= size) {
accum -= top_diff[(index - size) * step] *
top_data[(index - size) * step] / scale[(index - size) * step];
}
if (index >= post_pad) {
bottom_diff[(index - post_pad) * step] +=
top_diff[(index - post_pad) * step] *
pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(index - post_pad) * step] * accum;
}
++index;
}
}
}
template <>
void CrossMapNormalGrad<DEVICE_TYPE_GPU>(real* inputsGrad,
const real* inputsValue,
const real* outputsValue,
const real* outputsGrad,
const real* denoms,
size_t numSamples,
size_t channels,
size_t height,
size_t width,
size_t size,
real scale,
real pow) {
size_t imageSize = numSamples * height * width;
int blockSize = 1024;
int gridSize = (imageSize + 1024 - 1) / 1024;
KeCMRNormDiff <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(imageSize, inputsValue, outputsValue, denoms, outputsGrad, channels,
height, width, size, -pow, 2.0f * pow * scale, inputsGrad);
CHECK_SYNC("CrossMapNormalGrad");
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
TEST(CrossMapNormal, real) {
for (size_t numSamples : {5, 32}) {
for (size_t channels : {1, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
for (size_t size : {1, 2, 3, 5, 7}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW
<< " size=" << size;
FunctionCompare compare("CrossMapNormal",
FuncConfig()
.set("size", size)
.set("scale", (real)1.5)
.set("pow", (real)0.5));
Dims dims{numSamples, channels, imgSizeH, imgSizeW};
compare.cmpWithArg({Tensor(nullptr, dims)},
{Tensor(nullptr, dims), Tensor(nullptr, dims)},
{});
}
}
}
}
}
}
TEST(CrossMapNormalGrad, real) {
for (size_t numSamples : {5, 32}) {
for (size_t channels : {1, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
for (size_t size : {1, 2, 3, 5, 7}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW
<< " size=" << size;
FunctionCompare compare("CrossMapNormalGrad",
FuncConfig()
.set("size", size)
.set("scale", (real)1.5)
.set("pow", (real)0.5));
Dims dims{numSamples, channels, imgSizeH, imgSizeW};
compare.cmpWithArg({Tensor(nullptr, dims),
Tensor(nullptr, dims),
Tensor(nullptr, dims),
Tensor(nullptr, dims)},
{Tensor(nullptr, dims)},
{});
}
}
}
}
}
}
......@@ -27,16 +27,12 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM GSERVER_HEADER
layers/CudnnConvLayer.h
layers/CudnnPoolLayer.h
layers/CudnnBatchNormLayer.h
layers/NormProjectionLayer.h
layers/NormLayer.h)
layers/CudnnBatchNormLayer.h)
list(REMOVE_ITEM GSERVER_SOURCES
layers/CudnnConvLayer.cpp
layers/CudnnPoolLayer.cpp
layers/CudnnBatchNormLayer.cpp
layers/NormProjectionLayer.cpp
layers/NormLayer.cpp)
layers/CudnnBatchNormLayer.cpp)
compile_cu_as_cpp(layers/LstmCompute.cu)
compile_cu_as_cpp(layers/GruCompute.cu)
endif()
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <functional>
#include <memory>
#include "ModelConfig.pb.h"
#include "paddle/function/Function.h"
#include "paddle/math/CpuSparseMatrix.h"
#include "paddle/parameter/Parameter.h"
#include "paddle/utils/ClassRegistrar.h"
......@@ -100,6 +101,11 @@ protected:
/// Mark input grad in(true) or out(false) of backward function.
std::vector<bool> markInBackward_;
/// Layer forward function
std::vector<std::shared_ptr<FunctionBase>> forward_;
/// Layer backward function
std::vector<std::shared_ptr<FunctionBase>> backward_;
public:
/**
* Wait until all input value ready.
......@@ -126,6 +132,26 @@ public:
virtual void markAllInputGrad();
protected:
/**
* Create layer function. Function is called in forward or backward.
* \param function, Layer::forward_ or Layer::backward_
* \param name, function name
* \param config, initialization configuration for the function
*/
void createFunction(std::vector<std::shared_ptr<FunctionBase>>& function,
const std::string& name,
const FuncConfig& config) {
if (useGpu_) {
function.emplace_back(
FunctionBase::funcRegistrar_.createByType(name + "-GPU"));
} else {
function.emplace_back(
FunctionBase::funcRegistrar_.createByType(name + "-CPU"));
}
auto& func = function.back();
func->init(config);
}
/**
* Notify specified layer the output grad ready.
* Called in the backward function.
......
......@@ -45,6 +45,15 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
/* the size of inputs for norm-layer is 1 */
CHECK_EQ(config_.inputs_size(), 1);
createFunction(
forward_,
"CrossMapNormal",
FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_));
createFunction(
backward_,
"CrossMapNormalGrad",
FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_));
return true;
}
......@@ -54,7 +63,7 @@ void CMRProjectionNormLayer::forward(PassType passType) {
/* malloc memory for the output_ if necessary */
/* note: one sample correspond to one row */
MatrixPtr input = inputLayers_[0]->getOutputValue();
int batchSize = input->getHeight();
size_t batchSize = input->getHeight();
int size = getSize();
resetOutput(batchSize, size);
......@@ -62,10 +71,11 @@ void CMRProjectionNormLayer::forward(PassType passType) {
Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_);
denoms_->zeroMem();
outV->crossMapNormalFwd(
*input, imgSizeH_, imgSizeW_, *denoms_, channels_, size_, scale_, pow_);
dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_};
forward_[0]->calc(
{Tensor(input->getData(), dims_)},
{Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)},
{});
}
void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
......@@ -80,15 +90,11 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
MatrixPtr localOutV = getOutputValue();
MatrixPtr preOutV = inputLayers_[0]->getOutputValue();
preOutGrad->crossMapNormalBwd(*localGrad,
*denoms_,
*preOutV,
*localOutV,
channels_,
imgSizeH_,
imgSizeW_,
size_,
scale_,
pow_);
backward_[0]->calc({Tensor(preOutV->getData(), dims_),
Tensor(localOutV->getData(), dims_),
Tensor(localGrad->getData(), dims_),
Tensor(denoms_->getData(), dims_)},
{Tensor(preOutGrad->getData(), dims_)},
{});
}
} // namespace paddle
......@@ -39,5 +39,8 @@ public:
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
protected:
Dims dims_;
};
} // namespace paddle
......@@ -1021,11 +1021,10 @@ void testNormLayer(const string& normType, bool trans, bool useGpu) {
testLayerGrad(config, "norm", 100, trans, useGpu);
}
#ifndef PADDLE_ONLY_CPU
TEST(Layer, NormLayer) {
testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ true);
testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ false);
}
#endif
void setPoolConfig(TestConfig* config,
PoolConfig* pool,
......
......@@ -1262,69 +1262,6 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad,
outGrad.getStride());
}
void GpuMatrix::crossMapNormalFwd(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
Matrix& denoms,
size_t channels,
size_t sizeX,
float scale,
float pow) {
size_t num = input.getHeight();
size_t height = imgSizeH;
size_t width = imgSizeW;
CHECK(height * width * channels == input.getWidth());
CHECK(denoms.getHeight() == input.getHeight() &&
denoms.getWidth() == input.getWidth() && input.getHeight() == height_ &&
input.getWidth() == width_);
hl_CMRNorm_forward(num,
input.getData(),
denoms.getData(),
data_,
channels,
height,
width,
sizeX,
scale,
-pow);
}
void GpuMatrix::crossMapNormalBwd(Matrix& localGrad,
Matrix& denoms,
Matrix& preOutV,
Matrix& localOutV,
size_t channels,
size_t imgSizeH,
size_t imgSizeW,
size_t sizeX,
float scale,
float pow) {
size_t num = preOutV.getHeight();
size_t height = imgSizeH;
size_t width = imgSizeW;
CHECK(width * height * channels == preOutV.getWidth());
CHECK(denoms.getHeight() == preOutV.getHeight() &&
denoms.getWidth() == preOutV.getWidth() &&
preOutV.getHeight() == height_ && preOutV.getWidth() == width_);
CHECK(denoms.getHeight() == localGrad.getHeight() &&
denoms.getWidth() == localGrad.getWidth());
hl_CMRNorm_backward(num,
preOutV.getData(),
denoms.getData(),
localOutV.getData(),
localGrad.getData(),
data_,
channels,
height,
width,
sizeX,
-pow,
2.0f * pow * scale);
}
void GpuMatrix::maxSequenceForward(Matrix& input,
const IVector& sequence,
IVector& index) {
......@@ -2192,84 +2129,6 @@ void CpuMatrix::avgPoolBackward(Matrix& input,
}
}
void CpuMatrix::crossMapNormalFwd(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
Matrix& denoms,
size_t channels,
size_t sizeX,
float scale,
float pow) {
size_t num = input.getHeight();
size_t height = imgSizeH;
size_t width = imgSizeW;
size_t numCols = input.getWidth();
CHECK(height * width * channels == input.getWidth());
CHECK(denoms.getHeight() == input.getHeight() &&
denoms.getWidth() == input.getWidth() && input.getHeight() == height_ &&
input.getWidth() == width_);
real* imgData = input.getData();
real* diffData = input.getData();
real* targetData = getData();
size_t halfSize = sizeX / 2;
size_t imgPixels = height * width;
// use integral vector to implement the sum in local window
real* integralData =
(real*)malloc((channels + sizeX + 1) * sizeof(real)); // NOLINT // TODO:
for (size_t i = 0; i <= halfSize; i++) {
integralData[i] = 0;
}
for (size_t i = 0; i < num; i++) {
real* targetPtr = targetData + i * numCols;
real* imgPtr = imgData + i * numCols;
real* diffPtr = diffData + i * numCols;
for (size_t m = 0; m < height; m++) {
for (size_t n = 0; n < width; n++) {
for (size_t c = 0; c < channels; c++) {
integralData[c + halfSize + 1] =
integralData[c + halfSize] + _square(*(diffPtr + c * imgPixels));
}
for (size_t k = channels + halfSize + 1; k <= channels + sizeX; k++) {
integralData[k] = integralData[channels + halfSize];
}
for (size_t k = 0; k < channels; k += 1) {
real a = integralData[k + sizeX] - integralData[k];
a = scale * a + 1;
targetPtr[k * imgPixels] = imgPtr[k * imgPixels] * _pow(a, -pow);
}
diffPtr++;
targetPtr++;
imgPtr++;
}
}
}
free(integralData);
integralData = NULL;
}
void CpuMatrix::crossMapNormalBwd(Matrix& localGrad,
Matrix& denoms,
Matrix& preOutV,
Matrix& localOutV,
size_t channels,
size_t imgSizeH,
size_t imgSizeW,
size_t size,
float scale,
float pow) {
LOG(FATAL) << "Not implemented";
CHECK(imgSizeH * imgSizeW * channels == preOutV.getWidth());
CHECK(denoms.getHeight() == preOutV.getHeight() &&
denoms.getWidth() == preOutV.getWidth() &&
preOutV.getHeight() == height_ && preOutV.getWidth() == width_);
CHECK(denoms.getHeight() == localGrad.getHeight() &&
denoms.getWidth() == localGrad.getWidth());
// NOLINT // TODO:
}
/**
* Input: one or more sequences. Each sequence contains some instances.
* Output: output size is the number of input sequences (NOT input instances).
......
......@@ -952,31 +952,6 @@ public:
LOG(FATAL) << "Not implemeted";
}
/// normalize-operation.
virtual void crossMapNormalFwd(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
Matrix& denoms,
size_t channels,
size_t sizeX,
float scale,
float pow) {
LOG(FATAL) << "Not implemeted";
}
virtual void crossMapNormalBwd(Matrix& localGrad,
Matrix& denoms,
Matrix& preOutV,
Matrix& localOutV,
size_t channels,
size_t imgSizeH,
size_t imgSizeW,
size_t size,
float scale,
float pow) {
LOG(FATAL) << "Not implemeted";
}
/**
* Input: one or more sequences. Each sequence contains some instances.
*
......@@ -1459,26 +1434,6 @@ public:
size_t paddingH,
size_t paddingW);
void crossMapNormalFwd(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
Matrix& denoms,
size_t channels,
size_t sizeX,
float scale,
float pow);
void crossMapNormalBwd(Matrix& localGrad,
Matrix& denoms,
Matrix& preOutV,
Matrix& localOutV,
size_t channels,
size_t imgSizeH,
size_t imgSizeW,
size_t sizeX,
float scale,
float pow);
void maxSequenceForward(Matrix& input,
const IVector& sequence,
IVector& index);
......@@ -1685,26 +1640,6 @@ public:
size_t paddingH,
size_t paddingW);
void crossMapNormalFwd(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
Matrix& denoms,
size_t channels,
size_t sizeX,
float scale,
float pow);
void crossMapNormalBwd(Matrix& localGrad,
Matrix& denoms,
Matrix& preOutV,
Matrix& localOutV,
size_t channels,
size_t imgSizeH,
size_t imgSizeW,
size_t sizeX,
float scale,
float pow);
void maxSequenceForward(Matrix& input,
const IVector& sequence,
IVector& index);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册