提交 a6f772be 编写于 作者: D dangqingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into batch_norm

...@@ -96,6 +96,7 @@ function(link_paddle_exe TARGET_NAME) ...@@ -96,6 +96,7 @@ function(link_paddle_exe TARGET_NAME)
target_circle_link_libraries(${TARGET_NAME} target_circle_link_libraries(${TARGET_NAME}
ARCHIVE_START ARCHIVE_START
paddle_gserver paddle_gserver
paddle_function
${METRIC_LIBS} ${METRIC_LIBS}
ARCHIVE_END ARCHIVE_END
paddle_pserver paddle_pserver
...@@ -106,6 +107,7 @@ function(link_paddle_exe TARGET_NAME) ...@@ -106,6 +107,7 @@ function(link_paddle_exe TARGET_NAME)
paddle_parameter paddle_parameter
paddle_proto paddle_proto
paddle_cuda paddle_cuda
paddle_test_main
${METRIC_LIBS} ${METRIC_LIBS}
${PROTOBUF_LIBRARY} ${PROTOBUF_LIBRARY}
${LIBGLOG_LIBRARY} ${LIBGLOG_LIBRARY}
......
...@@ -39,12 +39,20 @@ The general development workflow with Docker and Bazel is as follows: ...@@ -39,12 +39,20 @@ The general development workflow with Docker and Bazel is as follows:
code. This image contains all the development tools and code. This image contains all the development tools and
dependencies of PaddlePaddle. dependencies of PaddlePaddle.
.. code-block:: bash .. code-block:: bash
cd paddle cd paddle
docker build -t paddle:dev -f paddle/scripts/docker/Dockerfile . docker build -t paddle:dev -f paddle/scripts/docker/Dockerfile .
Sometimes docker build might suffer from a slow network connection to the official Ubuntu apt-source servers. In such case, we can specify an apt-source mirror server that is geologically nearer to us. In the following example, we specified an apt-source server that responds fast in China.You can specify the UBUNTU MIRROR with :code:`--build-arg UBUNTU_MIRROR` like the example below.
.. code-block:: bash
docker build \
--build-arg UBUNTU_MIRROR="http://mirrors.163.com" \
-t paddle:dev \
-f paddle/scripts/docker/Dockerfile .
3. Run the image as a container and mounting local source code 3. Run the image as a container and mounting local source code
directory into the container. This allows us to change the code on directory into the container. This allows us to change the code on
......
add_subdirectory(cuda) add_subdirectory(cuda)
add_subdirectory(function)
add_subdirectory(utils) add_subdirectory(utils)
add_subdirectory(math) add_subdirectory(math)
add_subdirectory(parameter) add_subdirectory(parameter)
......
...@@ -46,6 +46,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp ...@@ -46,6 +46,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp
WORKING_DIRECTORY ${PROJ_ROOT}/paddle WORKING_DIRECTORY ${PROJ_ROOT}/paddle
DEPENDS python_swig_sources DEPENDS python_swig_sources
paddle_parameter paddle_parameter
paddle_function
paddle_math paddle_math
paddle_utils paddle_utils
paddle_gserver paddle_gserver
......
...@@ -30,8 +30,8 @@ try: ...@@ -30,8 +30,8 @@ try:
whole_end = "" whole_end = ""
LIB_DIRS = [ LIB_DIRS = [
"math", 'utils', 'parameter', "gserver", "api", "cuda", "pserver", "math", 'function', 'utils', 'parameter', "gserver", "api", "cuda",
"trainer" "pserver", "trainer"
] ]
PARENT_LIB_DIRS = ['proto'] PARENT_LIB_DIRS = ['proto']
...@@ -75,6 +75,7 @@ try: ...@@ -75,6 +75,7 @@ try:
libs = [ libs = [
whole_start, whole_start,
"-lpaddle_gserver", "-lpaddle_gserver",
"-lpaddle_function",
whole_end, whole_end,
"-lpaddle_pserver", "-lpaddle_pserver",
"-lpaddle_trainer_lib", "-lpaddle_trainer_lib",
......
...@@ -240,62 +240,6 @@ extern void hl_avgpool_backward(const int frameCnt, ...@@ -240,62 +240,6 @@ extern void hl_avgpool_backward(const int frameCnt,
real* backGrad, real* backGrad,
const int outStride); const int outStride);
/**
* @brief Cross-map-respose normalize forward.
*
* @param[in] frameCnt batch size of input image.
* @param[in] in input data.
* @param[in] scale buffer.
* @param[out] out output data.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] sizeX size.
* @param[in] alpha scale.
* @param[in] beta scale.
*
*/
extern void hl_CMRNorm_forward(size_t frameCnt,
const real* in,
real* scale,
real* out,
size_t channels,
size_t height,
size_t width,
size_t sizeX,
real alpha,
real beta);
/**
* @brief Cross-map-respose normalize backward.
*
* @param[in] frameCnt batch size of input image.
* @param[in] inV input data.
* @param[in] scale buffer.
* @param[out] outV output value.
* @param[out] outDiff output grad.
* @param[out] inDiff input grad.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] sizeX size.
* @param[in] alpha scale.
* @param[in] beta scale.
*
*/
extern void hl_CMRNorm_backward(size_t frameCnt,
const real* inV,
const real* scale,
const real* outV,
const real* outDiff,
real* inDiff,
size_t channels,
size_t height,
size_t width,
size_t sizeX,
real alpha,
real beta);
/** /**
* @brief Bilinear interpolation forward. * @brief Bilinear interpolation forward.
* *
......
...@@ -117,30 +117,6 @@ inline void hl_avgpool_backward(const int frameCnt, ...@@ -117,30 +117,6 @@ inline void hl_avgpool_backward(const int frameCnt,
real* backGrad, real* backGrad,
const int outStride) {} const int outStride) {}
inline void hl_CMRNorm_forward(size_t frameCnt,
const real* in,
real* scale,
real* out,
size_t channels,
size_t height,
size_t width,
size_t sizeX,
real alpha,
real beta) {}
inline void hl_CMRNorm_backward(size_t frameCnt,
const real* inV,
const real* scale,
const real* outV,
const real* outDiff,
real* inDiff,
size_t channels,
size_t height,
size_t width,
size_t sizeX,
real alpha,
real beta) {}
inline void hl_bilinear_forward(const real* inData, inline void hl_bilinear_forward(const real* inData,
const size_t inImgH, const size_t inImgH,
const size_t inImgW, const size_t inImgW,
......
...@@ -381,164 +381,6 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad, ...@@ -381,164 +381,6 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad,
CHECK_SYNC("hl_avgpool_backward failed"); CHECK_SYNC("hl_avgpool_backward failed");
} }
__global__ void KeCMRNormFillScale(size_t nthreads, const real* in,
real* scale, size_t channels,
size_t height, size_t width, size_t size,
real alpha) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < nthreads) {
// find out the local offset
size_t w = index % width;
size_t h = (index / width) % height;
size_t n = index / width / height;
size_t offset = (n * channels * height + h) * width + w;
size_t step = height * width;
in += offset;
scale += offset;
size_t head = 0;
size_t pre_pad = (size - 1) / 2;
size_t post_pad = size - pre_pad - 1;
real accum_scale = 0;
// fill the scale at [n, :, h, w]
// accumulate values
while (head < post_pad) {
accum_scale += in[head * step] * in[head * step];
++head;
}
// until we reach size, nothing needs to be subtracted
while (head < size) {
accum_scale += in[head * step] * in[head * step];
scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
++head;
}
// both add and subtract
while (head < channels) {
accum_scale += in[head * step] * in[head * step];
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
++head;
}
// subtract only
while (head < channels + post_pad) {
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
++head;
}
}
}
__global__ void KeCMRNormOutput(size_t nthreads, const real* in,
const real* scale, real negative_beta,
real* out) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < nthreads) {
out[index] = in[index] * pow(scale[index], negative_beta);
}
}
void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale,
real* out, size_t channels,
size_t height, size_t width, size_t sizeX,
real alpha, real beta) {
size_t threadsNum = frameCnt * height * width;
size_t blocksX = (threadsNum + 1024 - 1) / 1024;
size_t blocksY = 1;
dim3 threads(1024, 1);
dim3 grid(blocksX, blocksY);
KeCMRNormFillScale<<<grid, threads, 0, STREAM_DEFAULT>>>
(threadsNum, in, scale, channels, height, width, sizeX, alpha);
threadsNum = frameCnt * height * width *channels;
blocksX = (threadsNum + 1024 -1) / 1024;
dim3 threads2(1024, 1);
dim3 grid2(blocksX, blocksY);
KeCMRNormOutput<<<grid2, threads2, 0, STREAM_DEFAULT>>>
(threadsNum, in, scale, beta, out);
CHECK_SYNC("hl_CMRNorm_forward");
}
__global__ void KeCMRNormDiff(size_t nthreads, const real* bottom_data,
const real* top_data, const real* scale,
const real* top_diff, size_t channels,
size_t height, size_t width, size_t size,
real negative_beta, real cache_ratio,
real* bottom_diff ) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < nthreads) {
// find out the local offset
size_t w = index % width;
size_t h = (index / width) % height;
size_t n = index / width / height;
size_t offset = (n * channels * height + h) * width + w;
size_t step = height * width;
bottom_data += offset;
top_data += offset;
scale += offset;
top_diff += offset;
bottom_diff += offset;
int head = 0;
int pre_pad = size - (size + 1) / 2;
int post_pad = size - pre_pad - 1;
real accum_ratio = 0;
// accumulate values
while (head < post_pad) {
accum_ratio += top_diff[head * step] *
top_data[head * step] / scale[head * step];
++head;
}
// until we reach size, nothing needs to be subtracted
while (head < size) {
accum_ratio += top_diff[head * step] *
top_data[head * step] / scale[head * step];
bottom_diff[(head - post_pad) * step] +=
top_diff[(head - post_pad) * step] *
pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
// both add and subtract
while (head < channels) {
accum_ratio += top_diff[head * step] * top_data[head * step] /
scale[head * step];
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
bottom_diff[(head - post_pad) * step] +=
top_diff[(head - post_pad) * step] *
pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
// subtract only
while (head < channels + post_pad) {
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
bottom_diff[(head - post_pad) * step] +=
top_diff[(head - post_pad) * step] *
pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
}
}
void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
const real* scale,
const real* outV, const real* outDiff,
real *inDiff, size_t channels,
size_t height, size_t width, size_t sizeX,
real alpha, real beta) {
size_t threadsNum = frameCnt * height * width;
size_t blocksX = (threadsNum + 1024 - 1) / 1024;
size_t blocksY = 1;
dim3 threads(1024, 1);
dim3 grid(blocksX, blocksY);
KeCMRNormDiff <<<grid, threads, 0, STREAM_DEFAULT>>>
(threadsNum, inV, outV, scale, outDiff, channels,
height, width, sizeX, alpha, beta, inDiff);
CHECK_SYNC("hl_CMRNorm_backward");
}
__global__ void KeBilinearInterpFw(const real* in, __global__ void KeBilinearInterpFw(const real* in,
const size_t inImgH, const size_t inImgH,
const size_t inImgW, const size_t inImgW,
......
file(GLOB h_files . *_op.h)
file(GLOB cpp_files . *_op.cpp)
list(APPEND h_files Function.h)
list(APPEND cpp_files Function.cpp)
if(WITH_GPU)
file(GLOB cu_files . *_op_gpu.cu)
cuda_compile(cu_objs ${cu_files})
endif()
add_library(paddle_function STATIC ${cpp_files} ${cu_objs})
add_library(paddle_test_main STATIC TestMain.cpp)
if(WITH_GPU)
# TODO:
# file(GLOB test_files . *_op_test.cpp)
# add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files})
add_simple_unittest(cross_map_normal_op_test)
endif()
add_style_check_target(paddle_function ${h_files})
add_style_check_target(paddle_function ${cpp_files})
if(WITH_GPU)
add_style_check_target(paddle_function ${cu_files})
endif()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Function.h"
namespace paddle {
template <>
size_t FuncConfig::get<size_t>(const std::string& key) const {
auto it = valueMap_.find(key);
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
return it->second.s;
}
template <>
real FuncConfig::get<real>(const std::string& key) const {
auto it = valueMap_.find(key);
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
return it->second.r;
}
template <>
FuncConfig& FuncConfig::set<size_t>(const std::string& key, size_t v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
valueMap_[key].s = v;
return *this;
}
template <>
FuncConfig& FuncConfig::set<real>(const std::string& key, real v) {
CHECK(valueMap_.count(key) == 0) << "Duplicated value: " << key;
valueMap_[key].r = v;
return *this;
}
ClassRegistrar<FunctionBase> FunctionBase::funcRegistrar_;
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <vector>
#include "paddle/math/Matrix.h"
#include "paddle/utils/ClassRegistrar.h"
namespace paddle {
enum DeviceType {
DEVICE_TYPE_UNSPECIFIED = 0,
DEVICE_TYPE_CPU = 1,
DEVICE_TYPE_GPU = 2,
};
template <DeviceType Device>
struct MatrixT;
template <>
struct MatrixT<DEVICE_TYPE_CPU> {
using type = CpuMatrix;
};
template <>
struct MatrixT<DEVICE_TYPE_GPU> {
using type = GpuMatrix;
};
typedef std::vector<size_t> Dims;
class Tensor {
public:
Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {}
real* getData() const { return buf_; }
real* buf_;
Dims dims_;
};
typedef std::vector<Tensor> Arguments;
class FuncConfig {
public:
union value {
size_t s;
real r;
};
template <typename T>
T get(const std::string& key) const;
template <typename T>
FuncConfig& set(const std::string& key, T v);
protected:
std::map<std::string, value> valueMap_;
};
class FunctionBase {
public:
virtual ~FunctionBase() {}
virtual void init(const FuncConfig& config) {}
virtual void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) {}
static ClassRegistrar<FunctionBase> funcRegistrar_;
};
#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName
#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \
static InitFunction __reg_type_##typeName##deviceName([]() { \
FunctionBase::funcRegistrar_ \
.registerClass<className<DEVICE_TYPE_##deviceName>>( \
FUNC_NAME(typeName, deviceName)); \
})
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Function.h"
#include "paddle/math/Vector.h"
#include "paddle/math/tests/TensorCheck.h"
namespace paddle {
class FunctionCompare {
public:
FunctionCompare(const std::string& name, const FuncConfig& config)
: cpu(FunctionBase::funcRegistrar_.createByType(name + "-CPU")),
gpu(FunctionBase::funcRegistrar_.createByType(name + "-GPU")) {
cpu->init(config);
gpu->init(config);
}
void cmpWithArg(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) {
// init cpu and gpu arguments
auto initArgs = [=](
Arguments& cpuArgs, Arguments& gpuArgs, const Arguments& inArgs) {
for (auto arg : inArgs) {
size_t size = sizeof(real);
for (auto dim : arg.dims_) {
size *= dim;
}
cpuMemory.emplace_back(std::make_shared<CpuMemoryHandle>(size));
gpuMemory.emplace_back(std::make_shared<GpuMemoryHandle>(size));
cpuArgs.emplace_back(
Tensor((real*)cpuMemory.back()->getBuf(), arg.dims_));
gpuArgs.emplace_back(
Tensor((real*)gpuMemory.back()->getBuf(), arg.dims_));
// will use an api to refactor this code.
CpuVector cpuVector(size / sizeof(real),
(real*)cpuArgs.back().getData());
GpuVector gpuVector(size / sizeof(real),
(real*)gpuArgs.back().getData());
cpuVector.uniform(0.001, 1);
gpuVector.copyFrom(cpuVector);
}
};
initArgs(cpuInputs, gpuInputs, inputs);
initArgs(cpuOutputs, gpuOutputs, outputs);
initArgs(cpuInouts, gpuInouts, inouts);
// function calculate
cpu->calc(cpuInputs, cpuOutputs, cpuInouts);
gpu->calc(gpuInputs, gpuOutputs, gpuInouts);
// check outputs and inouts
auto checkArgs = [=](const Arguments& cpuArgs, const Arguments& gpuArgs) {
for (size_t i = 0; i < cpuArgs.size(); i++) {
auto cpu = cpuArgs[i];
auto gpu = gpuArgs[i];
size_t size = 1;
for (auto dim : cpu.dims_) {
size *= dim;
}
CpuVector cpuVector(size, (real*)cpu.getData());
GpuVector gpuVector(size, (real*)gpu.getData());
autotest::TensorCheckErr(cpuVector, gpuVector);
}
};
checkArgs(cpuOutputs, gpuOutputs);
checkArgs(cpuInouts, gpuInouts);
}
protected:
std::shared_ptr<FunctionBase> cpu;
std::shared_ptr<FunctionBase> gpu;
std::vector<CpuMemHandlePtr> cpuMemory;
std::vector<GpuMemHandlePtr> gpuMemory;
Arguments cpuInputs;
Arguments cpuOutputs;
Arguments cpuInouts;
Arguments gpuInputs;
Arguments gpuOutputs;
Arguments gpuInouts;
};
} // namespace paddle
using paddle::FunctionCompare;
using paddle::FuncConfig;
using paddle::Dims;
using paddle::Tensor;
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/utils/Util.h"
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
paddle::initMain(argc, argv);
return RUN_ALL_TESTS();
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cross_map_normal_op.h"
#include "paddle/math/Vector.h"
namespace paddle {
template <>
void CrossMapNormal<DEVICE_TYPE_CPU>(real* outputs,
real* denoms,
const real* inputs,
size_t numSamples,
size_t channels,
size_t height,
size_t width,
size_t size,
real scale,
real pow) {
size_t oneImage = height * width;
size_t oneSample = channels * oneImage;
CpuVector outputsV(numSamples * oneSample, outputs);
CpuVector inputsV(numSamples * oneSample, const_cast<real*>(inputs));
CpuVector denomsV(numSamples * oneSample, denoms);
// f(x) = x * ( 1 + scale * SUM((x)^2) )^(-pow)
// x represents inputs
// f(x) represents outputs
// denoms save the intermediate result for backward
denomsV = denomsV.constant(1.0);
const int start = -((int)size - 1) / 2;
const int end = (int)size + start;
for (size_t i = 0; i < numSamples; i++) {
real* oneDenom = denoms + i * oneSample;
real* oneInput = const_cast<real*>(inputs) + i * oneSample;
for (int c = 0; c < (int)channels; c++) {
CpuVector denom(oneImage, oneDenom + c * oneImage);
for (int s = start; s < end; s++) {
if (c + s >= 0 && c + s < (int)channels) {
CpuVector input(oneImage, oneInput + (c + s) * oneImage);
denom += input.square() * scale;
}
}
}
}
outputsV = inputsV * denomsV.pow(-pow);
}
template <>
void CrossMapNormalGrad<DEVICE_TYPE_CPU>(real* inputsGrad,
const real* inputsValue,
const real* outputsValue,
const real* outputsGrad,
const real* denoms,
size_t numSamples,
size_t channels,
size_t height,
size_t width,
size_t size,
real scale,
real pow) {
size_t oneSample = channels * height * width;
std::function<CpuVector(real*, size_t)> oneImage = [=](real* data,
size_t offset) {
return CpuVector(height * width, data + offset);
};
const int start = -((int)size) / 2;
const int end = (int)size + start;
const real ratio = -(real)2 * scale * pow;
for (size_t i = 0; i < numSamples; i++) {
size_t sOffset = i * oneSample;
real* oneInputGrad = inputsGrad + sOffset;
real* oneInputValue = const_cast<real*>(inputsValue) + sOffset;
real* oneDenom = const_cast<real*>(denoms) + sOffset;
real* oneOutputGrad = const_cast<real*>(outputsGrad) + sOffset;
real* oneOutputValue = const_cast<real*>(outputsValue) + sOffset;
for (int c = 0; c < (int)channels; c++) {
size_t cOffset = c * height * width;
CpuVector inputGrad = oneImage(oneInputGrad, cOffset);
CpuVector inputValue = oneImage(oneInputValue, cOffset);
CpuVector denom = oneImage(oneDenom, cOffset);
CpuVector outputGrad = oneImage(oneOutputGrad, cOffset);
inputGrad = inputGrad + denom.pow(-pow) * outputGrad;
for (int s = start; s < end; s++) {
if (c + s >= 0 && c + s < (int)channels) {
size_t offset = (c + s) * height * width;
CpuVector output = oneImage(oneOutputValue, offset);
CpuVector outputGrad = oneImage(oneOutputGrad, offset);
CpuVector denom = oneImage(oneDenom, offset);
inputGrad += ((outputGrad * output * ratio) / denom) * inputValue;
}
}
}
}
}
/**
* \param inputs[0] input value.
* \param outputs[0] output value.
* \param outputs[1] denoms.
*/
template <DeviceType Device>
class CrossMapNormalFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(1, inputs.size());
CHECK_EQ(2, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK_EQ(inputs[0].dims_.size(), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]);
}
size_t samples = inputs[0].dims_[0];
size_t channels = inputs[0].dims_[1];
size_t height = inputs[0].dims_[2];
size_t width = inputs[0].dims_[3];
CrossMapNormal<Device>(outputs[0].getData(),
outputs[1].getData(),
inputs[0].getData(),
samples,
channels,
height,
width,
size_,
scale_,
pow_);
}
private:
size_t size_;
real scale_;
real pow_;
};
/**
* \param inputs[0] input value.
* \param inputs[1] output value.
* \param inputs[2] output grad.
* \param inputs[3] denoms.
* \param outputs[0] input grad.
*/
template <DeviceType Device>
class CrossMapNormalGradFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow");
}
void calc(const Arguments& inputs,
const Arguments& outputs,
const Arguments& inouts) override {
CHECK_EQ(4, inputs.size());
CHECK_EQ(1, outputs.size());
CHECK_EQ(0, inouts.size());
CHECK_EQ(inputs[0].dims_.size(), 4);
for (size_t i = 0; i < inputs[0].dims_.size(); i++) {
CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]);
CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]);
}
size_t samples = inputs[0].dims_[0];
size_t channels = inputs[0].dims_[1];
size_t height = inputs[0].dims_[2];
size_t width = inputs[0].dims_[3];
CrossMapNormalGrad<Device>(outputs[0].getData(),
inputs[0].getData(),
inputs[1].getData(),
inputs[2].getData(),
inputs[3].getData(),
samples,
channels,
height,
width,
size_,
scale_,
pow_);
}
private:
size_t size_;
real scale_;
real pow_;
};
REGISTER_TYPED_FUNC(CrossMapNormal, CPU, CrossMapNormalFunc);
REGISTER_TYPED_FUNC(CrossMapNormalGrad, CPU, CrossMapNormalGradFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(CrossMapNormal, GPU, CrossMapNormalFunc);
REGISTER_TYPED_FUNC(CrossMapNormalGrad, GPU, CrossMapNormalGradFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
/**
* \brief Cross map respose normalize forward.
* The data structure of image data is NCHW.
*
* \param[out] outputs output data.
* \param[in] denoms denoms buffer.
* \param[in] inputs input data.
* \param[in] numSamples batch size of input image.
* \param[in] channels number of channel.
* \param[in] height image height.
* \param[in] width image width.
* \param[in] size size.
* \param[in] scale scale.
* \param[in] pow scale.
*
*/
template <DeviceType Device>
void CrossMapNormal(real* outputs,
real* denoms,
const real* inputs,
size_t numSamples,
size_t channels,
size_t height,
size_t width,
size_t size,
real scale,
real pow);
/**
* \brief Cross map respose normalize backward.
* The data structure of image data is NCHW.
*
* \param[out] inputsGrad input grad.
* \param[in] inputsValue input value.
* \param[out] outputsValue output value.
* \param[out] outputsGrad output grad.
* \param[in] denoms denoms buffer.
* \param[in] numSamples batch size of input image.
* \param[in] channels number of channel.
* \param[in] height image height.
* \param[in] width image width.
* \param[in] size size.
* \param[in] scale scale.
* \param[in] pow scale.
*
*/
template <DeviceType Device>
void CrossMapNormalGrad(real* inputsGrad,
const real* inputsValue,
const real* outputsValue,
const real* outputsGrad,
const real* denoms,
size_t numSamples,
size_t channels,
size_t height,
size_t width,
size_t size,
real scale,
real pow);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "cross_map_normal_op.h"
namespace paddle {
__global__ void KeCMRNormFillScale(size_t imageSize, const real* in,
real* scale, size_t channels,
size_t height, size_t width, size_t size,
real alpha) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < imageSize) {
const int w = idx % width;
const int h = (idx / width) % height;
const int n = idx / width / height;
const int offset = (n * channels * height + h) * width + w;
in += offset;
scale += offset;
const int step = height * width;
const int pre_pad = (size - 1) / 2;
const int post_pad = size - pre_pad - 1;
real accum = 0;
int index = 0;
while (index < channels + post_pad) {
if (index < channels) {
accum += in[index * step] * in[index * step];
}
if (index >= size) {
accum -= in[(index - size) * step] * in[(index - size) * step];
}
if (index >= post_pad) {
scale[(index - post_pad) * step] = 1. + accum * alpha;
}
++index;
}
}
}
__global__ void KeCMRNormOutput(size_t inputSize, const real* in,
const real* scale, real negative_beta,
real* out) {
const int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < inputSize) {
out[index] = in[index] * pow(scale[index], negative_beta);
}
}
template <>
void CrossMapNormal<DEVICE_TYPE_GPU>(real* outputs,
real* denoms,
const real* inputs,
size_t numSamples,
size_t channels,
size_t height,
size_t width,
size_t size,
real scale,
real pow) {
size_t imageSize = numSamples * height * width;
int blockSize = 1024;
int gridSize = (imageSize + 1024 - 1) / 1024;
KeCMRNormFillScale<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(imageSize, inputs, denoms, channels, height, width, size, scale);
size_t inputSize = numSamples * height * width *channels;
blockSize = 1024;
gridSize = (inputSize + 1024 - 1) / 1024;
KeCMRNormOutput<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(inputSize, inputs, denoms, -pow, outputs);
CHECK_SYNC("CrossMapNormal");
}
__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data,
const real* top_data, const real* scale,
const real* top_diff, size_t channels,
size_t height, size_t width, size_t size,
real negative_beta, real cache_ratio,
real* bottom_diff ) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < imageSize) {
const int w = idx % width;
const int h = (idx / width) % height;
const int n = idx / width / height;
const int offset = (n * channels * height + h) * width + w;
bottom_data += offset;
top_data += offset;
scale += offset;
top_diff += offset;
bottom_diff += offset;
const int step = height * width;
const int pre_pad = size - (size + 1) / 2;
const int post_pad = size - pre_pad - 1;
int index = 0;
real accum = 0;
while (index < channels + post_pad) {
if (index < channels) {
accum += top_diff[index * step] * top_data[index * step] /
scale[index * step];
}
if (index >= size) {
accum -= top_diff[(index - size) * step] *
top_data[(index - size) * step] / scale[(index - size) * step];
}
if (index >= post_pad) {
bottom_diff[(index - post_pad) * step] +=
top_diff[(index - post_pad) * step] *
pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(index - post_pad) * step] * accum;
}
++index;
}
}
}
template <>
void CrossMapNormalGrad<DEVICE_TYPE_GPU>(real* inputsGrad,
const real* inputsValue,
const real* outputsValue,
const real* outputsGrad,
const real* denoms,
size_t numSamples,
size_t channels,
size_t height,
size_t width,
size_t size,
real scale,
real pow) {
size_t imageSize = numSamples * height * width;
int blockSize = 1024;
int gridSize = (imageSize + 1024 - 1) / 1024;
KeCMRNormDiff <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(imageSize, inputsValue, outputsValue, denoms, outputsGrad, channels,
height, width, size, -pow, 2.0f * pow * scale, inputsGrad);
CHECK_SYNC("CrossMapNormalGrad");
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
TEST(CrossMapNormal, real) {
for (size_t numSamples : {5, 32}) {
for (size_t channels : {1, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
for (size_t size : {1, 2, 3, 5, 7}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW
<< " size=" << size;
FunctionCompare compare("CrossMapNormal",
FuncConfig()
.set("size", size)
.set("scale", (real)1.5)
.set("pow", (real)0.5));
Dims dims{numSamples, channels, imgSizeH, imgSizeW};
compare.cmpWithArg({Tensor(nullptr, dims)},
{Tensor(nullptr, dims), Tensor(nullptr, dims)},
{});
}
}
}
}
}
}
TEST(CrossMapNormalGrad, real) {
for (size_t numSamples : {5, 32}) {
for (size_t channels : {1, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
for (size_t size : {1, 2, 3, 5, 7}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW
<< " size=" << size;
FunctionCompare compare("CrossMapNormalGrad",
FuncConfig()
.set("size", size)
.set("scale", (real)1.5)
.set("pow", (real)0.5));
Dims dims{numSamples, channels, imgSizeH, imgSizeW};
compare.cmpWithArg({Tensor(nullptr, dims),
Tensor(nullptr, dims),
Tensor(nullptr, dims),
Tensor(nullptr, dims)},
{Tensor(nullptr, dims)},
{});
}
}
}
}
}
}
...@@ -27,16 +27,12 @@ if(NOT WITH_GPU) ...@@ -27,16 +27,12 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM GSERVER_HEADER list(REMOVE_ITEM GSERVER_HEADER
layers/CudnnConvLayer.h layers/CudnnConvLayer.h
layers/CudnnPoolLayer.h layers/CudnnPoolLayer.h
layers/CudnnBatchNormLayer.h layers/CudnnBatchNormLayer.h)
layers/NormProjectionLayer.h
layers/NormLayer.h)
list(REMOVE_ITEM GSERVER_SOURCES list(REMOVE_ITEM GSERVER_SOURCES
layers/CudnnConvLayer.cpp layers/CudnnConvLayer.cpp
layers/CudnnPoolLayer.cpp layers/CudnnPoolLayer.cpp
layers/CudnnBatchNormLayer.cpp layers/CudnnBatchNormLayer.cpp)
layers/NormProjectionLayer.cpp
layers/NormLayer.cpp)
compile_cu_as_cpp(layers/LstmCompute.cu) compile_cu_as_cpp(layers/LstmCompute.cu)
compile_cu_as_cpp(layers/GruCompute.cu) compile_cu_as_cpp(layers/GruCompute.cu)
endif() endif()
......
...@@ -78,7 +78,7 @@ public: ...@@ -78,7 +78,7 @@ public:
useGpu(arguments[0].deviceId)); useGpu(arguments[0].deviceId));
errorMat->zeroMem(); errorMat->zeroMem();
if (label != nullptr) { if (label != nullptr) {
errorMat->classificationError(output, label); errorMat->classificationError(*output, *label);
} else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) || } else if (dynamic_cast<CpuSparseMatrix*>(multiBinaryLabel.get()) ||
dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) { dynamic_cast<GpuSparseMatrix*>(multiBinaryLabel.get())) {
errorMat->classificationErrorMulti( errorMat->classificationErrorMulti(
......
...@@ -90,8 +90,8 @@ void ContextProjection::forward() { ...@@ -90,8 +90,8 @@ void ContextProjection::forward() {
REGISTER_TIMER_INFO("ContextProjectionForward", getName().c_str()); REGISTER_TIMER_INFO("ContextProjectionForward", getName().c_str());
bool isPadding = config_.trainable_padding(); bool isPadding = config_.trainable_padding();
out_->value->contextProjectionForward( out_->value->contextProjectionForward(
in_->value, *(in_->value),
state_ ? state_ : isPadding ? weight_->getW() : nullptr, state_ ? state_.get() : isPadding ? weight_->getW().get() : nullptr,
*startPositions, *startPositions,
config_.context_length(), config_.context_length(),
config_.context_start(), config_.context_start(),
...@@ -128,8 +128,8 @@ void ContextProjection::backward(const UpdateCallback& callback) { ...@@ -128,8 +128,8 @@ void ContextProjection::backward(const UpdateCallback& callback) {
bool isPadding = config_.trainable_padding(); bool isPadding = config_.trainable_padding();
if (!out_->grad->useGpu()) { if (!out_->grad->useGpu()) {
out_->grad->contextProjectionBackward( out_->grad->contextProjectionBackward(
in_->grad, in_->grad.get(),
isPadding ? weight_->getWGrad() : nullptr, isPadding ? weight_->getWGrad().get() : nullptr,
*startPositions, *startPositions,
config_.context_length(), config_.context_length(),
config_.context_start(), config_.context_start(),
...@@ -137,7 +137,7 @@ void ContextProjection::backward(const UpdateCallback& callback) { ...@@ -137,7 +137,7 @@ void ContextProjection::backward(const UpdateCallback& callback) {
isPadding); isPadding);
} else { } else {
if (in_->grad) { if (in_->grad) {
out_->grad->contextProjectionBackwardData(in_->grad, out_->grad->contextProjectionBackwardData(*(in_->grad),
*startPositions, *startPositions,
config_.context_length(), config_.context_length(),
config_.context_start()); config_.context_start());
...@@ -145,7 +145,7 @@ void ContextProjection::backward(const UpdateCallback& callback) { ...@@ -145,7 +145,7 @@ void ContextProjection::backward(const UpdateCallback& callback) {
if (isPadding && weight_->getWGrad()) { if (isPadding && weight_->getWGrad()) {
out_->grad->contextProjectionBackwardWeight( out_->grad->contextProjectionBackwardWeight(
weight_->getWGrad(), *(weight_->getWGrad()),
*startPositions, *startPositions,
config_.context_length(), config_.context_length(),
config_.context_start(), config_.context_start(),
......
...@@ -113,7 +113,7 @@ void ConvexCombinationLayer::forward(PassType passType) { ...@@ -113,7 +113,7 @@ void ConvexCombinationLayer::forward(PassType passType) {
tmpRow0->setData(inV0->getData() + i * weightDim); tmpRow0->setData(inV0->getData() + i * weightDim);
tmpRow1->setData(outV->getData() + i * dataDim); tmpRow1->setData(outV->getData() + i * dataDim);
tmpRow1->mul(tmpRow0, tmpMtx0, 1, 0); tmpRow1->mul(*tmpRow0, *tmpMtx0, 1, 0);
} }
} }
...@@ -136,7 +136,7 @@ void ConvexCombinationLayer::backward(const UpdateCallback& callback) { ...@@ -136,7 +136,7 @@ void ConvexCombinationLayer::backward(const UpdateCallback& callback) {
tmpRow1->setData(outG->getData() + i * dataDim); tmpRow1->setData(outG->getData() + i * dataDim);
tmpMtx0->setData(inV1->getData() + i * weightDim * dataDim); tmpMtx0->setData(inV1->getData() + i * weightDim * dataDim);
tmpRow0->mul(tmpRow1, tmpMtx0->getTranspose(), 1, 1); tmpRow0->mul(*tmpRow1, *(tmpMtx0->getTranspose()), 1, 1);
} }
} }
...@@ -146,7 +146,7 @@ void ConvexCombinationLayer::backward(const UpdateCallback& callback) { ...@@ -146,7 +146,7 @@ void ConvexCombinationLayer::backward(const UpdateCallback& callback) {
tmpRow1->setData(outG->getData() + i * dataDim); tmpRow1->setData(outG->getData() + i * dataDim);
tmpMtx0->setData(inG1->getData() + i * weightDim * dataDim); tmpMtx0->setData(inG1->getData() + i * weightDim * dataDim);
tmpMtx0->mul(tmpRow0->getTranspose(), tmpRow1, 1, 1); tmpMtx0->mul(*(tmpRow0->getTranspose()), *tmpRow1, 1, 1);
} }
} }
} }
......
...@@ -150,7 +150,7 @@ void ExpandConvBaseLayer::expandFwdOnce(MatrixPtr image, ...@@ -150,7 +150,7 @@ void ExpandConvBaseLayer::expandFwdOnce(MatrixPtr image,
Matrix::create(wgtData, subM, subK, false, useGpu_); // mark transpose Matrix::create(wgtData, subM, subK, false, useGpu_); // mark transpose
MatrixPtr B = Matrix::create(expInData, subK, subN, false, useGpu_); MatrixPtr B = Matrix::create(expInData, subK, subN, false, useGpu_);
MatrixPtr C = Matrix::create(outData, subM, subN, false, useGpu_); MatrixPtr C = Matrix::create(outData, subM, subN, false, useGpu_);
C->mul(A, B, 1, 1); C->mul(*A, *B, 1, 1);
A->clear(); A->clear();
B->clear(); B->clear();
...@@ -185,7 +185,7 @@ void ExpandConvBaseLayer::bpropActs(MatrixPtr out, ...@@ -185,7 +185,7 @@ void ExpandConvBaseLayer::bpropActs(MatrixPtr out,
MatrixPtr C = Matrix::create(expandInData, subK, subN, false, useGpu_); MatrixPtr C = Matrix::create(expandInData, subK, subN, false, useGpu_);
MatrixPtr B = Matrix::create(localGradData, subM, subN, false, useGpu_); MatrixPtr B = Matrix::create(localGradData, subM, subN, false, useGpu_);
MatrixPtr A = Matrix::create(wgtData, subM, subK, true, useGpu_); MatrixPtr A = Matrix::create(wgtData, subM, subK, true, useGpu_);
C->mul(A, B); // mul C->mul(*A, *B); // mul
// clear the temporary matrix // clear the temporary matrix
A->clear(); A->clear();
...@@ -252,7 +252,7 @@ void ExpandConvBaseLayer::bpropWeights(MatrixPtr image, ...@@ -252,7 +252,7 @@ void ExpandConvBaseLayer::bpropWeights(MatrixPtr image,
MatrixPtr A = Matrix::create(expandInData, subK, subN, true, useGpu_); MatrixPtr A = Matrix::create(expandInData, subK, subN, true, useGpu_);
MatrixPtr B = Matrix::create(gradData, subM, subN, false, useGpu_); MatrixPtr B = Matrix::create(gradData, subM, subN, false, useGpu_);
MatrixPtr C = Matrix::create(wGradData, subM, subK, false, useGpu_); MatrixPtr C = Matrix::create(wGradData, subM, subK, false, useGpu_);
C->mul(B, A, 1, 1); C->mul(*B, *A, 1, 1);
A->clear(); A->clear();
B->clear(); B->clear();
......
...@@ -28,7 +28,7 @@ FullMatrixProjection::FullMatrixProjection(const ProjectionConfig& config, ...@@ -28,7 +28,7 @@ FullMatrixProjection::FullMatrixProjection(const ProjectionConfig& config,
void FullMatrixProjection::forward() { void FullMatrixProjection::forward() {
REGISTER_TIMER_INFO("FwMulTimer", getName().c_str()); REGISTER_TIMER_INFO("FwMulTimer", getName().c_str());
out_->value->mul(in_->value, weight_->getW(), 1, 1); out_->value->mul(*(in_->value), *(weight_->getW()), 1, 1);
} }
void FullMatrixProjection::backward(const UpdateCallback& callback) { void FullMatrixProjection::backward(const UpdateCallback& callback) {
...@@ -37,7 +37,8 @@ void FullMatrixProjection::backward(const UpdateCallback& callback) { ...@@ -37,7 +37,8 @@ void FullMatrixProjection::backward(const UpdateCallback& callback) {
/* Calculate the W-gradient for the current layer */ /* Calculate the W-gradient for the current layer */
if (weight_->getWGrad()) { if (weight_->getWGrad()) {
REGISTER_TIMER_INFO("GradMulTimer", getName().c_str()); REGISTER_TIMER_INFO("GradMulTimer", getName().c_str());
weight_->getWGrad()->mul(in_->value->getTranspose(), out_->grad, 1, 1); weight_->getWGrad()->mul(
*(in_->value->getTranspose()), *(out_->grad), 1, 1);
} }
// If callback does not change value, backward propagation error // If callback does not change value, backward propagation error
...@@ -47,7 +48,7 @@ void FullMatrixProjection::backward(const UpdateCallback& callback) { ...@@ -47,7 +48,7 @@ void FullMatrixProjection::backward(const UpdateCallback& callback) {
/* Calculate the input layers error */ /* Calculate the input layers error */
if (in_->grad) { if (in_->grad) {
REGISTER_TIMER_INFO("BpMulTimer", getName().c_str()); REGISTER_TIMER_INFO("BpMulTimer", getName().c_str());
in_->grad->mul(out_->grad, weight_->getW()->getTranspose(), 1, 1); in_->grad->mul(*(out_->grad), *(weight_->getW()->getTranspose()), 1, 1);
} }
hl_set_sync_flag(syncFlag); hl_set_sync_flag(syncFlag);
......
...@@ -84,8 +84,8 @@ void FullyConnectedLayer::forward(PassType passType) { ...@@ -84,8 +84,8 @@ void FullyConnectedLayer::forward(PassType passType) {
auto input = getInput(i); auto input = getInput(i);
CHECK(input.value) << "The input of 'fc' layer must be matrix"; CHECK(input.value) << "The input of 'fc' layer must be matrix";
REGISTER_TIMER_INFO("FwMulTimer", getName().c_str()); REGISTER_TIMER_INFO("FwMulTimer", getName().c_str());
i == 0 ? outV->mul(input.value, weights_[i]->getW(), 1, 0) i == 0 ? outV->mul(*input.value, *weights_[i]->getW(), 1, 0)
: outV->mul(input.value, weights_[i]->getW(), 1, 1); : outV->mul(*input.value, *weights_[i]->getW(), 1, 1);
} }
/* add the bias-vector */ /* add the bias-vector */
...@@ -123,7 +123,7 @@ void FullyConnectedLayer::backward(const UpdateCallback& callback) { ...@@ -123,7 +123,7 @@ void FullyConnectedLayer::backward(const UpdateCallback& callback) {
MatrixPtr oGrad = getOutputGrad(); MatrixPtr oGrad = getOutputGrad();
{ {
REGISTER_TIMER_INFO("GradMulTimer", getName().c_str()); REGISTER_TIMER_INFO("GradMulTimer", getName().c_str());
weights_[i]->getWGrad()->mul(input_T, oGrad, 1, 1); weights_[i]->getWGrad()->mul(*input_T, *oGrad, 1, 1);
} }
} }
...@@ -136,7 +136,7 @@ void FullyConnectedLayer::backward(const UpdateCallback& callback) { ...@@ -136,7 +136,7 @@ void FullyConnectedLayer::backward(const UpdateCallback& callback) {
if (NULL != preGrad) { if (NULL != preGrad) {
MatrixPtr weights_T = weights_[i]->getW()->getTranspose(); MatrixPtr weights_T = weights_[i]->getW()->getTranspose();
REGISTER_TIMER_INFO("BpMulTimer", getName().c_str()); REGISTER_TIMER_INFO("BpMulTimer", getName().c_str());
preGrad->mul(getOutputGrad(), weights_T, 1, 1); preGrad->mul(*getOutputGrad(), *weights_T, 1, 1);
} }
hl_set_sync_flag(syncFlag); hl_set_sync_flag(syncFlag);
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "ModelConfig.pb.h" #include "ModelConfig.pb.h"
#include "paddle/function/Function.h"
#include "paddle/math/CpuSparseMatrix.h" #include "paddle/math/CpuSparseMatrix.h"
#include "paddle/parameter/Parameter.h" #include "paddle/parameter/Parameter.h"
#include "paddle/utils/ClassRegistrar.h" #include "paddle/utils/ClassRegistrar.h"
...@@ -100,6 +101,11 @@ protected: ...@@ -100,6 +101,11 @@ protected:
/// Mark input grad in(true) or out(false) of backward function. /// Mark input grad in(true) or out(false) of backward function.
std::vector<bool> markInBackward_; std::vector<bool> markInBackward_;
/// Layer forward function
std::vector<std::shared_ptr<FunctionBase>> forward_;
/// Layer backward function
std::vector<std::shared_ptr<FunctionBase>> backward_;
public: public:
/** /**
* Wait until all input value ready. * Wait until all input value ready.
...@@ -126,6 +132,26 @@ public: ...@@ -126,6 +132,26 @@ public:
virtual void markAllInputGrad(); virtual void markAllInputGrad();
protected: protected:
/**
* Create layer function. Function is called in forward or backward.
* \param function, Layer::forward_ or Layer::backward_
* \param name, function name
* \param config, initialization configuration for the function
*/
void createFunction(std::vector<std::shared_ptr<FunctionBase>>& function,
const std::string& name,
const FuncConfig& config) {
if (useGpu_) {
function.emplace_back(
FunctionBase::funcRegistrar_.createByType(name + "-GPU"));
} else {
function.emplace_back(
FunctionBase::funcRegistrar_.createByType(name + "-CPU"));
}
auto& func = function.back();
func->init(config);
}
/** /**
* Notify specified layer the output grad ready. * Notify specified layer the output grad ready.
* Called in the backward function. * Called in the backward function.
......
...@@ -59,7 +59,7 @@ real LinearChainCRF::forward(real* x, int* s, int length) { ...@@ -59,7 +59,7 @@ real LinearChainCRF::forward(real* x, int* s, int length) {
matX->rowMax(*maxX_); matX->rowMax(*maxX_);
expX_->assign(*matX); expX_->assign(*matX);
// subtract max to avoid overflow or underflow // subtract max to avoid overflow or underflow
expX_->mul(maxX_, ones_, (real)-1, (real)1); expX_->mul(*maxX_, *ones_, (real)-1, (real)1);
expX_->exp2(); expX_->exp2();
real* a = a_->getData(); real* a = a_->getData();
......
...@@ -316,7 +316,7 @@ void LstmLayer::forwardSequence(int batchSize, ...@@ -316,7 +316,7 @@ void LstmLayer::forwardSequence(int batchSize,
} }
if (prevOutput_) { if (prevOutput_) {
frameGate->setData(lstmValue.gateValue); frameGate->setData(lstmValue.gateValue);
frameGate->mul(prevOutput_, weight_->getW(), 1, 1); frameGate->mul(*prevOutput_, *weight_->getW(), 1, 1);
} }
} }
AsyncGpuBlock asyncGpuBlock; AsyncGpuBlock asyncGpuBlock;
...@@ -338,7 +338,7 @@ void LstmLayer::forwardSequence(int batchSize, ...@@ -338,7 +338,7 @@ void LstmLayer::forwardSequence(int batchSize,
frameOutput->setData(lstmValue.outputValue); frameOutput->setData(lstmValue.outputValue);
nextFrame(reversed_, getSize()); nextFrame(reversed_, getSize());
frameGate->setData(lstmValue.gateValue); frameGate->setData(lstmValue.gateValue);
frameGate->mul(frameOutput, weight_->getW(), 1, 1); frameGate->mul(*frameOutput, *weight_->getW(), 1, 1);
} }
} }
if (n != numSequences - 1) { if (n != numSequences - 1) {
...@@ -348,7 +348,7 @@ void LstmLayer::forwardSequence(int batchSize, ...@@ -348,7 +348,7 @@ void LstmLayer::forwardSequence(int batchSize,
if (!reversed_) { if (!reversed_) {
if (!prevState_) lstmValue.prevStateValue = nullptr; if (!prevState_) lstmValue.prevStateValue = nullptr;
if (prevOutput_) { if (prevOutput_) {
frameGate->mul(frameOutput, weight_->getW(), 1, 1); frameGate->mul(*frameOutput, *weight_->getW(), 1, 1);
} }
} else { } else {
lstmValue.prevStateValue = nullptr; lstmValue.prevStateValue = nullptr;
...@@ -470,7 +470,7 @@ void LstmLayer::backwardSequence(int batchSize, ...@@ -470,7 +470,7 @@ void LstmLayer::backwardSequence(int batchSize,
frameGate->setData(lstmGrad.gateGrad); frameGate->setData(lstmGrad.gateGrad);
nextFrame(reversed_, getSize()); nextFrame(reversed_, getSize());
frameOutput->setData(lstmGrad.outputGrad); frameOutput->setData(lstmGrad.outputGrad);
frameOutput->mul(frameGate, weightT, 1, 1); frameOutput->mul(*frameGate, *weightT, 1, 1);
} else { } else {
nextFrame(reversed_, getSize()); nextFrame(reversed_, getSize());
} }
...@@ -479,14 +479,14 @@ void LstmLayer::backwardSequence(int batchSize, ...@@ -479,14 +479,14 @@ void LstmLayer::backwardSequence(int batchSize,
if (weight_->getWGrad()) { if (weight_->getWGrad()) {
if (!reversed_) { if (!reversed_) {
weight_->getWGrad()->mul( weight_->getWGrad()->mul(
output_.value->subMatrix(start, length - 1)->getTranspose(), *output_.value->subMatrix(start, length - 1)->getTranspose(),
gate_.grad->subMatrix(start + 1, length - 1), *gate_.grad->subMatrix(start + 1, length - 1),
1, 1,
1); 1);
} else { } else {
weight_->getWGrad()->mul( weight_->getWGrad()->mul(
output_.value->subMatrix(start + 1, length - 1)->getTranspose(), *output_.value->subMatrix(start + 1, length - 1)->getTranspose(),
gate_.grad->subMatrix(start, length - 1), *gate_.grad->subMatrix(start, length - 1),
1, 1,
1); 1);
} }
...@@ -541,7 +541,7 @@ void LstmLayer::forwardBatch(int batchSize, ...@@ -541,7 +541,7 @@ void LstmLayer::forwardBatch(int batchSize,
if (n != 0) { if (n != 0) {
MatrixPtr batch1 = batchValue_->getBatchValue(n - 1, batchSize); MatrixPtr batch1 = batchValue_->getBatchValue(n - 1, batchSize);
gateValue->mul(batch1, weight_->getW(), 1, 1); gateValue->mul(*batch1, *weight_->getW(), 1, 1);
} else if (prevOutput_) { } else if (prevOutput_) {
Matrix::resizeOrCreate(prevBatchOutput2_, Matrix::resizeOrCreate(prevBatchOutput2_,
gateValue->getHeight(), gateValue->getHeight(),
...@@ -549,7 +549,7 @@ void LstmLayer::forwardBatch(int batchSize, ...@@ -549,7 +549,7 @@ void LstmLayer::forwardBatch(int batchSize,
false, false,
useGpu_); useGpu_);
batchValue_->prevOutput2Batch(*prevOutput_, *prevBatchOutput2_); batchValue_->prevOutput2Batch(*prevOutput_, *prevBatchOutput2_);
gateValue->mul(prevBatchOutput2_, weight_->getW(), 1, 1); gateValue->mul(*prevBatchOutput2_, *weight_->getW(), 1, 1);
batchValue_->prevOutput2Batch(*prevState_, batchValue_->prevOutput2Batch(*prevState_,
*totalState_->subMatrix(0, numSequences)); *totalState_->subMatrix(0, numSequences));
...@@ -672,16 +672,16 @@ void LstmLayer::backwardBatch(int batchSize, ...@@ -672,16 +672,16 @@ void LstmLayer::backwardBatch(int batchSize,
if (n != 0) { if (n != 0) {
MatrixPtr tmp = batchGrad_->getBatchValue(n - 1, batchSize); MatrixPtr tmp = batchGrad_->getBatchValue(n - 1, batchSize);
tmp->mul(gateGrad, weightT, 1, 1); tmp->mul(*gateGrad, *weightT, 1, 1);
} }
if (n != 0 && weight_->getWGrad()) { if (n != 0 && weight_->getWGrad()) {
/* backward weight */ /* backward weight */
MatrixPtr outputValue = batchValue_->getBatchValue(n - 1, batchSize); MatrixPtr outputValue = batchValue_->getBatchValue(n - 1, batchSize);
weight_->getWGrad()->mul(outputValue->getTranspose(), gateGrad, 1, 1); weight_->getWGrad()->mul(*outputValue->getTranspose(), *gateGrad, 1, 1);
} else if (prevOutput_ && weight_->getWGrad()) { } else if (prevOutput_ && weight_->getWGrad()) {
weight_->getWGrad()->mul( weight_->getWGrad()->mul(
prevBatchOutput2_->getTranspose(), gateGrad, 1, 1); *prevBatchOutput2_->getTranspose(), *gateGrad, 1, 1);
} }
} }
} }
......
...@@ -547,7 +547,7 @@ void MDLstmLayer::forwardOneSequence(int start, CoordIterator& coordIter) { ...@@ -547,7 +547,7 @@ void MDLstmLayer::forwardOneSequence(int start, CoordIterator& coordIter) {
if (coordIter.getPrePos(delays_, i, prePos)) { if (coordIter.getPrePos(delays_, i, prePos)) {
int preOffset = coordIter.offset(prePos); int preOffset = coordIter.offset(prePos);
frameGate_[start + offset].value->mul( frameGate_[start + offset].value->mul(
frameOutput_[start + preOffset].value, weight_->getW(), 1.0, 1.0); *frameOutput_[start + preOffset].value, *weight_->getW(), 1.0, 1.0);
} }
} }
forwardGate2OutputSequence(start, coordIter); forwardGate2OutputSequence(start, coordIter);
...@@ -747,11 +747,11 @@ void MDLstmLayer::backwardOneSequence(int start, CoordIterator& coordIter) { ...@@ -747,11 +747,11 @@ void MDLstmLayer::backwardOneSequence(int start, CoordIterator& coordIter) {
if (coordIter.getPrePos(delays_, i, prePos)) { if (coordIter.getPrePos(delays_, i, prePos)) {
int preOffset = coordIter.offset(prePos); int preOffset = coordIter.offset(prePos);
frameOutput_[start + preOffset].grad->mul( frameOutput_[start + preOffset].grad->mul(
frameGate_[start + offset].grad, weightT, 1.0, 1.0); *frameGate_[start + offset].grad, *weightT, 1.0, 1.0);
if (weight_->getWGrad()) { if (weight_->getWGrad()) {
weight_->getWGrad()->mul( weight_->getWGrad()->mul(
frameOutput_[start + preOffset].value->getTranspose(), *frameOutput_[start + preOffset].value->getTranspose(),
frameGate_[start + offset].grad, *frameGate_[start + offset].grad,
1.0, 1.0,
1.0); 1.0);
} }
......
...@@ -45,6 +45,15 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, ...@@ -45,6 +45,15 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
/* the size of inputs for norm-layer is 1 */ /* the size of inputs for norm-layer is 1 */
CHECK_EQ(config_.inputs_size(), 1); CHECK_EQ(config_.inputs_size(), 1);
createFunction(
forward_,
"CrossMapNormal",
FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_));
createFunction(
backward_,
"CrossMapNormalGrad",
FuncConfig().set("size", size_).set("scale", scale_).set("pow", pow_));
return true; return true;
} }
...@@ -54,7 +63,7 @@ void CMRProjectionNormLayer::forward(PassType passType) { ...@@ -54,7 +63,7 @@ void CMRProjectionNormLayer::forward(PassType passType) {
/* malloc memory for the output_ if necessary */ /* malloc memory for the output_ if necessary */
/* note: one sample correspond to one row */ /* note: one sample correspond to one row */
MatrixPtr input = inputLayers_[0]->getOutputValue(); MatrixPtr input = inputLayers_[0]->getOutputValue();
int batchSize = input->getHeight(); size_t batchSize = input->getHeight();
int size = getSize(); int size = getSize();
resetOutput(batchSize, size); resetOutput(batchSize, size);
...@@ -62,10 +71,11 @@ void CMRProjectionNormLayer::forward(PassType passType) { ...@@ -62,10 +71,11 @@ void CMRProjectionNormLayer::forward(PassType passType) {
Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_);
denoms_->zeroMem(); dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_};
forward_[0]->calc(
outV->crossMapNormalFwd( {Tensor(input->getData(), dims_)},
*input, imgSizeH_, imgSizeW_, *denoms_, channels_, size_, scale_, pow_); {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)},
{});
} }
void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
...@@ -80,15 +90,11 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { ...@@ -80,15 +90,11 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
MatrixPtr localOutV = getOutputValue(); MatrixPtr localOutV = getOutputValue();
MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); MatrixPtr preOutV = inputLayers_[0]->getOutputValue();
preOutGrad->crossMapNormalBwd(*localGrad, backward_[0]->calc({Tensor(preOutV->getData(), dims_),
*denoms_, Tensor(localOutV->getData(), dims_),
*preOutV, Tensor(localGrad->getData(), dims_),
*localOutV, Tensor(denoms_->getData(), dims_)},
channels_, {Tensor(preOutGrad->getData(), dims_)},
imgSizeH_, {});
imgSizeW_,
size_,
scale_,
pow_);
} }
} // namespace paddle } // namespace paddle
...@@ -39,5 +39,8 @@ public: ...@@ -39,5 +39,8 @@ public:
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType); void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr); void backward(const UpdateCallback& callback = nullptr);
protected:
Dims dims_;
}; };
} // namespace paddle } // namespace paddle
...@@ -96,7 +96,7 @@ void OuterProdLayer::forward(PassType passType) { ...@@ -96,7 +96,7 @@ void OuterProdLayer::forward(PassType passType) {
tmpRow0->setData(inV0->getData() + i * dim0); tmpRow0->setData(inV0->getData() + i * dim0);
tmpRow1->setData(inV1->getData() + i * dim1); tmpRow1->setData(inV1->getData() + i * dim1);
tmpMtx0->mul(tmpRow0->getTranspose(), tmpRow1); tmpMtx0->mul(*tmpRow0->getTranspose(), *tmpRow1);
} }
} }
} }
...@@ -121,7 +121,7 @@ void OuterProdLayer::backward(const UpdateCallback& callback) { ...@@ -121,7 +121,7 @@ void OuterProdLayer::backward(const UpdateCallback& callback) {
tmpRow0->setData(inG0->getData() + i * dim0); tmpRow0->setData(inG0->getData() + i * dim0);
tmpRow1->setData(inV1->getData() + i * dim1); tmpRow1->setData(inV1->getData() + i * dim1);
tmpRow0->mul(tmpRow1, tmpMtx0->getTranspose(), 1, 1); tmpRow0->mul(*tmpRow1, *tmpMtx0->getTranspose(), 1, 1);
} }
} }
...@@ -131,7 +131,7 @@ void OuterProdLayer::backward(const UpdateCallback& callback) { ...@@ -131,7 +131,7 @@ void OuterProdLayer::backward(const UpdateCallback& callback) {
tmpRow0->setData(inV0->getData() + i * dim0); tmpRow0->setData(inV0->getData() + i * dim0);
tmpRow1->setData(inG1->getData() + i * dim1); tmpRow1->setData(inG1->getData() + i * dim1);
tmpRow1->mul(tmpRow0, tmpMtx0, 1, 1); tmpRow1->mul(*tmpRow0, *tmpMtx0, 1, 1);
} }
} }
} }
......
...@@ -215,12 +215,12 @@ void RecurrentLayer::forwardSequence(int batchSize, ...@@ -215,12 +215,12 @@ void RecurrentLayer::forwardSequence(int batchSize,
void RecurrentLayer::forwardOneSequence(int start, int length) { void RecurrentLayer::forwardOneSequence(int start, int length) {
if (!reversed_) { if (!reversed_) {
if (prevOutput_) { if (prevOutput_) {
frameOutput_[start].value->mul(prevOutput_, weight_->getW(), 1, 1); frameOutput_[start].value->mul(*prevOutput_, *weight_->getW(), 1, 1);
} }
activation_->forward(frameOutput_[start]); activation_->forward(frameOutput_[start]);
for (int i = 1; i < length; ++i) { for (int i = 1; i < length; ++i) {
frameOutput_[start + i].value->mul( frameOutput_[start + i].value->mul(
frameOutput_[start + i - 1].value, weight_->getW(), 1, 1); *frameOutput_[start + i - 1].value, *weight_->getW(), 1, 1);
activation_->forward(frameOutput_[start + i]); activation_->forward(frameOutput_[start + i]);
} }
if (prevOutput_) { if (prevOutput_) {
...@@ -230,7 +230,7 @@ void RecurrentLayer::forwardOneSequence(int start, int length) { ...@@ -230,7 +230,7 @@ void RecurrentLayer::forwardOneSequence(int start, int length) {
activation_->forward(frameOutput_[start + length - 1]); activation_->forward(frameOutput_[start + length - 1]);
for (int i = length - 2; i >= 0; --i) { for (int i = length - 2; i >= 0; --i) {
frameOutput_[start + i].value->mul( frameOutput_[start + i].value->mul(
frameOutput_[start + i + 1].value, weight_->getW(), 1, 1); *frameOutput_[start + i + 1].value, *weight_->getW(), 1, 1);
activation_->forward(frameOutput_[start + i]); activation_->forward(frameOutput_[start + i]);
} }
} }
...@@ -282,13 +282,13 @@ void RecurrentLayer::backwardOneSequence(int start, int length) { ...@@ -282,13 +282,13 @@ void RecurrentLayer::backwardOneSequence(int start, int length) {
for (int i = length - 1; i > 0; --i) { for (int i = length - 1; i > 0; --i) {
activation_->backward(frameOutput_[start + i]); activation_->backward(frameOutput_[start + i]);
frameOutput_[start + i - 1].grad->mul( frameOutput_[start + i - 1].grad->mul(
frameOutput_[start + i].grad, weightT, 1, 1); *frameOutput_[start + i].grad, *weightT, 1, 1);
} }
activation_->backward(frameOutput_[start]); activation_->backward(frameOutput_[start]);
if (weight_->getWGrad()) { if (weight_->getWGrad()) {
weight_->getWGrad()->mul( weight_->getWGrad()->mul(
output_.value->subMatrix(start, length - 1)->getTranspose(), *output_.value->subMatrix(start, length - 1)->getTranspose(),
output_.grad->subMatrix(start + 1, length - 1), *output_.grad->subMatrix(start + 1, length - 1),
1, 1,
1); 1);
} }
...@@ -296,13 +296,13 @@ void RecurrentLayer::backwardOneSequence(int start, int length) { ...@@ -296,13 +296,13 @@ void RecurrentLayer::backwardOneSequence(int start, int length) {
for (int i = 0; i < length - 1; ++i) { for (int i = 0; i < length - 1; ++i) {
activation_->backward(frameOutput_[start + i]); activation_->backward(frameOutput_[start + i]);
frameOutput_[start + i + 1].grad->mul( frameOutput_[start + i + 1].grad->mul(
frameOutput_[start + i].grad, weightT, 1, 1); *frameOutput_[start + i].grad, *weightT, 1, 1);
} }
activation_->backward(frameOutput_[start + length - 1]); activation_->backward(frameOutput_[start + length - 1]);
if (weight_->getWGrad()) { if (weight_->getWGrad()) {
weight_->getWGrad()->mul( weight_->getWGrad()->mul(
output_.value->subMatrix(start + 1, length - 1)->getTranspose(), *output_.value->subMatrix(start + 1, length - 1)->getTranspose(),
output_.grad->subMatrix(start, length - 1), *output_.grad->subMatrix(start, length - 1),
1, 1,
1); 1);
} }
...@@ -329,7 +329,7 @@ void RecurrentLayer::forwardBatch(int batchSize, ...@@ -329,7 +329,7 @@ void RecurrentLayer::forwardBatch(int batchSize,
if (n != 0) { if (n != 0) {
MatrixPtr batch1 = MatrixPtr batch1 =
batchValue_->getBatchValue(n - 1, batch2->getHeight()); batchValue_->getBatchValue(n - 1, batch2->getHeight());
batch2->mul(batch1, weight_->getW(), 1, 1); batch2->mul(*batch1, *weight_->getW(), 1, 1);
} }
Argument arg; Argument arg;
arg.value = batch2; arg.value = batch2;
...@@ -367,14 +367,14 @@ void RecurrentLayer::backwardBatch(int batchSize, ...@@ -367,14 +367,14 @@ void RecurrentLayer::backwardBatch(int batchSize,
if (n != 0) { if (n != 0) {
batch1 = batchGrad_->getBatchValue(n - 1, batch2->getHeight()); batch1 = batchGrad_->getBatchValue(n - 1, batch2->getHeight());
batch1->mul(batch2, weightT, 1, 1); batch1->mul(*batch2, *weightT, 1, 1);
} }
if (backwardByBatch && weight_->getWGrad()) { if (backwardByBatch && weight_->getWGrad()) {
if (n != 0) { if (n != 0) {
/* backward weight */ /* backward weight */
batch1 = batchValue_->getBatchValue(n - 1, batch2->getHeight()); batch1 = batchValue_->getBatchValue(n - 1, batch2->getHeight());
weight_->getWGrad()->mul(batch1->getTranspose(), batch2, 1, 1); weight_->getWGrad()->mul(*batch1->getTranspose(), *batch2, 1, 1);
} }
} }
} }
...@@ -389,14 +389,14 @@ void RecurrentLayer::backwardBatch(int batchSize, ...@@ -389,14 +389,14 @@ void RecurrentLayer::backwardBatch(int batchSize,
int len = starts[seq + 1] - starts[seq]; int len = starts[seq + 1] - starts[seq];
if (!reversed_) { if (!reversed_) {
weight_->getWGrad()->mul( weight_->getWGrad()->mul(
output_.value->subMatrix(starts[seq], len - 1)->getTranspose(), *output_.value->subMatrix(starts[seq], len - 1)->getTranspose(),
output_.grad->subMatrix(starts[seq] + 1, len - 1), *output_.grad->subMatrix(starts[seq] + 1, len - 1),
1, 1,
1); 1);
} else { } else {
weight_->getWGrad()->mul( weight_->getWGrad()->mul(
output_.value->subMatrix(starts[seq] + 1, len - 1)->getTranspose(), *output_.value->subMatrix(starts[seq] + 1, len - 1)->getTranspose(),
output_.grad->subMatrix(starts[seq], len - 1), *output_.grad->subMatrix(starts[seq], len - 1),
1, 1,
1); 1);
} }
......
...@@ -155,20 +155,20 @@ void SelectiveFullyConnectedLayer::forward(PassType passType) { ...@@ -155,20 +155,20 @@ void SelectiveFullyConnectedLayer::forward(PassType passType) {
// manully compute the multiplication of // manully compute the multiplication of
// the input vector and the selected rows. // the input vector and the selected rows.
REGISTER_TIMER("selective.plain"); REGISTER_TIMER("selective.plain");
interOutput_->mul(input, weight->getTranspose(), 1, scaleT); interOutput_->mul(*input, *weight->getTranspose(), 1, scaleT);
} else { } else {
// if the indecies is not sparse enough, // if the indecies is not sparse enough,
// use full mul instead // use full mul instead
REGISTER_TIMER("selective.mul"); REGISTER_TIMER("selective.mul");
if (fullOutput_) { if (fullOutput_) {
interOutput_->mul(input, weight->getTranspose(), 1, scaleT); interOutput_->mul(*input, *weight->getTranspose(), 1, scaleT);
} else { } else {
Matrix::resizeOrCreate(mmat_, Matrix::resizeOrCreate(mmat_,
hsize, hsize,
wsize, wsize,
/*trans=*/false, /*trans=*/false,
/*useGpu=*/useGpu_); /*useGpu=*/useGpu_);
mmat_->mul(input, weight->getTranspose()); mmat_->mul(*input, *weight->getTranspose());
interOutput_->add3(mmat_); interOutput_->add3(mmat_);
} }
} }
...@@ -242,14 +242,14 @@ void SelectiveFullyConnectedLayer::backward(const UpdateCallback& callback) { ...@@ -242,14 +242,14 @@ void SelectiveFullyConnectedLayer::backward(const UpdateCallback& callback) {
MatrixPtr preGrad = getInputGrad(i); MatrixPtr preGrad = getInputGrad(i);
if (preGrad) { if (preGrad) {
REGISTER_TIMER_INFO("BpMulTimer", getName().c_str()); REGISTER_TIMER_INFO("BpMulTimer", getName().c_str());
preGrad->mul(interOutGrad_, weights_[i]->getW(), 1, 1); preGrad->mul(*interOutGrad_, *weights_[i]->getW(), 1, 1);
} }
MatrixPtr wGrad = weights_[i]->getWGrad(); MatrixPtr wGrad = weights_[i]->getWGrad();
if (wGrad) { if (wGrad) {
REGISTER_TIMER_INFO("GradMulTimer", getName().c_str()); REGISTER_TIMER_INFO("GradMulTimer", getName().c_str());
MatrixPtr input = getInputValue(i); MatrixPtr input = getInputValue(i);
wGrad->mul(interOutGrad_->getTranspose(), input, 1, 1); wGrad->mul(*interOutGrad_->getTranspose(), *input, 1, 1);
} }
{ {
......
...@@ -77,7 +77,7 @@ void TensorLayer::forward(PassType passType) { ...@@ -77,7 +77,7 @@ void TensorLayer::forward(PassType passType) {
REGISTER_TIMER_INFO("TensorFwMulTimer", getName().c_str()); REGISTER_TIMER_INFO("TensorFwMulTimer", getName().c_str());
for (size_t i = 0; i < getSize(); ++i) { for (size_t i = 0; i < getSize(); ++i) {
MatrixPtr weights = weights_[i]->getW(); MatrixPtr weights = weights_[i]->getW();
tmpMat->mul(input1, weights, 1, 0); tmpMat->mul(*input1, *weights, 1, 0);
outV->rowDotMul(i, *tmpMat, *input2); outV->rowDotMul(i, *tmpMat, *input2);
} }
} }
...@@ -112,7 +112,7 @@ void TensorLayer::backward(const UpdateCallback& callback) { ...@@ -112,7 +112,7 @@ void TensorLayer::backward(const UpdateCallback& callback) {
if (weights_[i]->getWGrad()) { if (weights_[i]->getWGrad()) {
tmpMat->rowScale(i, *input1, *oGrad); tmpMat->rowScale(i, *input1, *oGrad);
MatrixPtr input1_T = tmpMat->getTranspose(); MatrixPtr input1_T = tmpMat->getTranspose();
weights_[i]->getWGrad()->mul(input1_T, input2, 1, 1); weights_[i]->getWGrad()->mul(*input1_T, *input2, 1, 1);
} }
} }
} }
...@@ -130,11 +130,11 @@ void TensorLayer::backward(const UpdateCallback& callback) { ...@@ -130,11 +130,11 @@ void TensorLayer::backward(const UpdateCallback& callback) {
if (NULL != preGrad1) { /* (grad * e2) * trans(W) */ if (NULL != preGrad1) { /* (grad * e2) * trans(W) */
tmpMat->rowScale(i, *input2, *oGrad); tmpMat->rowScale(i, *input2, *oGrad);
MatrixPtr weights_T = weights->getTranspose(); MatrixPtr weights_T = weights->getTranspose();
preGrad1->mul(tmpMat, weights_T, 1, 1); preGrad1->mul(*tmpMat, *weights_T, 1, 1);
} }
if (NULL != preGrad2) { /* (grad * e1) * W */ if (NULL != preGrad2) { /* (grad * e1) * W */
tmpMat->rowScale(i, *input1, *oGrad); tmpMat->rowScale(i, *input1, *oGrad);
preGrad2->mul(tmpMat, weights, 1, 1); preGrad2->mul(*tmpMat, *weights, 1, 1);
} }
} }
} }
......
...@@ -46,7 +46,7 @@ TransposedFullMatrixProjection::TransposedFullMatrixProjection( ...@@ -46,7 +46,7 @@ TransposedFullMatrixProjection::TransposedFullMatrixProjection(
void TransposedFullMatrixProjection::forward() { void TransposedFullMatrixProjection::forward() {
REGISTER_TIMER_INFO("FwMulTimer", getName().c_str()); REGISTER_TIMER_INFO("FwMulTimer", getName().c_str());
out_->value->mul(in_->value, weight_->getW()->getTranspose(), 1, 1); out_->value->mul(*(in_->value), *(weight_->getW()->getTranspose()), 1, 1);
} }
void TransposedFullMatrixProjection::backward(const UpdateCallback& callback) { void TransposedFullMatrixProjection::backward(const UpdateCallback& callback) {
...@@ -55,7 +55,8 @@ void TransposedFullMatrixProjection::backward(const UpdateCallback& callback) { ...@@ -55,7 +55,8 @@ void TransposedFullMatrixProjection::backward(const UpdateCallback& callback) {
/* Calculate the W-gradient for the current layer */ /* Calculate the W-gradient for the current layer */
if (weight_->getWGrad()) { if (weight_->getWGrad()) {
REGISTER_TIMER_INFO("GradMulTimer", getName().c_str()); REGISTER_TIMER_INFO("GradMulTimer", getName().c_str());
weight_->getWGrad()->mul(out_->grad->getTranspose(), in_->value, 1, 1); weight_->getWGrad()->mul(
*(out_->grad->getTranspose()), *(in_->value), 1, 1);
} }
// If callback does not change value, backprop error asynchronously so that // If callback does not change value, backprop error asynchronously so that
...@@ -69,7 +70,7 @@ void TransposedFullMatrixProjection::backward(const UpdateCallback& callback) { ...@@ -69,7 +70,7 @@ void TransposedFullMatrixProjection::backward(const UpdateCallback& callback) {
/* Calculate the input layers error */ /* Calculate the input layers error */
if (in_->grad) { if (in_->grad) {
REGISTER_TIMER_INFO("BpMulTimer", getName().c_str()); REGISTER_TIMER_INFO("BpMulTimer", getName().c_str());
in_->grad->mul(out_->grad, weight_->getW(), 1, 1); in_->grad->mul(*(out_->grad), *(weight_->getW()), 1, 1);
} }
hl_set_sync_flag(syncFlag); hl_set_sync_flag(syncFlag);
......
...@@ -1021,11 +1021,10 @@ void testNormLayer(const string& normType, bool trans, bool useGpu) { ...@@ -1021,11 +1021,10 @@ void testNormLayer(const string& normType, bool trans, bool useGpu) {
testLayerGrad(config, "norm", 100, trans, useGpu); testLayerGrad(config, "norm", 100, trans, useGpu);
} }
#ifndef PADDLE_ONLY_CPU
TEST(Layer, NormLayer) { TEST(Layer, NormLayer) {
testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ true); testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ true);
testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ false);
} }
#endif
void setPoolConfig(TestConfig* config, void setPoolConfig(TestConfig* config,
PoolConfig* pool, PoolConfig* pool,
......
...@@ -163,15 +163,16 @@ MatrixPtr CpuSparseMatrix::getTranspose() { ...@@ -163,15 +163,16 @@ MatrixPtr CpuSparseMatrix::getTranspose() {
SparseValueType CpuSparseMatrix::getValueType() { return valueType_; } SparseValueType CpuSparseMatrix::getValueType() { return valueType_; }
void CpuSparseMatrix::mul(MatrixPtr a, MatrixPtr b, real scaleAB, real scaleT) { void CpuSparseMatrix::mul(const Matrix& a,
const Matrix& b,
real scaleAB,
real scaleT) {
CHECK(!isTransposed()) << "Not supported"; CHECK(!isTransposed()) << "Not supported";
const auto a_ptr = dynamic_cast<const CpuMatrix*>(&a);
const auto b_ptr = dynamic_cast<const CpuMatrix*>(&b);
if (dynamic_cast<CpuMatrix*>(a.get()) && dynamic_cast<CpuMatrix*>(b.get())) { if (a_ptr && b_ptr) {
CpuMatrix::mul(dynamic_cast<CpuMatrix*>(a.get()), CpuMatrix::mul((CpuMatrix*)a_ptr, (CpuMatrix*)b_ptr, this, scaleAB, scaleT);
dynamic_cast<CpuMatrix*>(b.get()),
this,
scaleAB,
scaleT);
} else { } else {
LOG(FATAL) << "not supported"; LOG(FATAL) << "not supported";
} }
......
...@@ -203,7 +203,7 @@ public: ...@@ -203,7 +203,7 @@ public:
/// mem MUST be alloced outside (memAlloc=false) /// mem MUST be alloced outside (memAlloc=false)
void transpose(MatrixPtr matTrans, bool memAlloc); void transpose(MatrixPtr matTrans, bool memAlloc);
void mul(MatrixPtr A, MatrixPtr B, real alpha, real beta); void mul(const Matrix& A, const Matrix& B, real alpha, real beta);
/** /**
* @brief sparseMatrix += denseMatrix * @brief sparseMatrix += denseMatrix
......
此差异已折叠。
...@@ -444,8 +444,8 @@ public: ...@@ -444,8 +444,8 @@ public:
* this = scaleAB*(a*b) + scaleT*this * this = scaleAB*(a*b) + scaleT*this
* @endcode * @endcode
*/ */
virtual void mul(const MatrixPtr a, virtual void mul(const Matrix& a,
const MatrixPtr b, const Matrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT) {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
...@@ -643,7 +643,7 @@ public: ...@@ -643,7 +643,7 @@ public:
* this = a*b * this = a*b
* @endcode * @endcode
*/ */
virtual void mul(const MatrixPtr a, const MatrixPtr b) { virtual void mul(const Matrix& a, const Matrix& b) {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
...@@ -835,7 +835,7 @@ public: ...@@ -835,7 +835,7 @@ public:
* *
* output[i] = 0 if row i is correct. * output[i] = 0 if row i is correct.
*/ */
virtual void classificationError(MatrixPtr output, IVectorPtr label) { virtual void classificationError(Matrix& output, IVector& label) {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
...@@ -952,31 +952,6 @@ public: ...@@ -952,31 +952,6 @@ public:
LOG(FATAL) << "Not implemeted"; LOG(FATAL) << "Not implemeted";
} }
/// normalize-operation.
virtual void crossMapNormalFwd(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
Matrix& denoms,
size_t channels,
size_t sizeX,
float scale,
float pow) {
LOG(FATAL) << "Not implemeted";
}
virtual void crossMapNormalBwd(Matrix& localGrad,
Matrix& denoms,
Matrix& preOutV,
Matrix& localOutV,
size_t channels,
size_t imgSizeH,
size_t imgSizeW,
size_t size,
float scale,
float pow) {
LOG(FATAL) << "Not implemeted";
}
/** /**
* Input: one or more sequences. Each sequence contains some instances. * Input: one or more sequences. Each sequence contains some instances.
* *
...@@ -997,8 +972,8 @@ public: ...@@ -997,8 +972,8 @@ public:
LOG(FATAL) << "Not implemeted"; LOG(FATAL) << "Not implemeted";
} }
virtual void contextProjectionForward(MatrixPtr input, virtual void contextProjectionForward(Matrix& input,
MatrixPtr weight, Matrix* weight,
const IVector& sequence, const IVector& sequence,
int contextLength, int contextLength,
int contextStart, int contextStart,
...@@ -1007,8 +982,8 @@ public: ...@@ -1007,8 +982,8 @@ public:
LOG(FATAL) << "Not implemeted"; LOG(FATAL) << "Not implemeted";
} }
virtual void contextProjectionBackward(MatrixPtr inputGrad, virtual void contextProjectionBackward(Matrix* inputGrad,
MatrixPtr weightGrad, Matrix* weightGrad,
const IVector& sequence, const IVector& sequence,
int contextLength, int contextLength,
int contextStart, int contextStart,
...@@ -1017,14 +992,14 @@ public: ...@@ -1017,14 +992,14 @@ public:
LOG(FATAL) << "Not implemeted"; LOG(FATAL) << "Not implemeted";
} }
virtual void contextProjectionBackwardData(MatrixPtr inputGrad, virtual void contextProjectionBackwardData(Matrix& inputGrad,
const IVector& sequence, const IVector& sequence,
int contextLength, int contextLength,
int contextStart) { int contextStart) {
LOG(FATAL) << "Not implemeted"; LOG(FATAL) << "Not implemeted";
} }
virtual void contextProjectionBackwardWeight(MatrixPtr weightGrad, virtual void contextProjectionBackwardWeight(Matrix& weightGrad,
const IVector& sequence, const IVector& sequence,
int contextLength, int contextLength,
int contextStart, int contextStart,
...@@ -1272,14 +1247,14 @@ public: ...@@ -1272,14 +1247,14 @@ public:
* this = scaleAB*(a*b) + scaleT*this * this = scaleAB*(a*b) + scaleT*this
* @endcode * @endcode
*/ */
void mul(const MatrixPtr a, const MatrixPtr b, real scaleAB, real scaleT); void mul(const Matrix& a, const Matrix& b, real scaleAB, real scaleT);
/** /**
* @code * @code
* this = a*b * this = a*b
* @endcode * @endcode
*/ */
void mul(const MatrixPtr a, const MatrixPtr b); void mul(const Matrix& a, const Matrix& b);
void mul(const GpuMatrix& a, const GpuMatrix& b, real scaleAB, real scaleT); void mul(const GpuMatrix& a, const GpuMatrix& b, real scaleAB, real scaleT);
...@@ -1373,7 +1348,7 @@ public: ...@@ -1373,7 +1348,7 @@ public:
void check(std::ostream& os, Matrix& refMat, bool printDiff = true); void check(std::ostream& os, Matrix& refMat, bool printDiff = true);
void randomizeUniform(); void randomizeUniform();
void classificationError(MatrixPtr output, IVectorPtr label); void classificationError(Matrix& output, IVector& label);
void convExpand(Matrix& feature, void convExpand(Matrix& feature,
int feaImgHeight, int feaImgHeight,
...@@ -1459,26 +1434,6 @@ public: ...@@ -1459,26 +1434,6 @@ public:
size_t paddingH, size_t paddingH,
size_t paddingW); size_t paddingW);
void crossMapNormalFwd(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
Matrix& denoms,
size_t channels,
size_t sizeX,
float scale,
float pow);
void crossMapNormalBwd(Matrix& localGrad,
Matrix& denoms,
Matrix& preOutV,
Matrix& localOutV,
size_t channels,
size_t imgSizeH,
size_t imgSizeW,
size_t sizeX,
float scale,
float pow);
void maxSequenceForward(Matrix& input, void maxSequenceForward(Matrix& input,
const IVector& sequence, const IVector& sequence,
IVector& index); IVector& index);
...@@ -1487,20 +1442,20 @@ public: ...@@ -1487,20 +1442,20 @@ public:
const IVector& sequence, const IVector& sequence,
IVector& index); IVector& index);
void contextProjectionForward(MatrixPtr input, void contextProjectionForward(Matrix& input,
MatrixPtr weight, Matrix* weight,
const IVector& sequence, const IVector& sequence,
int contextLength, int contextLength,
int contextStart, int contextStart,
size_t beginPad, size_t beginPad,
bool isPadding); bool isPadding);
void contextProjectionBackwardData(MatrixPtr inputGrad, void contextProjectionBackwardData(Matrix& inputGrad,
const IVector& sequence, const IVector& sequence,
int contextLength, int contextLength,
int contextStart); int contextStart);
void contextProjectionBackwardWeight(MatrixPtr weightGrad, void contextProjectionBackwardWeight(Matrix& weightGrad,
const IVector& sequence, const IVector& sequence,
int contextLength, int contextLength,
int contextStart, int contextStart,
...@@ -1685,26 +1640,6 @@ public: ...@@ -1685,26 +1640,6 @@ public:
size_t paddingH, size_t paddingH,
size_t paddingW); size_t paddingW);
void crossMapNormalFwd(Matrix& input,
size_t imgSizeH,
size_t imgSizeW,
Matrix& denoms,
size_t channels,
size_t sizeX,
float scale,
float pow);
void crossMapNormalBwd(Matrix& localGrad,
Matrix& denoms,
Matrix& preOutV,
Matrix& localOutV,
size_t channels,
size_t imgSizeH,
size_t imgSizeW,
size_t sizeX,
float scale,
float pow);
void maxSequenceForward(Matrix& input, void maxSequenceForward(Matrix& input,
const IVector& sequence, const IVector& sequence,
IVector& index); IVector& index);
...@@ -1713,16 +1648,16 @@ public: ...@@ -1713,16 +1648,16 @@ public:
const IVector& sequence, const IVector& sequence,
IVector& index); IVector& index);
void contextProjectionForward(MatrixPtr input, void contextProjectionForward(Matrix& input,
MatrixPtr weight, Matrix* weight,
const IVector& sequence, const IVector& sequence,
int contextLength, int contextLength,
int contextStart, int contextStart,
size_t beginPad, size_t beginPad,
bool isPadding); bool isPadding);
void contextProjectionBackward(MatrixPtr inputGrad, void contextProjectionBackward(Matrix* inputGrad,
MatrixPtr weightGrad, Matrix* weightGrad,
const IVector& sequence, const IVector& sequence,
int contextLength, int contextLength,
int contextStart, int contextStart,
...@@ -1784,7 +1719,7 @@ public: ...@@ -1784,7 +1719,7 @@ public:
void addColumnVector(const Matrix& b); void addColumnVector(const Matrix& b);
void mul(const MatrixPtr a, const MatrixPtr b, real scaleAB, real scaleT); void mul(const Matrix& a, const Matrix& b, real scaleAB, real scaleT);
void mul(CpuMatrix* a, CpuMatrix* b, real scaleAB, real scaleT); void mul(CpuMatrix* a, CpuMatrix* b, real scaleAB, real scaleT);
void mul(CpuMatrix* a, CpuSparseMatrix* b, real scaleAB, real scaleT); void mul(CpuMatrix* a, CpuSparseMatrix* b, real scaleAB, real scaleT);
...@@ -1807,7 +1742,7 @@ public: ...@@ -1807,7 +1742,7 @@ public:
virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT); virtual void mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, real scaleT);
void mul(const MatrixPtr a, const MatrixPtr b); void mul(const Matrix& a, const Matrix& b);
void rightMul(Matrix& b, real scaleAB, real scaleT); void rightMul(Matrix& b, real scaleAB, real scaleT);
void rightMul(Matrix& b); void rightMul(Matrix& b);
...@@ -1881,7 +1816,7 @@ public: ...@@ -1881,7 +1816,7 @@ public:
void randomizeUniform(); void randomizeUniform();
void classificationError(MatrixPtr output, IVectorPtr label); void classificationError(Matrix& output, IVector& label);
void addByBitCode(size_t numClasses, const IVector& codes, const Matrix& vec); void addByBitCode(size_t numClasses, const IVector& codes, const Matrix& vec);
......
...@@ -571,49 +571,48 @@ void GpuSparseMatrix::transpose(MatrixPtr matTrans, bool memAlloc) { ...@@ -571,49 +571,48 @@ void GpuSparseMatrix::transpose(MatrixPtr matTrans, bool memAlloc) {
hl_stream_synchronize(stream); hl_stream_synchronize(stream);
} }
void GpuSparseMatrix::mul(const GpuMatrixPtr a, void GpuSparseMatrix::mul(const GpuMatrix& a,
const GpuMatrixPtr b, const GpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT) {
CHECK(a->useGpu_ && b->useGpu_) << "type not match"; CHECK(a.useGpu_ && b.useGpu_) << "type not match";
CHECK(!trans_) << "trans not supported"; CHECK(!trans_) << "trans not supported";
real* A_d = a->getData(); real* A_d = (real*)a.getData();
real* B_d = b->getData(); real* B_d = (real*)b.getData();
hl_sparse_matrix_s C_d = sMatrix_.get(); hl_sparse_matrix_s C_d = sMatrix_.get();
hl_trans_op_t a_trans = a->trans_ ? HPPL_OP_T : HPPL_OP_N; hl_trans_op_t a_trans = a.trans_ ? HPPL_OP_T : HPPL_OP_N;
hl_trans_op_t b_trans = b->trans_ ? HPPL_OP_T : HPPL_OP_N; hl_trans_op_t b_trans = b.trans_ ? HPPL_OP_T : HPPL_OP_N;
if (!a->trans_ && !b->trans_) { if (!a.trans_ && !b.trans_) {
CHECK(height_ == a->getHeight()); CHECK(height_ == a.getHeight());
CHECK(width_ == b->getWidth()); CHECK(width_ == b.getWidth());
CHECK(a->getWidth() == b->getHeight()); CHECK(a.getWidth() == b.getHeight());
} else if (a->trans_ && !b->trans_) { } else if (a.trans_ && !b.trans_) {
CHECK(height_ == a->getWidth()); CHECK(height_ == a.getWidth());
CHECK(width_ == b->getWidth()); CHECK(width_ == b.getWidth());
CHECK(a->getHeight() == b->getHeight()); CHECK(a.getHeight() == b.getHeight());
} else if (!a->trans_ && b->trans_) { } else if (!a.trans_ && b.trans_) {
CHECK(height_ == a->getHeight()); CHECK(height_ == a.getHeight());
CHECK(width_ == b->getHeight()); CHECK(width_ == b.getHeight());
CHECK(a->getWidth() == b->getWidth()); CHECK(a.getWidth() == b.getWidth());
} else { } else {
LOG(INFO) << "Not support"; LOG(INFO) << "Not support";
} }
int dimM = height_; int dimM = height_;
int dimN = width_; int dimN = width_;
int dimK = !b->trans_ ? b->getHeight() : b->getWidth(); int dimK = !b.trans_ ? b.getHeight() : b.getWidth();
hl_sparse_matrix_mul( hl_sparse_matrix_mul(
A_d, a_trans, B_d, b_trans, C_d, dimM, dimN, dimK, scaleAB, scaleT); A_d, a_trans, B_d, b_trans, C_d, dimM, dimN, dimK, scaleAB, scaleT);
} }
void GpuSparseMatrix::mul(const MatrixPtr a, void GpuSparseMatrix::mul(const Matrix& a,
const MatrixPtr b, const Matrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT) {
if (std::dynamic_pointer_cast<GpuMatrix>(a) && const auto a_ptr = dynamic_cast<const GpuMatrix*>(&a);
std::dynamic_pointer_cast<GpuMatrix>(b)) { const auto b_ptr = dynamic_cast<const GpuMatrix*>(&b);
GpuMatrixPtr a_ptr = std::dynamic_pointer_cast<GpuMatrix>(a); if (a_ptr && b_ptr) {
GpuMatrixPtr b_ptr = std::dynamic_pointer_cast<GpuMatrix>(b); mul(*a_ptr, *b_ptr, scaleAB, scaleT);
mul(a_ptr, b_ptr, scaleAB, scaleT);
} else { } else {
LOG(FATAL) << "not supported"; LOG(FATAL) << "not supported";
} }
......
...@@ -104,10 +104,7 @@ public: ...@@ -104,10 +104,7 @@ public:
size_t newNnz, size_t newNnz,
SparseValueType valueType); SparseValueType valueType);
void mul(const GpuMatrixPtr a, void mul(const GpuMatrix& a, const GpuMatrix& b, real scaleAB, real scaleT);
const GpuMatrixPtr b,
real scaleAB,
real scaleT);
/// B = A , B.trans = !A.trans /// B = A , B.trans = !A.trans
MatrixPtr getTranspose(); MatrixPtr getTranspose();
...@@ -218,7 +215,7 @@ protected: ...@@ -218,7 +215,7 @@ protected:
void copyRow(int offsets, size_t colNum, const sparse_float_value_t* row); void copyRow(int offsets, size_t colNum, const sparse_float_value_t* row);
public: public:
void mul(const MatrixPtr a, const MatrixPtr b, real scaleAB, real scaleT); void mul(const Matrix& a, const Matrix& b, real scaleAB, real scaleT);
void copyFrom(CpuSparseMatrix& src, hl_stream_t stream); void copyFrom(CpuSparseMatrix& src, hl_stream_t stream);
void copyFrom(GpuSparseMatrix& src, hl_stream_t stream); void copyFrom(GpuSparseMatrix& src, hl_stream_t stream);
......
...@@ -33,8 +33,8 @@ TEST(Matrix, CopyCpuMatrixToSparseMatrix) { ...@@ -33,8 +33,8 @@ TEST(Matrix, CopyCpuMatrixToSparseMatrix) {
ret2(new CpuMatrix(HEIGHT, WIDTH_TEST)); ret2(new CpuMatrix(HEIGHT, WIDTH_TEST));
ret1->zeroMem(); ret1->zeroMem();
ret2->zeroMem(); ret2->zeroMem();
ret1->mul(testMatrix, mulCpuMatrix, 1.0, 1.0); ret1->mul(*testMatrix, *mulCpuMatrix, 1.0, 1.0);
ret2->mul(testCpuMatrix, mulCpuMatrix, 1.0, 1.0); ret2->mul(*testCpuMatrix, *mulCpuMatrix, 1.0, 1.0);
checkMatrixEqual(ret1, ret2); checkMatrixEqual(ret1, ret2);
} }
...@@ -147,9 +147,9 @@ void test_sparse_matrix_mul(MatrixPara paraA, ...@@ -147,9 +147,9 @@ void test_sparse_matrix_mul(MatrixPara paraA,
hl_stream_synchronize(stream); hl_stream_synchronize(stream);
/*matrix mul*/ /*matrix mul*/
cpuMatrixC->mul(cpuMatrixA, cpuMatrixB, 1.0, 1.0); cpuMatrixC->mul(*cpuMatrixA, *cpuMatrixB, 1.0, 1.0);
gpuMatrixC->mul(gpuMatrixA, gpuMatrixB, 1.0, 1.0); gpuMatrixC->mul(*gpuMatrixA, *gpuMatrixB, 1.0, 1.0);
cpuDenseC->mul(cpuDenseA, cpuDenseB, 1.0, 1.0); cpuDenseC->mul(*cpuDenseA, *cpuDenseB, 1.0, 1.0);
gpuMatrixC_d2h->copyFrom(*gpuMatrixC, stream); gpuMatrixC_d2h->copyFrom(*gpuMatrixC, stream);
hl_stream_synchronize(stream); hl_stream_synchronize(stream);
...@@ -224,8 +224,8 @@ TEST(Matrix, CopySparseMatrixToGpuSparseMatrix) { ...@@ -224,8 +224,8 @@ TEST(Matrix, CopySparseMatrixToGpuSparseMatrix) {
MatrixPtr ret2(new GpuMatrix(HEIGHT, WIDTH_TEST)); MatrixPtr ret2(new GpuMatrix(HEIGHT, WIDTH_TEST));
ret1->zeroMem(); ret1->zeroMem();
ret2->zeroMem(); ret2->zeroMem();
ret1->mul(testMatrix, mulCpuMatrix, 1.0, 1.0); ret1->mul(*testMatrix, *mulCpuMatrix, 1.0, 1.0);
ret2->mul(testGpuMatrix, mulGpuMatrix, 1.0, 1.0); ret2->mul(*testGpuMatrix, *mulGpuMatrix, 1.0, 1.0);
checkMatrixEqual(ret1, ret2); checkMatrixEqual(ret1, ret2);
} }
......
...@@ -65,16 +65,16 @@ void testMatrixProjectionForward(int contextStart, ...@@ -65,16 +65,16 @@ void testMatrixProjectionForward(int contextStart,
// calculate // calculate
int beginPad = std::max(0, -contextStart); int beginPad = std::max(0, -contextStart);
cpuOutput->contextProjectionForward(cpuInput, cpuOutput->contextProjectionForward(*cpuInput,
cpuWeight, cpuWeight.get(),
*cpuSequence, *cpuSequence,
contextLength, contextLength,
contextStart, contextStart,
beginPad, beginPad,
padding); padding);
gpuOutput->contextProjectionForward(gpuInput, gpuOutput->contextProjectionForward(*gpuInput,
gpuWeight, gpuWeight.get(),
*gpuSequence, *gpuSequence,
contextLength, contextLength,
contextStart, contextStart,
...@@ -120,17 +120,17 @@ void testMatrixProjectionBackward(int contextStart, ...@@ -120,17 +120,17 @@ void testMatrixProjectionBackward(int contextStart,
// calculate // calculate
int beginPad = std::max(0, -contextStart); int beginPad = std::max(0, -contextStart);
cpuOutputGrad->contextProjectionBackward(cpuInputGrad, cpuOutputGrad->contextProjectionBackward(cpuInputGrad.get(),
cpuWeightGrad, cpuWeightGrad.get(),
*cpuSequence, *cpuSequence,
contextLength, contextLength,
contextStart, contextStart,
beginPad, beginPad,
padding); padding);
gpuOutputGrad->contextProjectionBackwardData( gpuOutputGrad->contextProjectionBackwardData(
gpuInputGrad, *gpuSequence, contextLength, contextStart); *gpuInputGrad, *gpuSequence, contextLength, contextStart);
if (padding) { if (padding) {
gpuOutputGrad->contextProjectionBackwardWeight(gpuWeightGrad, gpuOutputGrad->contextProjectionBackwardWeight(*gpuWeightGrad,
*gpuSequence, *gpuSequence,
contextLength, contextLength,
contextStart, contextStart,
...@@ -318,7 +318,7 @@ void testMatrixInverse(int height) { ...@@ -318,7 +318,7 @@ void testMatrixInverse(int height) {
cpu->randomizeUniform(); cpu->randomizeUniform();
MatrixPtr cpuT = cpu->getTranspose(); MatrixPtr cpuT = cpu->getTranspose();
MatrixPtr outputCheck = std::make_shared<CpuMatrix>(height, height); MatrixPtr outputCheck = std::make_shared<CpuMatrix>(height, height);
outputCheck->mul(cpu, cpuT); outputCheck->mul(*cpu, *cpuT);
cpu->setDiag(1.0); cpu->setDiag(1.0);
cpu->add(*outputCheck); cpu->add(*outputCheck);
...@@ -328,7 +328,7 @@ void testMatrixInverse(int height) { ...@@ -328,7 +328,7 @@ void testMatrixInverse(int height) {
TensorCheckErr(*cpuI, *gpuI); TensorCheckErr(*cpuI, *gpuI);
outputCheck->mul(cpu, cpuI); outputCheck->mul(*cpu, *cpuI);
cpu->setDiag(1.0); cpu->setDiag(1.0);
TensorCheckErr(*cpu, *outputCheck); TensorCheckErr(*cpu, *outputCheck);
} }
...@@ -509,8 +509,8 @@ void testMatrixMul(bool transa, bool transb, int dimM, int dimN, int dimK) { ...@@ -509,8 +509,8 @@ void testMatrixMul(bool transa, bool transb, int dimM, int dimN, int dimK) {
gpuB->copyFrom(*cpuB); gpuB->copyFrom(*cpuB);
gpuC->copyFrom(*cpuC); gpuC->copyFrom(*cpuC);
cpuC->mul(cpuA, cpuB, alpha, beta); cpuC->mul(*cpuA, *cpuB, alpha, beta);
gpuC->mul(gpuA, gpuB, alpha, beta); gpuC->mul(*gpuA, *gpuB, alpha, beta);
TensorCheckErr(*cpuC, *gpuC); TensorCheckErr(*cpuC, *gpuC);
} }
...@@ -581,8 +581,8 @@ void testSubMatrixMul(bool transa, bool transb, int dimM, int dimN, int dimK) { ...@@ -581,8 +581,8 @@ void testSubMatrixMul(bool transa, bool transb, int dimM, int dimN, int dimK) {
MatrixPtr subCpuC = cpuC->subMatrix(startM, endM, startN, endN); MatrixPtr subCpuC = cpuC->subMatrix(startM, endM, startN, endN);
MatrixPtr subGpuC = gpuC->subMatrix(startM, endM, startN, endN); MatrixPtr subGpuC = gpuC->subMatrix(startM, endM, startN, endN);
subCpuC->mul(subCpuA, subCpuB, alpha, beta); subCpuC->mul(*subCpuA, *subCpuB, alpha, beta);
subGpuC->mul(subGpuA, subGpuB, alpha, beta); subGpuC->mul(*subGpuA, *subGpuB, alpha, beta);
TensorCheckErr(*cpuC, *gpuC); TensorCheckErr(*cpuC, *gpuC);
} }
...@@ -939,8 +939,8 @@ void testClassificationError(int numSamples, int dim) { ...@@ -939,8 +939,8 @@ void testClassificationError(int numSamples, int dim) {
gpuOutput->copyFrom(*cpuOutput); gpuOutput->copyFrom(*cpuOutput);
gpuLabel->copyFrom(*cpuLabel); gpuLabel->copyFrom(*cpuLabel);
cpuError->classificationError(cpuOutput, cpuLabel); cpuError->classificationError(*cpuOutput, *cpuLabel);
gpuError->classificationError(gpuOutput, gpuLabel); gpuError->classificationError(*gpuOutput, *gpuLabel);
TensorCheckEqual(*cpuError, *gpuError); TensorCheckEqual(*cpuError, *gpuError);
} }
......
...@@ -102,8 +102,8 @@ void testSpMatrixMul(int M, int N, int K, real rate) { ...@@ -102,8 +102,8 @@ void testSpMatrixMul(int M, int N, int K, real rate) {
gpuC->copyFrom(*cpuC, stream); gpuC->copyFrom(*cpuC, stream);
hl_stream_synchronize(stream); hl_stream_synchronize(stream);
cpuC->mul(cpuA, cpuB->getTranspose(), 1, 1); cpuC->mul(*cpuA, *cpuB->getTranspose(), 1, 1);
gpuC->mul(gpuA, gpuB->getTranspose(), 1, 1); gpuC->mul(*gpuA, *gpuB->getTranspose(), 1, 1);
MatrixPtr outputCheck(new CpuSparseMatrix(M, N, nnz)); MatrixPtr outputCheck(new CpuSparseMatrix(M, N, nnz));
outputCheck->copyFrom(*gpuC, stream); outputCheck->copyFrom(*gpuC, stream);
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
void ParameterUpdater::init(std::vector<ParameterPtr>& parameters) { void ParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
parameters_ = parameters; parameters_ = parameters;
for (ParameterType type : getParameterTypes()) { for (ParameterType type : getParameterTypes()) {
for (auto& para : parameters) { for (auto& para : parameters) {
......
...@@ -32,7 +32,7 @@ public: ...@@ -32,7 +32,7 @@ public:
parameterTypes_.push_back(type); parameterTypes_.push_back(type);
} }
virtual void init(std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
// called by Trainer when starting a new pass // called by Trainer when starting a new pass
virtual void startPass() {} virtual void startPass() {}
...@@ -105,7 +105,7 @@ public: ...@@ -105,7 +105,7 @@ public:
ParameterUpdaterComposite() {} ParameterUpdaterComposite() {}
virtual ~ParameterUpdaterComposite() {} virtual ~ParameterUpdaterComposite() {}
virtual void init(std::vector<ParameterPtr>& parameters) = 0; virtual void init(const std::vector<ParameterPtr>& parameters) = 0;
virtual void startPass() { virtual void startPass() {
syncThreadPool_->execPlusOwner( syncThreadPool_->execPlusOwner(
......
...@@ -2,6 +2,8 @@ FROM ubuntu:14.04 ...@@ -2,6 +2,8 @@ FROM ubuntu:14.04
MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com> MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
ARG UBUNTU_MIRROR
RUN /bin/bash -c 'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i 's#http://archive.ubuntu.com#${UBUNTU_MIRROR}#g' /etc/apt/sources.list; fi'
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y cmake libprotobuf-dev protobuf-compiler git \ && apt-get install -y cmake libprotobuf-dev protobuf-compiler git \
libgoogle-glog-dev libgflags-dev libgtest-dev \ libgoogle-glog-dev libgflags-dev libgtest-dev \
......
...@@ -2,6 +2,8 @@ FROM nvidia/cuda:7.5-cudnn5-devel-ubuntu14.04 ...@@ -2,6 +2,8 @@ FROM nvidia/cuda:7.5-cudnn5-devel-ubuntu14.04
MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com> MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
ARG UBUNTU_MIRROR
RUN /bin/bash -c 'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i 's#http://archive.ubuntu.com#${UBUNTU_MIRROR}#g' /etc/apt/sources.list; fi'
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y cmake libprotobuf-dev protobuf-compiler git \ && apt-get install -y cmake libprotobuf-dev protobuf-compiler git \
libgoogle-glog-dev libgflags-dev libgtest-dev \ libgoogle-glog-dev libgflags-dev libgtest-dev \
......
...@@ -34,7 +34,8 @@ SgdUpdaterWithCpuAverager::SgdUpdaterWithCpuAverager( ...@@ -34,7 +34,8 @@ SgdUpdaterWithCpuAverager::SgdUpdaterWithCpuAverager(
updateWorker_.addJob([]() { hl_set_device(FLAGS_gpu_id); }); updateWorker_.addJob([]() { hl_set_device(FLAGS_gpu_id); });
} }
void SgdUpdaterWithCpuAverager::init(std::vector<ParameterPtr>& parameters) { void SgdUpdaterWithCpuAverager::init(
const std::vector<ParameterPtr>& parameters) {
SgdLocalUpdater::init(parameters); SgdLocalUpdater::init(parameters);
averager_->init(parameters_.size(), nullptr); averager_->init(parameters_.size(), nullptr);
copyEvents_.resize(parameters_.size()); copyEvents_.resize(parameters_.size());
......
...@@ -64,7 +64,7 @@ public: ...@@ -64,7 +64,7 @@ public:
* be initialized. * be initialized.
* @param parameters The parameter need to be initialized. * @param parameters The parameter need to be initialized.
*/ */
virtual void init(std::vector<ParameterPtr>& parameters) { virtual void init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters); ParameterUpdater::init(parameters);
optimizer_->init(parameters_.size(), nullptr); optimizer_->init(parameters_.size(), nullptr);
// check no L1 decay in parameter configs // check no L1 decay in parameter configs
...@@ -208,7 +208,7 @@ public: ...@@ -208,7 +208,7 @@ public:
* @brief init. Initialize cpu parameters, model average optimizer. * @brief init. Initialize cpu parameters, model average optimizer.
* @param parameters * @param parameters
*/ */
virtual void init(std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
virtual PassType startBatch(int64_t batchSize) { virtual PassType startBatch(int64_t batchSize) {
averager_->startBatch(-1UL); averager_->startBatch(-1UL);
......
...@@ -44,7 +44,7 @@ RemoteParameterUpdater::RemoteParameterUpdater( ...@@ -44,7 +44,7 @@ RemoteParameterUpdater::RemoteParameterUpdater(
addParameterType(PARAMETER_MOMENTUM); addParameterType(PARAMETER_MOMENTUM);
} }
void RemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) { void RemoteParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters); ParameterUpdater::init(parameters);
if (localUpdater_) { if (localUpdater_) {
...@@ -595,7 +595,8 @@ SparseRemoteParameterUpdater::SparseRemoteParameterUpdater( ...@@ -595,7 +595,8 @@ SparseRemoteParameterUpdater::SparseRemoteParameterUpdater(
testing_(testing), testing_(testing),
useApplyInPserver_(false) {} useApplyInPserver_(false) {}
void SparseRemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) { void SparseRemoteParameterUpdater::init(
const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters); ParameterUpdater::init(parameters);
parameterClient_.reset(new ParameterClient2( parameterClient_.reset(new ParameterClient2(
...@@ -809,7 +810,7 @@ void SparseRemoteParameterUpdater::saveParametersRemote( ...@@ -809,7 +810,7 @@ void SparseRemoteParameterUpdater::saveParametersRemote(
} }
void SparseRemoteParameterUpdaterComposite::init( void SparseRemoteParameterUpdaterComposite::init(
std::vector<ParameterPtr>& parameters) { const std::vector<ParameterPtr>& parameters) {
parameters_ = parameters; parameters_ = parameters;
std::vector<ParameterPtr> parametersArray[NUMBER_UPDATERS]; std::vector<ParameterPtr> parametersArray[NUMBER_UPDATERS];
......
...@@ -67,7 +67,7 @@ public: ...@@ -67,7 +67,7 @@ public:
/** /**
* initialize the internal parameter client and itself. * initialize the internal parameter client and itself.
*/ */
virtual void init(std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
/** /**
* @brief start batch * @brief start batch
* *
...@@ -274,7 +274,7 @@ public: ...@@ -274,7 +274,7 @@ public:
} }
/// initialization /// initialization
virtual void init(std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
/// stateful batch control /// stateful batch control
virtual PassType startBatch(int64_t batchSize); virtual PassType startBatch(int64_t batchSize);
...@@ -360,7 +360,7 @@ public: ...@@ -360,7 +360,7 @@ public:
} }
/// initialization of dense and sparse updaters /// initialization of dense and sparse updaters
virtual void init(std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
}; };
class ParameterUpdaterCreators { class ParameterUpdaterCreators {
......
...@@ -32,7 +32,7 @@ SgdThreadUpdater::SgdThreadUpdater(const OptimizationConfig& optConfig) ...@@ -32,7 +32,7 @@ SgdThreadUpdater::SgdThreadUpdater(const OptimizationConfig& optConfig)
} }
} }
void SgdThreadUpdater::init(std::vector<ParameterPtr>& parameters) { void SgdThreadUpdater::init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters); ParameterUpdater::init(parameters);
// calc max parameter id // calc max parameter id
......
...@@ -49,7 +49,7 @@ public: ...@@ -49,7 +49,7 @@ public:
// Use the finishPass() function of the base optimizer. // Use the finishPass() function of the base optimizer.
virtual bool finishPass(real cost); virtual bool finishPass(real cost);
virtual void init(std::vector<ParameterPtr>& parameters); virtual void init(const std::vector<ParameterPtr>& parameters);
virtual PassType startBatch(int64_t batchSize); virtual PassType startBatch(int64_t batchSize);
// Call finishBatch for each optimizer. // Call finishBatch for each optimizer.
virtual void finishBatch(real cost); virtual void finishBatch(real cost);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册