提交 17fe8322 编写于 作者: H hedaoyuan 提交者: GitHub

Merge pull request #2282 from hedaoyuan/convolution

Add convolution Function
...@@ -14,8 +14,8 @@ add_library(paddle_function STATIC ${cpp_files} ${cu_objs}) ...@@ -14,8 +14,8 @@ add_library(paddle_function STATIC ${cpp_files} ${cu_objs})
add_dependencies(paddle_function ${external_project_dependencies}) add_dependencies(paddle_function ${external_project_dependencies})
add_dependencies(paddle_function gen_proto_cpp) add_dependencies(paddle_function gen_proto_cpp)
if(WITH_GPU)
if(WITH_TESTING) if(WITH_TESTING)
if(WITH_GPU)
# TODO: # TODO:
# file(GLOB test_files . *OpTest.cpp) # file(GLOB test_files . *OpTest.cpp)
# add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files}) # add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files})
...@@ -30,6 +30,8 @@ if(WITH_TESTING) ...@@ -30,6 +30,8 @@ if(WITH_TESTING)
add_simple_unittest(CosSimOpTest) add_simple_unittest(CosSimOpTest)
add_simple_unittest(RowConvOpTest) add_simple_unittest(RowConvOpTest)
endif() endif()
add_simple_unittest(ConvOpTest)
endif() endif()
add_style_check_target(paddle_function ${h_files}) add_style_check_target(paddle_function ${h_files})
......
...@@ -28,7 +28,7 @@ void testMatrixProjectionForward(int context_start, ...@@ -28,7 +28,7 @@ void testMatrixProjectionForward(int context_start,
std::max(0, (int)(context_start + context_length - 1)); std::max(0, (int)(context_start + context_length - 1));
if (pad == 0) is_padding = false; if (pad == 0) is_padding = false;
FunctionCompare test( CpuGpuFuncCompare test(
"ContextProjectionForward", "ContextProjectionForward",
FuncConfig() FuncConfig()
.set("context_length", context_length) .set("context_length", context_length)
...@@ -60,7 +60,7 @@ void testMatrixProjectionBackward(int context_start, ...@@ -60,7 +60,7 @@ void testMatrixProjectionBackward(int context_start,
std::max(0, (int)(context_start + context_length - 1)); std::max(0, (int)(context_start + context_length - 1));
if (pad == 0) is_padding = false; if (pad == 0) is_padding = false;
FunctionCompare test( CpuGpuFuncCompare test(
"ContextProjectionBackward", "ContextProjectionBackward",
FuncConfig() FuncConfig()
.set("context_length", context_length) .set("context_length", context_length)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
/*
* \brief Based on the ConvFunctionBase class, the forward calculation,
* backward input calculation and backward filter calculation
* of convolution operations can be implemented.
*
* Arguments of forward and backward calculation:
* 1. Forward calculation of convolution.
* inputs = {INPUT, FILTER}, outputs = {OUTPUT}
* The first and second input arguments are input image and filter data.
* The output argument is output image.
*
* 2. Backward input calculation of convolution.
* inputs = {OUTPUT_GRAD, FILTER}, outputs = {INPUT_GRAD}
* The first and second input arguments are output grad image
* and filter data.
* The output argument is input grad image.
*
* 3. Backward filter calculation of convolution.
* inputs = {OUTPUT_GRAD, INPUT}, outputs = {FILTER_GRAD}
* The first and second input arguments are output grad image
* and input image.
* The output argument is filter grad.
*
* Arguments format of input, filter and output:
* 1. Input image, output image, input image gradient, output image gradient
* are all NCHW format. Where N is batch size, C is the number of channels,
* H and W is the height and width of image or image gradient.
*
* 2. The format of the filter data is MCHW, where M is the number of output
* image channels, C is the number of input image channels,
* H and W is height and width of filter.
*
* If `groups` is greater than 1, the filter's data format should be GMCHW,
* where G is the `groups`, and G * M is the number of output image
* channels, G * C is the number of input image channels,
* H and W is height and width of filter.
*/
class ConvFunctionBase : public FunctionBase {
public:
void init(const FuncConfig& config) override {
// function arguments
strides_ = config.get<std::vector<size_t>>("strides");
paddings_ = config.get<std::vector<size_t>>("paddings");
groups_ = config.get<size_t>("groups");
// number of inputs and outputs
numInputs_ = 2;
numOutputs_ = 1;
}
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
// input can be INPUT and INPUT_GRAD
// filter can be FILTER and FILTER_GRAD
// output can be OUTPUT and OUTPUT_GRAD
void check(const TensorShape& input,
const TensorShape& filter,
const TensorShape& output) {
// inputs and outputs arguments should be 4-dimensional.
CHECK_EQ(input.ndims(), (size_t)4);
CHECK_EQ(output.ndims(), (size_t)4);
// The batchSize of the input needs to be equal to
// the batchSize of the output.
CHECK_EQ(input[0], output[0]);
if (filter.ndims() == (size_t)4) {
// If the filter's dimension is 4, groups convolution is not supported.
CHECK_EQ(groups_, (size_t)1);
// The input and output channel dimensions are the second and first
// dimensions of the filter shape.
CHECK_EQ(input[1], filter[1]);
CHECK_EQ(output[1], filter[0]);
} else {
// filter argument should be 5-dimensional.
CHECK_EQ(filter.ndims(), (size_t)5);
// The first dimension of the filter is the size of the group
CHECK_EQ(filter[0], groups_);
// The input and output channel dimensions are the third and second
// dimensions of the filter shape.
CHECK_EQ(input[1], filter[2] * groups_);
CHECK_EQ(output[1], filter[1] * groups_);
}
}
protected:
size_t getFilterHeight(const TensorShape& filter) const {
return filter[filter.ndims() - 2];
}
size_t getFilterWidth(const TensorShape& filter) const {
return filter[filter.ndims() - 1];
}
std::vector<size_t> strides_;
std::vector<size_t> paddings_;
/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
/// filters are only connected to the first half of the input channels,
/// and the second half only connected to the second half.
size_t groups_;
inline int strideH() const { return strides_[0]; }
inline int strideW() const { return strides_[1]; }
inline int paddingH() const { return paddings_[0]; }
inline int paddingW() const { return paddings_[1]; }
// A temporary memory in convolution calculation.
MemoryHandlePtr memory_;
template <DeviceType Device>
void resizeBuffer(size_t newSize) {
if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) {
if (Device == DEVICE_TYPE_CPU) {
memory_ = std::make_shared<CpuMemoryHandle>(newSize * sizeof(real));
} else {
memory_ = std::make_shared<GpuMemoryHandle>(newSize * sizeof(real));
}
}
}
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "Function.h"
#include "FunctionTest.h"
namespace paddle {
enum TestType {
kForwardTest = 0,
kBackwardInputTest = 1,
kBackwardFilterTest = 2,
};
template <DeviceType DType1, DeviceType DType2>
class ConvolutionTest {
public:
ConvolutionTest(const std::string& conv1,
const std::string& conv2,
TestType type,
std::string algo = "auto") {
for (size_t batchSize : {1, 32}) {
for (size_t inputSize : {7, 14, 54}) {
for (size_t filterSize : {1, 3, 5}) {
for (size_t inputChannels : {3, 64}) {
for (size_t outputChannels : {3, 64, 128}) {
if (inputChannels < outputChannels) break;
for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) {
if (padding >= filterSize) break;
size_t outputSize =
(inputSize - filterSize + 2 * padding + stride) / stride;
VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", algo));
TensorShape input{
batchSize, inputChannels, inputSize, inputSize};
TensorShape filter{
outputChannels, inputChannels, filterSize, filterSize};
TensorShape output{
batchSize, outputChannels, outputSize, outputSize};
if (type == kForwardTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.run();
} else if (type == kBackwardInputTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
test.run();
} else if (type == kBackwardFilterTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.run();
}
}
}
}
}
}
}
}
}
};
// Mainly used to test cases where the height and width (input, filter)
// are not equal.
template <DeviceType DType1, DeviceType DType2>
class ConvolutionTest2 {
public:
ConvolutionTest2(const std::string& conv1,
const std::string& conv2,
TestType type,
std::string algo = "auto") {
for (size_t batchSize : {16}) {
for (size_t inputHeight : {7, 31}) {
for (size_t inputWidth : {10, 54}) {
for (size_t filterHeight : {1, 5}) {
for (size_t filterWidth : {3, 7}) {
for (size_t inputChannels : {7}) {
for (size_t outputChannels : {32}) {
size_t stride = 1;
size_t padding = 0;
size_t outputHeight =
(inputHeight - filterHeight + 2 * padding + stride) /
stride;
size_t outputWidth =
(inputWidth - filterWidth + 2 * padding + stride) /
stride;
VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputHeight
<< " inputWidth=" << inputWidth
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterHeight
<< " filterWidth=" << filterWidth
<< " outputHeight=" << outputHeight
<< " outputWidth=" << outputWidth
<< " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", algo));
TensorShape input{
batchSize, inputChannels, inputHeight, inputWidth};
TensorShape filter{
outputChannels, inputChannels, filterHeight, filterWidth};
TensorShape output{
batchSize, outputChannels, outputHeight, outputWidth};
if (type == kForwardTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.run();
} else if (type == kBackwardInputTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
test.run();
} else if (type == kBackwardFilterTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.run();
}
}
}
}
}
}
}
}
}
};
TEST(Forward, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
"NaiveConv-CPU", "GemmConv-CPU", kForwardTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test2(
"NaiveConv-CPU", "GemmConv-CPU", kForwardTest);
}
#ifndef PADDLE_ONLY_CPU
TEST(Forward, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConv-CPU", "GemmConv-GPU", kForwardTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConv-CPU", "GemmConv-GPU", kForwardTest);
}
TEST(BackwardInput, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest);
}
TEST(BackwardFilter, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest);
}
#endif
} // namespace paddle
...@@ -22,7 +22,7 @@ void testCosSimForward(size_t height_x, ...@@ -22,7 +22,7 @@ void testCosSimForward(size_t height_x,
size_t height_y, size_t height_y,
size_t width, size_t width,
real scale) { real scale) {
FunctionCompare test("CosSimForward", FuncConfig().set("scale", scale)); CpuGpuFuncCompare test("CosSimForward", FuncConfig().set("scale", scale));
// prepare input arguments // prepare input arguments
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, width}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_y, width}));
...@@ -36,7 +36,7 @@ void testCosSimBackward(size_t height_x, ...@@ -36,7 +36,7 @@ void testCosSimBackward(size_t height_x,
size_t height_y, size_t height_y,
size_t width, size_t width,
real scale) { real scale) {
FunctionCompare test("CosSimBackward", FuncConfig().set("scale", scale)); CpuGpuFuncCompare test("CosSimBackward", FuncConfig().set("scale", scale));
// prepare input arguments // prepare input arguments
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{height_x, 1}));
......
...@@ -28,7 +28,7 @@ TEST(CrossMapNormal, real) { ...@@ -28,7 +28,7 @@ TEST(CrossMapNormal, real) {
<< " size=" << size; << " size=" << size;
// init Test object // init Test object
FunctionCompare test("CrossMapNormal", CpuGpuFuncCompare test("CrossMapNormal",
FuncConfig() FuncConfig()
.set("size", size) .set("size", size)
.set("scale", (real)1.5) .set("scale", (real)1.5)
...@@ -57,7 +57,7 @@ TEST(CrossMapNormalGrad, real) { ...@@ -57,7 +57,7 @@ TEST(CrossMapNormalGrad, real) {
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW
<< " size=" << size; << " size=" << size;
FunctionCompare test("CrossMapNormalGrad", CpuGpuFuncCompare test("CrossMapNormalGrad",
FuncConfig() FuncConfig()
.set("size", size) .set("size", size)
.set("scale", (real)1.5) .set("scale", (real)1.5)
......
...@@ -22,14 +22,62 @@ namespace paddle { ...@@ -22,14 +22,62 @@ namespace paddle {
typedef std::shared_ptr<BufferArg> BufferArgPtr; typedef std::shared_ptr<BufferArg> BufferArgPtr;
namespace test {
template <DeviceType DType>
struct Allocator;
template <>
struct Allocator<DEVICE_TYPE_CPU> {
using type = CpuMemoryHandle;
};
template <>
struct Allocator<DEVICE_TYPE_GPU> {
using type = GpuMemoryHandle;
};
// Copy argument1 to argument2
template <DeviceType DType1, DeviceType DType2>
class CopyArgument {
public:
void operator()(const BufferArg& arg1, BufferArg& arg2) {
CHECK_EQ(arg1.valueType(), arg2.valueType());
CHECK_LE(arg1.shape().getElements(), arg2.shape().getElements());
if (arg1.valueType() == VALUE_TYPE_INT32) {
IVectorPtr vector1 =
IVector::create((int*)arg1.data(),
arg1.shape().getElements(),
DType1 == DEVICE_TYPE_CPU ? false : true);
IVectorPtr vector2 =
IVector::create((int*)arg2.data(),
arg2.shape().getElements(),
DType2 == DEVICE_TYPE_CPU ? false : true);
vector2->copyFrom(*vector1);
} else {
VectorPtr vector1 =
Vector::create((real*)arg1.data(),
arg1.shape().getElements(),
DType1 == DEVICE_TYPE_CPU ? false : true);
VectorPtr vector2 =
Vector::create((real*)arg2.data(),
arg2.shape().getElements(),
DType2 == DEVICE_TYPE_CPU ? false : true);
vector2->copyFrom(*vector1);
}
}
};
} // namespace test
/** /**
* \brief A class for comparing CPU and GPU implementations of Function. * \brief A class for comparing two Functions of different implementations.
* * For example, can be used to compare the CPU and GPU implementation
* of the function is consistent.
* *
* Use case: * Use case:
* // Initializes a test object, the corresponding cpu and gpu Function * // Initializes a test object, the corresponding cpu and gpu Function
* // are constructed according to FunctionName and FuncConfig. * // are constructed according to FunctionName and FuncConfig.
* FunctionCompare test(FunctionName, FuncConfig); * CpuGpuFuncCompare test(FunctionName, FuncConfig);
* // Prepare inputs and outputs arguments. * // Prepare inputs and outputs arguments.
* // Here the input and output can not contain real data, * // Here the input and output can not contain real data,
* // only contains the argument type and shape. * // only contains the argument type and shape.
...@@ -45,28 +93,38 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr; ...@@ -45,28 +93,38 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr;
* // Compares CPU and GPU calculation results for consistency. * // Compares CPU and GPU calculation results for consistency.
* test.run(); * test.run();
*/ */
class FunctionCompare { template <DeviceType DType1, DeviceType DType2>
class Compare2Function {
public: public:
FunctionCompare(const std::string& name, const FuncConfig& config) typedef typename test::Allocator<DType1>::type Allocator1;
: cpuFunc_(FunctionBase::funcRegistrar_.createByType(name + "-CPU")), typedef typename test::Allocator<DType2>::type Allocator2;
gpuFunc_(FunctionBase::funcRegistrar_.createByType(name + "-GPU")) { typedef typename Tensor<real, DType1>::Vector Vector1;
cpuFunc_->init(config); typedef typename Tensor<real, DType2>::Vector Vector2;
gpuFunc_->init(config); typedef typename Tensor<real, DType1>::SparseMatrix SparseMatrix1;
typedef typename Tensor<real, DType2>::SparseMatrix SparseMatrix2;
Compare2Function(const std::string& name1,
const std::string& name2,
const FuncConfig& config)
: function1_(FunctionBase::funcRegistrar_.createByType(name1)),
function2_(FunctionBase::funcRegistrar_.createByType(name2)) {
function1_->init(config);
function2_->init(config);
} }
~FunctionCompare() {} ~Compare2Function() {}
// input need only contains shape, do not contains data. // input need only contains shape, do not contains data.
void addInputs(const BufferArg& input) { void addInputs(const BufferArg& input) {
size_t size = size_t size =
input.shape().getElements() * sizeOfValuType(input.valueType()); input.shape().getElements() * sizeOfValuType(input.valueType());
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size)); func1Memory_.emplace_back(std::make_shared<Allocator1>(size));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size)); func2Memory_.emplace_back(std::make_shared<Allocator2>(size));
cpuInputs_.emplace_back(std::make_shared<BufferArg>( func1Inputs_.emplace_back(std::make_shared<BufferArg>(
cpuMemory_.back()->getBuf(), input.valueType(), input.shape())); func1Memory_.back()->getBuf(), input.valueType(), input.shape()));
gpuInputs_.emplace_back(std::make_shared<BufferArg>( func2Inputs_.emplace_back(std::make_shared<BufferArg>(
gpuMemory_.back()->getBuf(), input.valueType(), input.shape())); func2Memory_.back()->getBuf(), input.valueType(), input.shape()));
} }
// assume one copy of sequence is shared by different SequenceArgs // assume one copy of sequence is shared by different SequenceArgs
...@@ -75,62 +133,57 @@ public: ...@@ -75,62 +133,57 @@ public:
size_t batchSize = input.shape()[0]; size_t batchSize = input.shape()[0];
size_t numSeqs = batchSize / 10 + 1; size_t numSeqs = batchSize / 10 + 1;
size_t sizeId = (numSeqs + 1) * sizeOfValuType(VALUE_TYPE_INT32); size_t sizeId = (numSeqs + 1) * sizeOfValuType(VALUE_TYPE_INT32);
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(sizeId)); func1Memory_.emplace_back(std::make_shared<Allocator1>(sizeId));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(sizeId)); func2Memory_.emplace_back(std::make_shared<Allocator2>(sizeId));
cpuSeq_ = std::make_shared<SequenceIdArg>(cpuMemory_.back()->getBuf(), seq1_ = std::make_shared<SequenceIdArg>(func1Memory_.back()->getBuf(),
TensorShape{numSeqs + 1}); TensorShape{numSeqs + 1});
gpuSeq_ = std::make_shared<SequenceIdArg>(gpuMemory_.back()->getBuf(), seq2_ = std::make_shared<SequenceIdArg>(func2Memory_.back()->getBuf(),
TensorShape{numSeqs + 1}); TensorShape{numSeqs + 1});
/// init sequence Id /// init sequence Id
initArg(*cpuSeq_, batchSize); initArg(*seq1_, batchSize);
// todo(tianbing), delete it copyArg_(*seq1_, *seq2_);
CHECK_EQ(cpuSeq_->shape().getElements(), cpuSeq_->numSeqs() + 1);
CpuIVector cpuSeq(cpuSeq_->shape().getElements(), (int*)cpuSeq_->data());
GpuIVector gpuSeq(gpuSeq_->shape().getElements(), (int*)gpuSeq_->data());
gpuSeq.copyFrom(cpuSeq);
} }
void addInputs(const SequenceArg& input) { void addInputs(const SequenceArg& input) {
CHECK_EQ(input.shape().ndims(), 2UL); CHECK_EQ(input.shape().ndims(), 2UL);
size_t batchSize = input.shape()[0]; size_t batchSize = input.shape()[0];
if (!cpuSeq_ || !gpuSeq_) { // sequence not exist if (!seq1_ || !seq2_) { // sequence not exist
addSequence(SequenceIdArg(TensorShape{batchSize})); addSequence(SequenceIdArg(TensorShape{batchSize}));
} }
size_t size = size_t size =
input.shape().getElements() * sizeOfValuType(input.valueType()); input.shape().getElements() * sizeOfValuType(input.valueType());
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size)); func1Memory_.emplace_back(std::make_shared<Allocator1>(size));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size)); func2Memory_.emplace_back(std::make_shared<Allocator2>(size));
/// SequenceArg /// SequenceArg
cpuInputs_.emplace_back( func1Inputs_.emplace_back(
std::make_shared<SequenceArg>(cpuMemory_.back()->getBuf(), std::make_shared<SequenceArg>(func1Memory_.back()->getBuf(),
input.valueType(), input.valueType(),
input.shape(), input.shape(),
*cpuSeq_)); *seq1_));
gpuInputs_.emplace_back( func2Inputs_.emplace_back(
std::make_shared<SequenceArg>(gpuMemory_.back()->getBuf(), std::make_shared<SequenceArg>(func2Memory_.back()->getBuf(),
input.valueType(), input.valueType(),
input.shape(), input.shape(),
*gpuSeq_)); *seq2_));
} }
// output need only contains shape, do not contains data. // output need only contains shape, do not contains data.
void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) { void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) {
size_t size = size_t size =
output.shape().getElements() * sizeOfValuType(output.valueType()); output.shape().getElements() * sizeOfValuType(output.valueType());
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size)); func1Memory_.emplace_back(std::make_shared<Allocator1>(size));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size)); func2Memory_.emplace_back(std::make_shared<Allocator2>(size));
cpuOutputs_.emplace_back( func1Outputs_.emplace_back(
std::make_shared<BufferArg>(cpuMemory_.back()->getBuf(), std::make_shared<BufferArg>(func1Memory_.back()->getBuf(),
output.valueType(), output.valueType(),
output.shape(), output.shape(),
argType)); argType));
gpuOutputs_.emplace_back( func2Outputs_.emplace_back(
std::make_shared<BufferArg>(gpuMemory_.back()->getBuf(), std::make_shared<BufferArg>(func2Memory_.back()->getBuf(),
output.valueType(), output.valueType(),
output.shape(), output.shape(),
argType)); argType));
...@@ -138,14 +191,14 @@ public: ...@@ -138,14 +191,14 @@ public:
/// add and init output sparse matrix /// add and init output sparse matrix
void addOutputs(const SparseMatrixArg& output, ArgType argType = ASSIGN_TO) { void addOutputs(const SparseMatrixArg& output, ArgType argType = ASSIGN_TO) {
cpuSparse_ = std::make_shared<CpuSparseMatrix>( sparse1_ = std::make_shared<SparseMatrix1>(
output.shape()[0], output.shape()[0],
output.shape()[1], output.shape()[1],
output.nnz(), output.nnz(),
static_cast<SparseValueType>(output.dataType()), static_cast<SparseValueType>(output.dataType()),
static_cast<SparseFormat>(output.dataFormat())); static_cast<SparseFormat>(output.dataFormat()));
gpuSparse_ = std::make_shared<GpuSparseMatrix>( sparse2_ = std::make_shared<SparseMatrix2>(
output.shape()[0], output.shape()[0],
output.shape()[1], output.shape()[1],
output.nnz(), output.nnz(),
...@@ -154,52 +207,52 @@ public: ...@@ -154,52 +207,52 @@ public:
/// init sparse matrix /// init sparse matrix
hl_stream_t stream(HPPL_STREAM_1); hl_stream_t stream(HPPL_STREAM_1);
cpuSparse_->randomizeUniform(); sparse1_->randomizeUniform();
gpuSparse_->copyFrom(*cpuSparse_, stream); sparse2_->copyFrom(*sparse1_, stream);
hl_stream_synchronize(stream); hl_stream_synchronize(stream);
cpuOutputs_.emplace_back( func1Outputs_.emplace_back(
std::make_shared<SparseMatrixArg>(*cpuSparse_, argType)); std::make_shared<SparseMatrixArg>(*sparse1_, argType));
gpuOutputs_.emplace_back( func2Outputs_.emplace_back(
std::make_shared<SparseMatrixArg>(*gpuSparse_, argType)); std::make_shared<SparseMatrixArg>(*sparse2_, argType));
} }
void addOutputs(const SequenceArg& output, ArgType argType = ASSIGN_TO) { void addOutputs(const SequenceArg& output, ArgType argType = ASSIGN_TO) {
CHECK_EQ(output.shape().ndims(), 2UL); CHECK_EQ(output.shape().ndims(), 2UL);
size_t batchSize = output.shape()[0]; size_t batchSize = output.shape()[0];
if (!cpuSeq_ || !gpuSeq_) { // sequence not exist if (!seq1_ || !seq2_) { // sequence not exist
addSequence(SequenceIdArg(TensorShape{batchSize})); addSequence(SequenceIdArg(TensorShape{batchSize}));
} }
size_t size = size_t size =
output.shape().getElements() * sizeOfValuType(output.valueType()); output.shape().getElements() * sizeOfValuType(output.valueType());
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size)); func1Memory_.emplace_back(std::make_shared<Allocator1>(size));
gpuMemory_.emplace_back(std::make_shared<GpuMemoryHandle>(size)); func2Memory_.emplace_back(std::make_shared<Allocator2>(size));
/// SequenceArg /// SequenceArg
cpuOutputs_.emplace_back( func1Outputs_.emplace_back(
std::make_shared<SequenceArg>(cpuMemory_.back()->getBuf(), std::make_shared<SequenceArg>(func1Memory_.back()->getBuf(),
output.valueType(), output.valueType(),
output.shape(), output.shape(),
*cpuSeq_, *seq1_,
argType)); argType));
gpuOutputs_.emplace_back( func2Outputs_.emplace_back(
std::make_shared<SequenceArg>(gpuMemory_.back()->getBuf(), std::make_shared<SequenceArg>(func2Memory_.back()->getBuf(),
output.valueType(), output.valueType(),
output.shape(), output.shape(),
*gpuSeq_, *seq2_,
argType)); argType));
} }
void addInputs(const SparseMatrixArg& input) { void addInputs(const SparseMatrixArg& input) {
cpuSparse_ = std::make_shared<CpuSparseMatrix>( sparse1_ = std::make_shared<SparseMatrix1>(
input.shape()[0], input.shape()[0],
input.shape()[1], input.shape()[1],
input.nnz(), input.nnz(),
static_cast<SparseValueType>(input.dataType()), static_cast<SparseValueType>(input.dataType()),
static_cast<SparseFormat>(input.dataFormat())); static_cast<SparseFormat>(input.dataFormat()));
gpuSparse_ = std::make_shared<GpuSparseMatrix>( sparse2_ = std::make_shared<SparseMatrix2>(
input.shape()[0], input.shape()[0],
input.shape()[1], input.shape()[1],
input.nnz(), input.nnz(),
...@@ -208,12 +261,12 @@ public: ...@@ -208,12 +261,12 @@ public:
/// init sparse matrix /// init sparse matrix
hl_stream_t stream(HPPL_STREAM_1); hl_stream_t stream(HPPL_STREAM_1);
cpuSparse_->randomizeUniform(); sparse1_->randomizeUniform();
gpuSparse_->copyFrom(*cpuSparse_, stream); sparse2_->copyFrom(*sparse1_, stream);
hl_stream_synchronize(stream); hl_stream_synchronize(stream);
cpuInputs_.emplace_back(std::make_shared<SparseMatrixArg>(*cpuSparse_)); func1Inputs_.emplace_back(std::make_shared<SparseMatrixArg>(*sparse1_));
gpuInputs_.emplace_back(std::make_shared<SparseMatrixArg>(*gpuSparse_)); func2Inputs_.emplace_back(std::make_shared<SparseMatrixArg>(*sparse2_));
} }
void run() { void run() {
...@@ -236,27 +289,27 @@ public: ...@@ -236,27 +289,27 @@ public:
function->calc(inArgs, outArgs); function->calc(inArgs, outArgs);
}; };
callFunction(cpuFunc_.get(), cpuInputs_, cpuOutputs_); callFunction(function1_.get(), func1Inputs_, func1Outputs_);
callFunction(gpuFunc_.get(), gpuInputs_, gpuOutputs_); callFunction(function2_.get(), func2Inputs_, func2Outputs_);
// check outputs // check outputs
compareOutputs(); compareOutputs();
} }
std::shared_ptr<FunctionBase> getCpuFunction() const { return cpuFunc_; } std::shared_ptr<FunctionBase> getFunction1() const { return function1_; }
std::shared_ptr<FunctionBase> getGpuFunction() const { return gpuFunc_; } std::shared_ptr<FunctionBase> getFunction2() const { return function2_; }
protected: protected:
// only init cpu argument, gpu argument copy from cpu argument. // only init cpu argument, gpu argument copy from cpu argument.
void initArg(BufferArg& arg) { void initArg(BufferArg& arg) {
CpuVector vector(arg.shape().getElements(), (real*)arg.data()); Vector1 vector(arg.shape().getElements(), (real*)arg.data());
vector.uniform(0.001, 1); vector.uniform(0.001, 1);
} }
void initArg(SequenceArg& arg) { void initArg(SequenceArg& arg) {
/// init only matrix /// init only matrix
CpuVector vector(arg.shape().getElements(), (real*)arg.data()); Vector1 vector(arg.shape().getElements(), (real*)arg.data());
vector.uniform(0.001, 1); vector.uniform(0.001, 1);
} }
...@@ -276,73 +329,72 @@ protected: ...@@ -276,73 +329,72 @@ protected:
} }
void initInputs() { void initInputs() {
for (size_t i = 0; i < cpuInputs_.size(); i++) { for (size_t i = 0; i < func1Inputs_.size(); i++) {
if (cpuInputs_[i]->isSparseArg()) { if (func1Inputs_[i]->isSparseArg()) {
continue; /// sparse matrix already init continue; /// sparse matrix already init
} }
if (cpuInputs_[i]->isSequenceArg()) { if (func1Inputs_[i]->isSequenceArg()) {
initArg(dynamic_cast<SequenceArg&>(*cpuInputs_[i])); initArg(dynamic_cast<SequenceArg&>(*func1Inputs_[i]));
} else { } else {
initArg(*cpuInputs_[i]); initArg(*func1Inputs_[i]);
} }
// TODO: Need a BufferCopy used to copy from one BufferArg to another.
CpuVector cpuVector(cpuInputs_[i]->shape().getElements(),
(real*)cpuInputs_[i]->data());
GpuVector gpuVector(gpuInputs_[i]->shape().getElements(),
(real*)gpuInputs_[i]->data());
gpuVector.copyFrom(cpuVector); copyArg_(*func1Inputs_[i], *func2Inputs_[i]);
} }
} }
void initOutputs() { void initOutputs() {
for (size_t i = 0; i < cpuOutputs_.size(); i++) { for (size_t i = 0; i < func1Outputs_.size(); i++) {
if (cpuOutputs_[i]->isSparseArg()) { if (func1Outputs_[i]->isSparseArg()) {
continue; /// sparse matrix already init continue; /// sparse matrix already init
} }
if (cpuOutputs_[i]->isSequenceArg()) { if (func1Outputs_[i]->isSequenceArg()) {
initArg(dynamic_cast<SequenceArg&>(*cpuOutputs_[i])); initArg(dynamic_cast<SequenceArg&>(*func1Outputs_[i]));
} else { } else {
initArg(*cpuOutputs_[i]); initArg(*func1Outputs_[i]);
} }
// TODO: Need a BufferCopy used to copy from one BufferArg to another. copyArg_(*func1Outputs_[i], *func2Outputs_[i]);
CpuVector cpuVector(cpuOutputs_[i]->shape().getElements(),
(real*)cpuOutputs_[i]->data());
GpuVector gpuVector(gpuOutputs_[i]->shape().getElements(),
(real*)gpuOutputs_[i]->data());
gpuVector.copyFrom(cpuVector);
} }
} }
void compareOutputs() { void compareOutputs() {
for (size_t i = 0; i < cpuOutputs_.size(); i++) { for (size_t i = 0; i < func1Outputs_.size(); i++) {
// TODO, Need a BufferCheck used to compare the two buffers. // TODO, Need a BufferCheck used to compare the two buffers.
const auto cpu = cpuOutputs_[i]; const auto cpu = func1Outputs_[i];
const auto gpu = gpuOutputs_[i]; const auto gpu = func2Outputs_[i];
CHECK_EQ(cpu->numElements(), gpu->numElements()); CHECK_EQ(cpu->numElements(), gpu->numElements());
CpuVector cpuVector(cpu->numElements(), (real*)cpu->data()); Vector1 cpuVector(cpu->numElements(), (real*)cpu->data());
GpuVector gpuVector(gpu->numElements(), (real*)gpu->data()); Vector2 gpuVector(gpu->numElements(), (real*)gpu->data());
autotest::TensorCheckErr(cpuVector, gpuVector); autotest::TensorCheckErr(cpuVector, gpuVector);
} }
} }
protected: protected:
std::shared_ptr<FunctionBase> cpuFunc_; std::shared_ptr<FunctionBase> function1_;
std::shared_ptr<FunctionBase> gpuFunc_; std::shared_ptr<FunctionBase> function2_;
std::vector<CpuMemHandlePtr> cpuMemory_; std::vector<std::shared_ptr<Allocator1>> func1Memory_;
std::vector<GpuMemHandlePtr> gpuMemory_; std::vector<std::shared_ptr<Allocator2>> func2Memory_;
std::vector<BufferArgPtr> cpuInputs_; std::vector<BufferArgPtr> func1Inputs_;
std::vector<BufferArgPtr> cpuOutputs_; std::vector<BufferArgPtr> func1Outputs_;
std::vector<BufferArgPtr> gpuInputs_; std::vector<BufferArgPtr> func2Inputs_;
std::vector<BufferArgPtr> gpuOutputs_; std::vector<BufferArgPtr> func2Outputs_;
std::shared_ptr<CpuSparseMatrix> cpuSparse_; std::shared_ptr<SparseMatrix1> sparse1_;
std::shared_ptr<GpuSparseMatrix> gpuSparse_; std::shared_ptr<SparseMatrix2> sparse2_;
std::shared_ptr<SequenceIdArg> cpuSeq_; std::shared_ptr<SequenceIdArg> seq1_;
std::shared_ptr<SequenceIdArg> gpuSeq_; std::shared_ptr<SequenceIdArg> seq2_;
test::CopyArgument<DType1, DType2> copyArg_;
};
class CpuGpuFuncCompare
: public Compare2Function<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> {
public:
CpuGpuFuncCompare(const std::string& name, const FuncConfig& config)
: Compare2Function(name + "-CPU", name + "-GPU", config) {}
~CpuGpuFuncCompare() {}
}; };
} // namespace paddle } // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "GemmConvOp.h"
#include "GemmFunctor.h"
#include "paddle/math/MemoryHandle.h"
namespace paddle {
/*
* imData = [input_channels, input_height, input_width]
* colData = [input_channels, filter_height, filter_width,
* output_height, output_width]
*/
template <class T>
class Im2ColFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* imData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* colData) {
int channelsCol = inputChannels * filterHeight * filterWidth;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 ||
(imColIdx - paddingWidth) >= inputWidth) {
colData[(c * outputHeight + h) * outputWidth + w] = T(0);
} else {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
colData[(c * outputHeight + h) * outputWidth + w] =
imData[imRowIdx * inputWidth + imColIdx];
}
}
}
}
}
};
template <class T>
class Col2ImFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* colData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* imData) {
int channelsCol = inputChannels * filterHeight * filterWidth;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) >= 0 &&
(imRowIdx - paddingHeight) < inputHeight &&
(imColIdx - paddingWidth) >= 0 &&
(imColIdx - paddingWidth) < inputWidth) {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
imData[imRowIdx * inputWidth + imColIdx] +=
colData[(c * outputHeight + h) * outputWidth + w];
}
}
}
}
}
};
/*
* \brief Forward calculation of convolution.
*/
template <DeviceType Device>
class GemmConvFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
// TODO(hedaoyuan): Need to define some index macros,
// to avoid useing 0 and 1.
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
check(input, filter, output);
real beta;
if (outputs[0].getArgType() == ADD_TO) {
beta = 1.0;
} else {
beta = 0.0;
}
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>();
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
outputHeight * outputWidth;
resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<Device, real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
inputChannels / groups_,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
colData);
int M = outputChannels / groups_;
int N = outputHeight * outputWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
1.0f,
filterData + g * filterOffset,
K,
colData,
N,
beta,
outputData + g * outputOffset,
N);
}
inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
}
}
};
/*
* \brief Backward input calculation of convolution.
*/
template <DeviceType Device>
class GemmConvGradInputFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
// Since the implementation of Col2ImFunctor is ADD_TO,
// this function only supports ADD_TO mode.
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const TensorShape& output = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& input = outputs[0].shape();
check(input, filter, output);
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
real* outputGrad = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* inputGrad = outputs[0].data<real>();
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
outputHeight * outputWidth;
resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf());
Col2ImFunctor<Device, real> col2im;
GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
int K = outputChannels / groups_;
int N = outputHeight * outputWidth;
int M = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasTrans,
CblasNoTrans,
M,
N,
K,
1.0f,
filterData + g * filterOffset,
M,
outputGrad + g * outputOffset,
N,
0.0f,
colData,
N);
col2im(colData,
inputChannels / groups_,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
inputGrad + g * inputOffset);
}
inputGrad += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
}
}
};
/*
* \brief Backward filter calculation of convolution.
*/
template <DeviceType Device>
class GemmConvGradFilterFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& output = inputs[0].shape();
const TensorShape& input = inputs[1].shape();
const TensorShape& filter = outputs[0].shape();
check(input, filter, output);
real beta;
if (outputs[0].getArgType() == ADD_TO) {
beta = 1.0;
} else {
beta = 0.0;
}
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
real* outputGrad = inputs[0].data<real>();
real* inputData = inputs[1].data<real>();
real* filterGrad = outputs[0].data<real>();
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
outputHeight * outputWidth;
resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<Device, real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
inputChannels / groups_,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
colData);
int M = outputChannels / groups_;
int K = outputHeight * outputWidth;
int N = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasNoTrans,
CblasTrans,
M,
N,
K,
1.0f,
outputGrad + g * outputOffset,
K,
colData,
K,
i == 0 ? beta : 1.0f,
filterGrad + g * filterOffset,
N);
}
inputData += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
}
}
};
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction);
REGISTER_TYPED_FUNC(GemmConvGradInput, GPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, GPU, GemmConvGradFilterFunction);
#endif
} // namespace paddle
...@@ -14,31 +14,49 @@ limitations under the License. */ ...@@ -14,31 +14,49 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include "ConvOp.h"
#include "ExpandConvBaseLayer.h"
#include "paddle/math/Matrix.h"
namespace paddle { namespace paddle {
/** /*
* @brief A subclass of convolution layer. * imData = [input_channels, input_height, input_width]
* This layer expands input and use matrix multiplication to * colData = [input_channels, filter_height, filter_width,
* calculate convolution transpose (deconv) operation. * output_height, output_width]
*
* The config file api is img_conv_layer with flag trans=True.
*/ */
class ExpandConvTransLayer : public ExpandConvBaseLayer { template <DeviceType Device, class T>
class Im2ColFunctor {
public: public:
explicit ExpandConvTransLayer(const LayerConfig& config) void operator()(const T* imData,
: ExpandConvBaseLayer(config) {} int inputChannels,
int inputHeight,
~ExpandConvTransLayer() {} int inputWidth,
int filterHeight,
bool init(const LayerMap& layerMap, int filterWidth,
const ParameterMap& parameterMap) override; int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* colData);
};
void forward(PassType passType) override; template <DeviceType Device, class T>
void backward(const UpdateCallback& callback) override; class Col2ImFunctor {
public:
void operator()(const T* colData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* imData);
}; };
} // namespace paddle } // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ConvOp.h"
#include "GemmConvOp.h"
namespace paddle {
template<class T>
__global__
void im2col(const T* data_im, int numOuts, int height, int width,
int blockH, int blockW,
int strideH, int strideW,
int paddingH, int paddingW,
int height_col, int width_col,
T* data_col) {
int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < numOuts) {
int w_out = index % width_col;
index /= width_col;
int h_out = index % height_col;
int channel_in = index / height_col;
int channel_out = channel_in * blockH * blockW;
int h_in = h_out * strideH;
int w_in = w_out * strideW;
data_col += (channel_out * height_col + h_out) * width_col + w_out;
for (int i = 0; i < blockH; ++i) {
for (int j = 0; j < blockW; ++j) {
int rIdx = int(h_in+i);
int cIdx = int(w_in+j);
if ((rIdx-(int)paddingH) >= (int)height ||
(rIdx-(int)paddingH) < 0 ||
(cIdx-(int)paddingW) >= (int)width ||
(cIdx-(int)paddingW) < 0) {
*data_col = 0;
} else {
rIdx = rIdx + channel_in*height - paddingH;
cIdx = cIdx - paddingW;
*data_col = data_im[rIdx* width + cIdx];
}
data_col += height_col * width_col;
}
}
}
}
template <class T>
class Im2ColFunctor<DEVICE_TYPE_GPU, T> {
public:
void operator()(const T* imData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* colData) {
int numKernels = inputChannels * outputHeight * outputWidth;
int blocks = (numKernels + 1024 -1) / 1024;
int blockX = 512;
int blockY = (blocks + 512 - 1) / 512;
dim3 threads(1024, 1);
dim3 grid(blockX, blockY);
im2col<T><<< grid, threads, 0, STREAM_DEFAULT >>>
(imData, numKernels, inputHeight, inputWidth, filterHeight, filterWidth,
strideHeight, strideWidth, paddingHeight, paddingWidth,
outputHeight, outputWidth, colData);
CHECK_SYNC("Im2ColFunctor GPU failed");
}
};
template<class T>
__global__
void col2im(size_t n, const T* data_col, size_t height,
size_t width, size_t channels,
size_t blockH, size_t blockW,
size_t strideH, size_t strideW,
size_t paddingH, size_t paddingW,
size_t height_col, size_t width_col,
T* data_im) {
size_t index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < n) {
T val = 0;
int w = int(index % width);
int h = int((index / width) % height);
int c = int(index / (width * height));
if ((w - (int)paddingW) >= 0 &&
(w - (int)paddingW) < (width-2 * paddingW) &&
(h - (int)paddingH) >= 0 &&
(h - paddingH) < (height - 2 * paddingH)) {
// compute the start and end of the output
int w_col_start =
(w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1;
int w_col_end =
min((int)(w / (int)strideW + 1), (int)(width_col));
int h_col_start =
(h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1;
int h_col_end = min(int(h / strideH + 1), int(height_col));
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out]
int c_col = int(c * blockH* blockW) + \
(h - h_col * (int)strideH) * (int)blockW +
(w - w_col * (int)strideW);
val += data_col[(c_col * height_col + h_col) * width_col + w_col];
}
}
h -= paddingH;
w -= paddingW;
data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
h*(width-2*paddingW) + w] += val;
}
}
}
template <class T>
class Col2ImFunctor<DEVICE_TYPE_GPU, T> {
public:
void operator()(const T* colData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* imData) {
size_t numKernels = inputChannels * (inputHeight + 2*paddingHeight)
* (inputWidth + 2*paddingWidth);
size_t blocks = (numKernels + 1024 -1) / 1024;
size_t blockX = 512;
size_t blockY = (blocks+512-1)/512;
dim3 threads(1024, 1);
dim3 grid(blockX, blockY);
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im<T><<< grid, threads, 0, STREAM_DEFAULT >>>
(numKernels,
colData,
inputHeight + 2*paddingHeight,
inputWidth + 2*paddingWidth,
inputChannels,
filterHeight,
filterWidth,
strideHeight,
strideWidth,
paddingHeight,
paddingWidth,
outputHeight,
outputWidth,
imData);
CHECK_SYNC("Col2ImFunctor GPU failed");
}
};
template class Im2ColFunctor<DEVICE_TYPE_GPU, float>;
template class Im2ColFunctor<DEVICE_TYPE_GPU, double>;
template class Col2ImFunctor<DEVICE_TYPE_GPU, float>;
template class Col2ImFunctor<DEVICE_TYPE_GPU, double>;
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/math/MathFunctions.h"
namespace paddle {
// TODO(hedaoyuan): Since the hl_matrix_mul interface does not conform to the
// cblas_dgemm interface's parameter format, it is necessary to introduce
// GemmFunctor as a new interface. Later, when considering the implementation
// of MatMulFunction, we need to consider the reconstruction of hl_matrix_mul
// interface.
template <DeviceType Device, class T>
class GemmFunctor {
public:
void operator()(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc);
};
template <class T>
class GemmFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc) {
gemm<T>(transA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}
};
template <class T>
class GemmFunctor<DEVICE_TYPE_GPU, T> {
public:
void operator()(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc) {
hl_matrix_mul((T*)A,
transA == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T,
(T*)B,
TransB == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T,
C,
M,
N,
K,
alpha,
beta,
lda,
ldb,
ldc);
}
};
} // namespace paddle
...@@ -35,7 +35,7 @@ void testFuncDDDMatrix( ...@@ -35,7 +35,7 @@ void testFuncDDDMatrix(
size_t heightC = dimM; size_t heightC = dimM;
size_t widthC = dimN; size_t widthC = dimN;
// init Test object // init Test object
FunctionCompare test( CpuGpuFuncCompare test(
"MulOp", FuncConfig().set("aTrans", transa).set("bTrans", transb)); "MulOp", FuncConfig().set("aTrans", transa).set("bTrans", transb));
// prepare input arguments // prepare input arguments
/// matrix A : HA * WA /// matrix A : HA * WA
...@@ -81,8 +81,8 @@ void testFuncDSparseDMatrix( ...@@ -81,8 +81,8 @@ void testFuncDSparseDMatrix(
size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) { size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) {
real scaleT = 1.0; real scaleT = 1.0;
// init Test object // init Test object
FunctionCompare test("MulOp", CpuGpuFuncCompare test(
FuncConfig().set("aTrans", false).set("bTrans", false)); "MulOp", FuncConfig().set("aTrans", false).set("bTrans", false));
// prepare input arguments // prepare input arguments
/// sparse matrix A : M * K /// sparse matrix A : M * K
test.addInputs(SparseMatrixArg( test.addInputs(SparseMatrixArg(
...@@ -126,8 +126,8 @@ void testFuncDDSparseMatrix( ...@@ -126,8 +126,8 @@ void testFuncDDSparseMatrix(
size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) { size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) {
real scaleT = 1.0; real scaleT = 1.0;
// init Test object // init Test object
FunctionCompare test("MulOp", CpuGpuFuncCompare test(
FuncConfig().set("aTrans", false).set("bTrans", false)); "MulOp", FuncConfig().set("aTrans", false).set("bTrans", false));
// prepare input arguments // prepare input arguments
/// matrix A : M * K /// matrix A : M * K
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK}));
...@@ -172,8 +172,8 @@ void testFuncSparseDDMatrix( ...@@ -172,8 +172,8 @@ void testFuncSparseDDMatrix(
size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) { size_t dimM, size_t dimN, size_t dimK, size_t nnz, SparseFormat FORMAT) {
real scaleT = 1.0; real scaleT = 1.0;
// init Test object // init Test object
FunctionCompare test("MulOp", CpuGpuFuncCompare test(
FuncConfig().set("aTrans", false).set("bTrans", false)); "MulOp", FuncConfig().set("aTrans", false).set("bTrans", false));
// prepare input arguments // prepare input arguments
/// matrix A : M * K /// matrix A : M * K
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK})); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{dimM, dimK}));
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ConvOp.h"
namespace paddle {
/*
* The three arguments are stored in memory in row major order.
* inputData = [batchSize, inputChannels, inputHeight, inputWidth]
* filterData = [outputChannels, inputChannels, filterHeight, filterWidth]
* outputData = [batchSize, outputChannels, outputHeight, outputWidth]
*/
template <class T>
class NaiveConvFunctor {
public:
void operator()(const T* inputData,
size_t batchSize,
size_t inputChannels,
size_t inputHeight,
size_t inputWidth,
const T* filterData,
size_t filterHeight,
size_t filterWidth,
T* outputData,
size_t outputChannels,
size_t outputHeight,
size_t outputWidth,
size_t paddingH,
size_t paddingW,
size_t strideH,
size_t strideW) {
for (size_t batch = 0; batch < batchSize; batch++) {
for (size_t outC = 0; outC < outputChannels; outC++) {
for (size_t outH = 0; outH < outputHeight; outH++) {
for (size_t outW = 0; outW < outputWidth; outW++) {
const int inStartH = (outH * strideH) - paddingH;
const int inStartW = (outW * strideW) - paddingW;
T outValue = (T)0;
for (size_t inC = 0; inC < inputChannels; inC++) {
for (size_t fH = 0; fH < filterHeight; fH++) {
for (size_t fW = 0; fW < filterWidth; fW++) {
T inValue;
const int inH = inStartH + fH;
const int inW = inStartW + fW;
if ((inH >= 0 && inH < inputHeight) &&
(inW >= 0 && inW < inputWidth)) {
size_t offsetInput =
batch * inputChannels * inputHeight * inputWidth +
inC * inputHeight * inputWidth + inH * inputWidth + inW;
inValue = inputData[offsetInput];
} else {
inValue = (T)0;
}
size_t offsetFilter =
outC * inputChannels * filterHeight * filterWidth +
inC * filterHeight * filterWidth + fH * filterWidth + fW;
T filterValue = filterData[offsetFilter];
outValue += (inValue * filterValue);
}
}
}
size_t offset =
batch * outputChannels * outputHeight * outputWidth +
outC * outputHeight * outputWidth + outH * outputWidth + outW;
outputData[offset] = outValue;
}
}
}
}
}
};
template <DeviceType Device>
class NaiveConvFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
check(input, filter, output);
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
size_t batchSize = inputs[0].shape()[0];
size_t inputChannels = inputs[0].shape()[1];
size_t inputHeight = inputs[0].shape()[2];
size_t inputWidth = inputs[0].shape()[3];
size_t filterHeight = inputs[1].shape()[2];
size_t filterWidth = inputs[1].shape()[3];
size_t outputChannels = outputs[0].shape()[1];
size_t outputHeight = outputs[0].shape()[2];
size_t outputWidth = outputs[0].shape()[3];
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>();
NaiveConvFunctor<real> conv;
conv(inputData,
batchSize,
inputChannels,
inputHeight,
inputWidth,
filterData,
filterHeight,
filterWidth,
outputData,
outputChannels,
outputHeight,
outputWidth,
paddingH(),
paddingW(),
strideH(),
strideW());
}
};
REGISTER_TYPED_FUNC(NaiveConv, CPU, NaiveConvFunction);
} // namespace paddle
...@@ -25,7 +25,7 @@ TEST(Pad, real) { ...@@ -25,7 +25,7 @@ TEST(Pad, real) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
for (bool test_grad : {false, true}) { for (bool test_grad : {false, true}) {
FunctionCompare compare( CpuGpuFuncCompare compare(
test_grad ? "PadGrad" : "Pad", test_grad ? "PadGrad" : "Pad",
FuncConfig() FuncConfig()
.set<std::vector<uint32_t>>("channel", {2, 3}) .set<std::vector<uint32_t>>("channel", {2, 3})
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) { void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) {
FunctionCompare test("RowConv", FuncConfig()); CpuGpuFuncCompare test("RowConv", FuncConfig());
test.addSequence(SequenceIdArg(TensorShape{batchSize})); test.addSequence(SequenceIdArg(TensorShape{batchSize}));
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim})); test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}));
...@@ -31,7 +31,7 @@ void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) { ...@@ -31,7 +31,7 @@ void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) {
} }
void testRowConvBw(size_t batchSize, size_t dim, size_t contextLength) { void testRowConvBw(size_t batchSize, size_t dim, size_t contextLength) {
FunctionCompare test("RowConvGrad", FuncConfig()); CpuGpuFuncCompare test("RowConvGrad", FuncConfig());
test.addSequence(SequenceIdArg(TensorShape{batchSize})); test.addSequence(SequenceIdArg(TensorShape{batchSize}));
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim})); test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}));
......
...@@ -118,11 +118,7 @@ size_t ConvBaseLayer::calOutputSize() { ...@@ -118,11 +118,7 @@ size_t ConvBaseLayer::calOutputSize() {
layerSize = outH[0] * outW[0] * size_t(numFilters_); layerSize = outH[0] * outW[0] * size_t(numFilters_);
}; };
if (isDeconv_) {
setLayerSize(outputH_, outputW_, imgSizeH_, imgSizeW_);
} else {
setLayerSize(imgSizeH_, imgSizeW_, outputH_, outputW_); setLayerSize(imgSizeH_, imgSizeW_, outputH_, outputW_);
}
return layerSize; return layerSize;
} }
......
...@@ -70,14 +70,8 @@ void CudnnConvBaseLayer::forward(PassType passType) { ...@@ -70,14 +70,8 @@ void CudnnConvBaseLayer::forward(PassType passType) {
if (biases_) { if (biases_) {
REGISTER_TIMER_INFO("CudnnConvBiasTimer", getName().c_str()); REGISTER_TIMER_INFO("CudnnConvBiasTimer", getName().c_str());
int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
int outH, outW; int outH = outputH_[0];
if (isDeconv_) { int outW = outputW_[0];
outH = imgSizeH_[0];
outW = imgSizeW_[0];
} else {
outH = outputH_[0];
outW = outputW_[0];
}
hl_tensor_reshape(outputDesc_, hl_tensor_reshape(outputDesc_,
batchSize, batchSize,
......
...@@ -22,26 +22,8 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap, ...@@ -22,26 +22,8 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap,
/* Initialize the basic convolutional parent class */ /* Initialize the basic convolutional parent class */
ConvBaseLayer::init(layerMap, parameterMap); ConvBaseLayer::init(layerMap, parameterMap);
/* The class fields channels_ and numFilters_ are the same as in the config
* i.e., channels_ is the for the input and numFilters_ is for the output
*
* But in order for the variables in convTrans having the same semantic
* meaning as in conv, we need to swap channels_ and numFilters here for
* convTrans, and in other functions too.
* */
/* Initialize the projection */
for (auto &inputConfig : config_.inputs()) { for (auto &inputConfig : config_.inputs()) {
const ConvConfig &conf = inputConfig.conv_conf(); const ConvConfig &conf = inputConfig.conv_conf();
int numFilters = isDeconv_ ? conf.channels() : numFilters_;
subM_.push_back(numFilters / conf.groups());
subN_.push_back(conf.output_x() *
(conf.has_output_y() ? conf.output_y() : conf.output_x()));
int channel = isDeconv_ ? numFilters_ : conf.channels();
subK_.push_back(
channel * conf.filter_size() *
(conf.has_filter_size_y() ? conf.filter_size_y() : conf.filter_size()) /
conf.groups());
/* Consistent caffe mode for multiple input */ /* Consistent caffe mode for multiple input */
caffeMode_ = conf.caffe_mode(); caffeMode_ = conf.caffe_mode();
} }
...@@ -54,17 +36,9 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap, ...@@ -54,17 +36,9 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap,
size_t ExpandConvBaseLayer::getOutputSize() { size_t ExpandConvBaseLayer::getOutputSize() {
CHECK_NE(inputLayers_.size(), 0UL); CHECK_NE(inputLayers_.size(), 0UL);
size_t layerSize = ConvBaseLayer::calOutputSize(); size_t layerSize = ConvBaseLayer::calOutputSize();
subN_.clear();
for (size_t i = 0; i < inputLayers_.size(); i++) {
subN_.push_back(outputH_[i] * outputW_[i]);
}
return layerSize; return layerSize;
} }
void ExpandConvBaseLayer::resetExpandInput(size_t height, size_t width) {
Matrix::resizeOrCreate(expandInput_, height, width, false, useGpu_);
}
void ExpandConvBaseLayer::addSharedBias() { void ExpandConvBaseLayer::addSharedBias() {
size_t mapW = getOutputSize() / numFilters_; size_t mapW = getOutputSize() / numFilters_;
size_t mapH = getOutputValue()->getElementCnt() / mapW; size_t mapH = getOutputValue()->getElementCnt() / mapW;
...@@ -101,173 +75,6 @@ void ExpandConvBaseLayer::addUnsharedBias() { ...@@ -101,173 +75,6 @@ void ExpandConvBaseLayer::addUnsharedBias() {
outValue->addBias(*bias, 1.0f); outValue->addBias(*bias, 1.0f);
} }
void ExpandConvBaseLayer::expandOneFrame(MatrixPtr image,
size_t startIdx,
int inIdx) {
int channel = isDeconv_ ? numFilters_ : channels_[inIdx];
resetExpandInput(subK_[inIdx] * groups_[inIdx], subN_[inIdx]);
CHECK_EQ(image->getWidth(),
static_cast<size_t>(imgSizeH_[inIdx] * imgSizeW_[inIdx] * channel));
real *imgData = image->getData() + startIdx * image->getWidth();
MatrixPtr imageTmp =
Matrix::create(imgData,
1,
imgSizeH_[inIdx] * imgSizeW_[inIdx] * channel,
false,
useGpu_);
expandInput_->convExpand(*imageTmp,
imgSizeH_[inIdx],
imgSizeW_[inIdx],
channel,
filterSizeY_[inIdx],
filterSize_[inIdx],
strideY_[inIdx],
stride_[inIdx],
paddingY_[inIdx],
padding_[inIdx],
outputH_[inIdx],
outputW_[inIdx]);
imageTmp->clear();
}
void ExpandConvBaseLayer::expandFwdOnce(MatrixPtr image,
MatrixPtr out,
int inIdx,
int startIdx) {
int subM = subM_[inIdx];
int subN = subN_[inIdx];
int subK = subK_[inIdx];
expandOneFrame(image, startIdx, inIdx);
int numFilters = isDeconv_ ? channels_[inIdx] : numFilters_;
real *outData = out->getData() + startIdx * subN * numFilters;
real *wgtData = weights_[inIdx]->getW()->getData();
real *expInData = expandInput_->getData();
for (int g = 0; g < groups_[inIdx]; ++g) {
MatrixPtr A =
Matrix::create(wgtData, subM, subK, false, useGpu_); // mark transpose
MatrixPtr B = Matrix::create(expInData, subK, subN, false, useGpu_);
MatrixPtr C = Matrix::create(outData, subM, subN, false, useGpu_);
C->mul(*A, *B, 1, 1);
A->clear();
B->clear();
C->clear();
wgtData += subK * subM;
expInData += subK * subN;
outData += subM * subN;
}
}
void ExpandConvBaseLayer::bpropActs(MatrixPtr out,
MatrixPtr image,
int inpIdx) {
int channel = isDeconv_ ? numFilters_ : channels_[inpIdx];
int subM = subM_[inpIdx];
int subN = subN_[inpIdx];
int subK = subK_[inpIdx];
size_t batchSize = image->getHeight();
/* reset the expand-grad memory */
resetExpandInput(subK * groups_[inpIdx], subN);
real *localGradData = out->getData();
real *tgtGradData = image->getData();
for (size_t n = 0; n < batchSize; n++) {
real *wgtData = weights_[inpIdx]->getW()->getData();
real *expandInData = expandInput_->getData();
for (int g = 0; g < groups_[inpIdx]; g++) {
// create temporary matrix
MatrixPtr C = Matrix::create(expandInData, subK, subN, false, useGpu_);
MatrixPtr B = Matrix::create(localGradData, subM, subN, false, useGpu_);
MatrixPtr A = Matrix::create(wgtData, subM, subK, true, useGpu_);
C->mul(*A, *B); // mul
// clear the temporary matrix
A->clear();
B->clear();
C->clear();
expandInData += subK * subN;
localGradData += subM * subN;
wgtData += subK * subM;
}
// shrink one frame outGrad
MatrixPtr oneGradTmp = Matrix::create(
expandInput_->getData(), subK * groups_[inpIdx], subN, false, useGpu_);
MatrixPtr vTmp =
Matrix::create(tgtGradData,
1,
imgSizeH_[inpIdx] * imgSizeW_[inpIdx] * channel,
false,
useGpu_);
vTmp->convShrink(*oneGradTmp,
imgSizeH_[inpIdx],
imgSizeW_[inpIdx],
channel,
filterSizeY_[inpIdx],
filterSize_[inpIdx],
strideY_[inpIdx],
stride_[inpIdx],
paddingY_[inpIdx],
padding_[inpIdx],
outputH_[inpIdx],
outputW_[inpIdx],
1.0f,
1.0f);
vTmp->clear();
oneGradTmp->clear();
// move the data-pointer
tgtGradData += imgSizeH_[inpIdx] * imgSizeW_[inpIdx] * channel;
}
}
void ExpandConvBaseLayer::bpropWeights(MatrixPtr image,
MatrixPtr out,
int inpIdx) {
MatrixPtr weightGrad = weights_[inpIdx]->getWGrad();
int subM = subM_[inpIdx];
int subN = subN_[inpIdx];
int subK = subK_[inpIdx];
size_t batchSize = image->getHeight();
resetExpandInput(subK * groups_[inpIdx], subN);
real *gradData = out->getData();
for (size_t n = 0; n < batchSize; n++) { // frame by frame
// expand
expandOneFrame(image, n, inpIdx);
real *wGradData = weightGrad->getData();
real *expandInData = expandInput_->getData();
// expand-mul one-group by one
for (int g = 0; g < groups_[inpIdx]; g++) {
MatrixPtr A = Matrix::create(expandInData, subK, subN, true, useGpu_);
MatrixPtr B = Matrix::create(gradData, subM, subN, false, useGpu_);
MatrixPtr C = Matrix::create(wGradData, subM, subK, false, useGpu_);
C->mul(*B, *A, 1, 1);
A->clear();
B->clear();
C->clear();
gradData += subM * subN;
wGradData += subK * subM;
expandInData += subK * subN;
}
}
}
void ExpandConvBaseLayer::bpropSharedBias(MatrixPtr biases, MatrixPtr v) { void ExpandConvBaseLayer::bpropSharedBias(MatrixPtr biases, MatrixPtr v) {
size_t mapW = getOutputSize() / numFilters_; size_t mapW = getOutputSize() / numFilters_;
size_t mapH = v->getElementCnt() / mapW; size_t mapH = v->getElementCnt() / mapW;
......
...@@ -26,19 +26,6 @@ namespace paddle { ...@@ -26,19 +26,6 @@ namespace paddle {
*/ */
class ExpandConvBaseLayer : public ConvBaseLayer { class ExpandConvBaseLayer : public ConvBaseLayer {
protected: protected:
/// For expand convolution.
/// subM_ = numFilters_ / groups_.
IntV subM_;
/// subN_ = outputH_ * outputW_.
IntV subN_;
/// subK_ = channels_ * filterPixels_ * groups_.
IntV subK_;
/*The expandInput_ and transOutValue_ are used for CPU expand conv calc
* Expand one sample at a time. shape:
* (numChannels * filterPixels_, outputSizeH * outputSizeW)
* */
MatrixPtr expandInput_;
/// The transpose of output, which is an auxiliary matrix. /// The transpose of output, which is an auxiliary matrix.
MatrixPtr transOutValue_; MatrixPtr transOutValue_;
...@@ -52,10 +39,6 @@ public: ...@@ -52,10 +39,6 @@ public:
const ParameterMap& parameterMap) override; const ParameterMap& parameterMap) override;
size_t getOutputSize(); size_t getOutputSize();
/**
* Create or resize expandInput_.
*/
void resetExpandInput(size_t height, size_t width);
/** /**
* Add shared bias. * Add shared bias.
...@@ -66,20 +49,9 @@ public: ...@@ -66,20 +49,9 @@ public:
* Add unshared bias. * Add unshared bias.
*/ */
void addUnsharedBias(); void addUnsharedBias();
/**
* Expand one input sample.
*/
void expandOneFrame(MatrixPtr image, size_t startIdx, int inIdx);
/**
* Expand one input sample and perform matrix multiplication.
*/
void expandFwdOnce(MatrixPtr image, MatrixPtr out, int inIdx, int startIdx);
void bpropSharedBias(MatrixPtr biases, MatrixPtr v); void bpropSharedBias(MatrixPtr biases, MatrixPtr v);
void bpropBiases(MatrixPtr v); void bpropBiases(MatrixPtr v);
void bpropWeights(MatrixPtr image, MatrixPtr out, int inpIdx);
void bpropActs(MatrixPtr image, MatrixPtr out, int inpIdx);
}; };
} // namespace paddle } // namespace paddle
...@@ -18,32 +18,94 @@ limitations under the License. */ ...@@ -18,32 +18,94 @@ limitations under the License. */
namespace paddle { namespace paddle {
/*
* The calculation of the exconvt(convolution transpose (deconv) operation)
* is a swap of forward and backward of the calculation of exconv.
* */
REGISTER_LAYER(exconv, ExpandConvLayer); REGISTER_LAYER(exconv, ExpandConvLayer);
REGISTER_LAYER(exconvt, ExpandConvLayer);
bool ExpandConvLayer::init(const LayerMap &layerMap, bool ExpandConvLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) { const ParameterMap &parameterMap) {
/* Initialize the basic convolutional parent class */ /* Initialize the basic convolutional parent class */
ExpandConvBaseLayer::init(layerMap, parameterMap); ExpandConvBaseLayer::init(layerMap, parameterMap);
size_t numInputs = config_.inputs_size();
inputShape_.resize(numInputs);
filterShape_.resize(numInputs);
outputShape_.resize(numInputs);
for (int i = 0; i < config_.inputs_size(); i++) {
std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]};
std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]};
createFunction(forward_,
!isDeconv_ ? "GemmConv" : "GemmConvGradInput",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
createFunction(backward_,
!isDeconv_ ? "GemmConvGradInput" : "GemmConv",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
createFunction(backward_,
"GemmConvGradFilter",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
}
return true; return true;
} }
// i is the index of input layers
#define BACKWARD_INPUT(i, inputs, outputs) \
backward_[2 * i]->calc(inputs, outputs)
#define BACKWARD_FILTER(i, inputs, outputs) \
backward_[2 * i + 1]->calc(inputs, outputs)
void ExpandConvLayer::forward(PassType passType) { void ExpandConvLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
/* malloc memory for the output_ if necessary */ size_t batchSize = inputLayers_[0]->getOutputValue()->getHeight();
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
resetOutput(batchSize, getOutputSize()); resetOutput(batchSize, getOutputSize());
MatrixPtr image = nullptr; // Calculate the shape of the input, output, and filter.
MatrixPtr outV = getOutputValue();
for (size_t i = 0; i < inputLayers_.size(); ++i) { for (size_t i = 0; i < inputLayers_.size(); ++i) {
LayerPtr prevLayer = getPrev(i); inputShape_[i] = TensorShape({(size_t)batchSize,
image = prevLayer->getOutputValue(); (size_t)channels_[i],
for (size_t off = 0; off < image->getHeight(); off++) { (size_t)imgSizeH_[i],
REGISTER_TIMER_INFO("expandFwdOnce", getName().c_str()); (size_t)imgSizeW_[i]});
expandFwdOnce(image, outV, i, off); filterShape_[i] =
TensorShape({(size_t)groups_[i],
!isDeconv_ ? (size_t)numFilters_ / groups_[i]
: (size_t)channels_[i] / groups_[i],
!isDeconv_ ? (size_t)channels_[i] / groups_[i]
: (size_t)numFilters_ / groups_[i],
(size_t)filterSizeY_[i],
(size_t)filterSize_[i]});
outputShape_[i] = TensorShape({(size_t)batchSize,
(size_t)numFilters_,
(size_t)outputH_[i],
(size_t)outputW_[i]});
} }
// Calculate the output value.
for (size_t i = 0; i < inputLayers_.size(); ++i) {
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(i), inputShape_[i]);
inputs.addArg(*weights_[i]->getW(), filterShape_[i]);
outputs.addArg(*getOutputValue(),
outputShape_[i],
!isDeconv_ && i == 0 ? ASSIGN_TO : ADD_TO);
forward_[i]->calc(inputs, outputs);
} }
/* add the bias-vector */ /* add the bias-vector */
if (biases_.get()) { if (biases_.get()) {
if (sharedBiases_) { if (sharedBiases_) {
...@@ -67,14 +129,30 @@ void ExpandConvLayer::backward(const UpdateCallback &callback) { ...@@ -67,14 +129,30 @@ void ExpandConvLayer::backward(const UpdateCallback &callback) {
biases_->getParameterPtr()->incUpdate(callback); biases_->getParameterPtr()->incUpdate(callback);
} }
// Calculate the input grad and filter grad.
for (size_t i = 0; i < inputLayers_.size(); ++i) { for (size_t i = 0; i < inputLayers_.size(); ++i) {
/* First, calculate the input layers error */ if (getInputGrad(i)) {
if (getPrev(i)->getOutputGrad()) { BufferArgs inputs;
bpropActs(outGrad, getPrev(i)->getOutputGrad(), i); BufferArgs outputs;
inputs.addArg(*getOutputGrad(), outputShape_[i]);
inputs.addArg(*weights_[i]->getW(), filterShape_[i]);
outputs.addArg(*getInputGrad(i), inputShape_[i], ADD_TO);
BACKWARD_INPUT(i, inputs, outputs);
} }
if (weights_[i]->getWGrad()) { if (weights_[i]->getWGrad()) {
/* Then, calculate the W-gradient for the current layer */ BufferArgs inputs;
bpropWeights(getPrev(i)->getOutputValue(), outGrad, i); BufferArgs outputs;
if (!isDeconv_) {
inputs.addArg(*getOutputGrad(), outputShape_[i]);
inputs.addArg(*getInputValue(i), inputShape_[i]);
} else {
inputs.addArg(*getInputValue(i), inputShape_[i]);
inputs.addArg(*getOutputGrad(), outputShape_[i]);
}
outputs.addArg(*weights_[i]->getWGrad(), filterShape_[i], ADD_TO);
BACKWARD_FILTER(i, inputs, outputs);
/* Increasing the number of gradient */ /* Increasing the number of gradient */
weights_[i]->getParameterPtr()->incUpdate(callback); weights_[i]->getParameterPtr()->incUpdate(callback);
} }
......
...@@ -40,6 +40,11 @@ public: ...@@ -40,6 +40,11 @@ public:
void forward(PassType passType) override; void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override; void backward(const UpdateCallback& callback) override;
protected:
std::vector<TensorShape> inputShape_;
std::vector<TensorShape> filterShape_;
std::vector<TensorShape> outputShape_;
}; };
} // namespace paddle } // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ExpandConvTransLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
/* The implementation of the convTransLayer is basically a swap of forward and
* backward of the original convLayer.
* The variable naming follows the convention of the convLayer.
* */
namespace paddle {
REGISTER_LAYER(exconvt, ExpandConvTransLayer);
bool ExpandConvTransLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
/* Initialize the basic convolutional parent class */
ExpandConvBaseLayer::init(layerMap, parameterMap);
return true;
}
void ExpandConvTransLayer::forward(PassType passType) {
Layer::forward(passType);
/* malloc memory for the output_ if necessary */
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
resetOutput(batchSize, getOutputSize());
MatrixPtr output = nullptr;
for (size_t i = 0; i < inputLayers_.size(); ++i) {
LayerPtr prevLayer = getPrev(i);
output = prevLayer->getOutputValue();
REGISTER_TIMER_INFO("shrinkFwd", getName().c_str());
bpropActs(output, getOutputValue(), i);
}
/* add the bias-vector */
if (biases_.get()) {
if (sharedBiases_) {
addSharedBias();
} else {
addUnsharedBias();
}
}
/* activation */
forwardActivation();
}
void ExpandConvTransLayer::backward(const UpdateCallback &callback) {
backwardActivation();
MatrixPtr imageGrad = getOutputGrad();
if (biases_ && biases_->getWGrad()) {
bpropBiases(imageGrad);
/* Increasing the number of gradient */
biases_->getParameterPtr()->incUpdate(callback);
}
for (size_t i = 0; i < inputLayers_.size(); ++i) {
/* First, calculate the input layers error */
for (size_t off = 0; off < imageGrad->getHeight(); off++) {
if (getPrev(i)->getOutputGrad()) {
expandFwdOnce(imageGrad, getPrev(i)->getOutputGrad(), i, off);
}
}
if (weights_[i]->getWGrad()) {
/* Then, calculate the W-gradient for the current layer */
bpropWeights(imageGrad, getPrev(i)->getOutputValue(), i);
/* Increasing the number of gradient */
weights_[i]->getParameterPtr()->incUpdate(callback);
}
}
}
} // namespace paddle
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "ModelConfig.pb.h" #include "ModelConfig.pb.h"
#include "paddle/gserver/layers/DataLayer.h" #include "paddle/gserver/layers/DataLayer.h"
#include "paddle/gserver/layers/ExpandConvTransLayer.h"
#include "paddle/trainer/Trainer.h" #include "paddle/trainer/Trainer.h"
#include "paddle/utils/GlobalConstants.h" #include "paddle/utils/GlobalConstants.h"
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "ModelConfig.pb.h" #include "ModelConfig.pb.h"
#include "paddle/gserver/layers/DataLayer.h" #include "paddle/gserver/layers/DataLayer.h"
#include "paddle/gserver/layers/ExpandConvTransLayer.h"
#include "paddle/math/MathUtils.h" #include "paddle/math/MathUtils.h"
#include "paddle/trainer/Trainer.h" #include "paddle/trainer/Trainer.h"
#include "paddle/utils/GlobalConstants.h" #include "paddle/utils/GlobalConstants.h"
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "ModelConfig.pb.h" #include "ModelConfig.pb.h"
#include "paddle/gserver/layers/DataLayer.h" #include "paddle/gserver/layers/DataLayer.h"
#include "paddle/gserver/layers/ExpandConvTransLayer.h"
#include "paddle/math/MathUtils.h" #include "paddle/math/MathUtils.h"
#include "paddle/trainer/Trainer.h" #include "paddle/trainer/Trainer.h"
#include "paddle/utils/GlobalConstants.h" #include "paddle/utils/GlobalConstants.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册