提交 81016bb0 编写于 作者: S Superjom

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into rnn_varilen_design

......@@ -17,73 +17,6 @@ limitations under the License. */
#include "hl_base.h"
/**
* @brief Shrink column to feature.
*
* @param[in] dataCol expand data.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] blockH filter height.
* @param[in] blockW filter width.
* @param[in] strideH stride height.
* @param[in] strideW stride width.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[in] outputH output height.
* @param[in] outputW output width.
* @param[out] dataIm output image data.
* @param[in] alpha
* @param[in] beta
*/
extern void hl_shrink_col2feature(const real* dataCol,
size_t channels,
size_t height,
size_t width,
size_t blockH,
size_t blockW,
size_t strideH,
size_t strideW,
size_t paddingH,
size_t paddingW,
size_t outputH,
size_t outputW,
real* dataIm,
real alpha = 1.0f,
real beta = 0.0f);
/**
* @brief Expand feature to column.
*
* @param[in] dataIm input image data.
* @param[in] channels number of channel.
* @param[in] height image height.
* @param[in] width image width.
* @param[in] blockH filter height.
* @param[in] blockW filter width.
* @param[in] strideH stride height.
* @param[in] strideW stride width.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[in] outputH output height.
* @param[in] outputW output width.
* @param[out] dataCol expand data.
*
*/
extern void hl_expand_feature2col(const real* dataIm,
size_t channels,
size_t height,
size_t width,
size_t blockH,
size_t blockW,
size_t strideH,
size_t strideW,
size_t paddingH,
size_t paddingW,
size_t outputH,
size_t outputW,
real* dataCol);
/**
* @brief Maximum pool forward.
*
......
......@@ -17,36 +17,6 @@ limitations under the License. */
#include "hl_cnn.h"
inline void hl_shrink_col2feature(const real* dataCol,
size_t channels,
size_t height,
size_t width,
size_t blockH,
size_t blockW,
size_t strideH,
size_t strideW,
size_t paddingH,
size_t paddingW,
size_t outputH,
size_t outputW,
real* dataIm,
real alpha,
real beta) {}
inline void hl_expand_feature2col(const real* dataIm,
size_t channels,
size_t height,
size_t width,
size_t blockH,
size_t blockW,
size_t strideH,
size_t strideW,
size_t paddingH,
size_t paddingW,
size_t outputH,
size_t outputW,
real* dataCol) {}
inline void hl_maxpool_forward(const int frameCnt,
const real* inputData,
const int channels,
......
......@@ -18,134 +18,6 @@ limitations under the License. */
#include "hl_cnn.h"
#include "hl_device_functions.cuh"
__global__ void KeFeature2col(size_t n, size_t height, const real* data_im,
size_t blockH, size_t blockW, size_t width,
size_t strideH, size_t strideW,
size_t paddingH, size_t paddingW,
size_t height_col, size_t width_col,
real* data_col) {
size_t index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < n) {
size_t w_out = index % width_col;
index /= width_col;
size_t h_out = index % height_col;
size_t channel_in = index / height_col;
size_t channel_out = channel_in * blockH * blockW;
size_t h_in = h_out * strideH;
size_t w_in = w_out * strideW;
data_col += (channel_out * height_col + h_out) * width_col + w_out;
for (size_t i = 0; i < blockH; ++i) {
for (size_t j = 0; j < blockW; ++j) {
int rIdx = int(h_in+i);
int cIdx = int(w_in+j);
if ((rIdx-(int)paddingH) >= (int)height ||
(rIdx-(int)paddingH) < 0 ||
(cIdx-(int)paddingW) >= (int)width ||
(cIdx-(int)paddingW) < 0) {
*data_col = 0;
} else {
rIdx = rIdx + channel_in*height - paddingH;
cIdx = cIdx - paddingW;
*data_col = data_im[rIdx* width + cIdx];
}
data_col += height_col * width_col;
}
}
}
}
void hl_expand_feature2col(const real* dataIm, size_t channels,
size_t height, size_t width,
size_t blockH, size_t blockW,
size_t strideH, size_t strideW,
size_t paddingH, size_t paddingW,
size_t outputH, size_t outputW,
real* dataCol) {
size_t numKernels = channels * outputH * outputW;
size_t blocks = (numKernels + 1024 -1) / 1024;
size_t blockX = 512;
size_t blockY = (blocks+512-1)/512;
dim3 threads(1024, 1);
dim3 grid(blockX, blockY);
KeFeature2col<<< grid, threads, 0, STREAM_DEFAULT >>>
(numKernels, height, dataIm, blockH, blockW, width,
strideH, strideW, paddingH, paddingW,
outputH, outputW, dataCol);
CHECK_SYNC("hl_expand_feature2col failed");
}
__global__ void KeCol2Feature(size_t n, const real* data_col, size_t height,
size_t width, size_t channels,
size_t blockH, size_t blockW,
size_t strideH, size_t strideW,
size_t paddingH, size_t paddingW,
size_t height_col, size_t width_col,
real* data_im, real alpha, real beta) {
size_t index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < n) {
real val = 0;
int w = int(index % width);
int h = int((index / width) % height);
int c = int(index / (width * height));
if ((w - (int)paddingW) >= 0 &&
(w - (int)paddingW) < (width-2 * paddingW) &&
(h - (int)paddingH) >= 0 &&
(h - paddingH) < (height - 2 * paddingH)) {
// compute the start and end of the output
int w_col_start =
(w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1;
int w_col_end =
min((int)(w / (int)strideW + 1), (int)(width_col));
int h_col_start =
(h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1;
int h_col_end = min(int(h / strideH + 1), int(height_col));
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out]
int c_col = int(c * blockH* blockW) + \
(h - h_col * (int)strideH) * (int)blockW +
(w - w_col * (int)strideW);
val += data_col[(c_col * height_col + h_col) * width_col + w_col];
}
}
h -= paddingH;
w -= paddingW;
real tD = data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
h*(width-2*paddingW) + w];
data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
h*(width-2*paddingW) + w] = alpha * val + beta*tD;
}
}
}
void hl_shrink_col2feature(const real * dataCol, size_t channels,
size_t height, size_t width,
size_t blockH, size_t blockW,
size_t strideH, size_t strideW,
size_t paddingH, size_t paddingW,
size_t outputH, size_t outputW,
real* dataIm, real alpha, real beta) {
size_t numKernels = channels * (height + 2*paddingH) * (width + 2*paddingW);
size_t blocks = (numKernels + 1024 -1) / 1024;
size_t blockX = 512;
size_t blockY = (blocks+512-1)/512;
dim3 threads(1024, 1);
dim3 grid(blockX, blockY);
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
KeCol2Feature<<< grid, threads, 0, STREAM_DEFAULT >>>
(numKernels, dataCol, height + 2*paddingH, width + 2*paddingW,
channels, blockH, blockW, strideH, strideW, paddingH, paddingW,
outputH, outputW, dataIm, alpha, beta);
CHECK_SYNC("hl_shrink_col2feature failed");
}
__global__ void KeMaxPoolForward(const int nthreads, const real* inputData,
const int channels, const int height,
const int width,
......
......@@ -56,7 +56,9 @@ class Scope {
if (var) {
return var;
} else {
vars_[name] = std::unique_ptr<Variable>(new Variable());
auto ptr = new Variable();
name_to_var_[name] = std::unique_ptr<Variable>(ptr);
var_to_name_[ptr] = name;
return GetVariable(name);
}
}
......@@ -68,8 +70,8 @@ class Scope {
* from it's parent scope. Return nullptr if not found.
*/
Variable* GetVariable(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) {
auto it = name_to_var_.find(name);
if (it != name_to_var_.end()) {
return it->second.get();
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
......@@ -84,12 +86,21 @@ class Scope {
* Find if there is a Variable in this scope and it's parent scope
*/
bool HasVariable(const std::string& name) const {
return (vars_.find(name) != vars_.end() ||
return (name_to_var_.find(name) != name_to_var_.end() ||
(parent_ && parent_->HasVariable(name)));
}
std::string GetVariableName(Variable* const var) const {
try {
return var_to_name_.at(var);
} catch (...) {
return "";
}
}
private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
std::unordered_map<Variable*, std::string> var_to_name_;
std::unordered_map<std::string, std::unique_ptr<Variable>> name_to_var_;
std::shared_ptr<Scope> parent_{nullptr};
};
......
......@@ -40,6 +40,11 @@ TEST(Scope, Create) {
/// already exist.
Variable* var4 = scope->CreateVariable("a");
EXPECT_EQ(var4, var2);
EXPECT_EQ("a", scope->GetVariableName(var4));
Scope scope2;
auto var = scope2.CreateVariable("tmp");
EXPECT_EQ("", scope->GetVariableName(var));
}
TEST(Scope, Parent) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Function.h"
#include "Im2Col.h"
namespace paddle {
/*
* \brief Converts the image data of four dimensions(NCHW) into
* a sequence data of three dimensions(NST) in the forward calculation,
* which is reversed in the backward calculation.
* Where N is batch size, S is the length of the sequence after each
* image is expanded, T is the size of each time step in the sequence.
*
* Arguments in forward function:
* \param inputs[0] Image data of NCHW format.
* \param outputs[0] Sequence data of NST format.
*
* Arguments in backward function:
* \param inputs[0] Sequence data of NST format.
* \param outputs[0] Image data of NCHW format.
*/
class BlockExpandFunction : public FunctionBase {
public:
void init(const FuncConfig& config) override {
// function arguments
strides_ = config.get<std::vector<size_t>>("strides");
paddings_ = config.get<std::vector<size_t>>("paddings");
blocks_ = config.get<std::vector<size_t>>("blocks");
// number of inputs and outputs
numInputs_ = 1;
numOutputs_ = 1;
}
void checkShape(const TensorShape& image, const TensorShape& sequence) const {
// image shape should be 4-dimensional.
CHECK_EQ(image.ndims(), (size_t)4);
// sequence shape should be 3-dimensional.
CHECK_EQ(sequence.ndims(), (size_t)3);
// The batchSize of the image needs to be equal to
// the batchSize of the sequence.
CHECK_EQ(image[0], sequence[0]);
}
// Calculate the shape of colData based on the shape of the image
// and the shape of the sequence.
TensorShape getColShape(const TensorShape& image,
const TensorShape& sequence) const {
size_t inputChannels = image[1];
size_t inputHeight = image[2];
size_t inputWidth = image[3];
size_t seqLength = sequence[1];
size_t stepSize = sequence[2];
size_t outputHeight =
1 +
(inputHeight + 2 * paddingH() - blockH() + strideH() - 1) / strideH();
size_t outputWidth =
1 +
(inputWidth + 2 * paddingW() - blockW() + strideW() - 1) / strideW();
CHECK_EQ(seqLength, outputHeight * outputWidth);
CHECK_EQ(stepSize, inputChannels * blockH() * blockW());
// [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
return TensorShape({outputHeight,
outputWidth,
inputChannels,
(size_t)blockH(),
(size_t)blockW()});
}
protected:
std::vector<size_t> strides_;
std::vector<size_t> paddings_;
std::vector<size_t> blocks_;
inline int strideH() const { return strides_[0]; }
inline int strideW() const { return strides_[1]; }
inline int paddingH() const { return paddings_[0]; }
inline int paddingW() const { return paddings_[1]; }
inline int blockH() const { return blocks_[0]; }
inline int blockW() const { return blocks_[1]; }
};
template <DeviceType Device>
class BlockExpandForward : public BlockExpandFunction {
public:
void init(const FuncConfig& config) override {
BlockExpandFunction::init(config);
}
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
const TensorShape& image = inputs[0].shape();
const TensorShape& sequence = outputs[0].shape();
checkShape(image, sequence);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
const TensorShape& image = inputs[0].shape();
const TensorShape& sequence = outputs[0].shape();
TensorShape imShape = TensorShape({image[1], image[2], image[3]});
TensorShape colShape = getColShape(image, sequence);
size_t batchSize = image[0];
real* imageData = inputs[0].data<real>();
real* seqData = outputs[0].data<real>();
Im2ColFunctor<kOCF, Device, real> im2col;
for (size_t i = 0; i < batchSize; i++) {
// The result of im2col is [outputHeight, outputWidth,
// inputChannels, filterHeight, filterWidth], and it is easy to
// reshape into [seqLength, stepSize], where seqLength is equal
// output_height * output_width, stepSize is equal
// input_channels * filter_height * filter_width
im2col(imageData,
imShape,
seqData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
imageData += imShape.getElements();
seqData += colShape.getElements();
}
}
};
template <DeviceType Device>
class BlockExpandBackward : public BlockExpandFunction {
public:
void init(const FuncConfig& config) override {
BlockExpandFunction::init(config);
}
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
const TensorShape& image = outputs[0].shape();
const TensorShape& sequence = inputs[0].shape();
checkShape(image, sequence);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
// Since the implementation of Col2ImFunctor is ADD_TO,
// this function only supports ADD_TO mode.
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const TensorShape& image = outputs[0].shape();
const TensorShape& sequence = inputs[0].shape();
TensorShape imShape = TensorShape({image[1], image[2], image[3]});
TensorShape colShape = getColShape(image, sequence);
size_t batchSize = image[0];
real* imageData = outputs[0].data<real>();
real* seqData = inputs[0].data<real>();
Col2ImFunctor<kOCF, Device, real> col2im;
for (size_t i = 0; i < batchSize; i++) {
col2im(imageData,
imShape,
seqData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
imageData += imShape.getElements();
seqData += colShape.getElements();
}
}
};
REGISTER_TYPED_FUNC(BlockExpand, CPU, BlockExpandForward);
REGISTER_TYPED_FUNC(BlockExpandGrad, CPU, BlockExpandBackward);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(BlockExpand, GPU, BlockExpandForward);
REGISTER_TYPED_FUNC(BlockExpandGrad, GPU, BlockExpandBackward);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
namespace paddle {
TEST(BlockExpandForward, real) {
for (size_t batchSize : {5, 32}) {
for (size_t channels : {1, 5, 32}) {
for (size_t inputHeight : {5, 33, 100}) {
for (size_t inputWidth : {5, 32, 96}) {
for (size_t block : {1, 3, 5}) {
for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) {
// init Test object
std::vector<size_t> strides = {stride, stride};
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> blocks = {block, block};
CpuGpuFuncCompare test("BlockExpand",
FuncConfig()
.set("strides", strides)
.set("paddings", paddings)
.set("blocks", blocks));
size_t outputHeight =
1 +
(inputHeight + 2 * padding - block + stride - 1) / stride;
size_t outputWidth =
1 +
(inputWidth + 2 * padding - block + stride - 1) / stride;
TensorShape inputShape =
TensorShape({batchSize, channels, inputHeight, inputWidth});
TensorShape outputShape =
TensorShape({batchSize,
outputHeight * outputWidth,
channels * block * block});
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, inputShape));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, outputShape));
// run Function
test.run();
}
}
}
}
}
}
}
}
TEST(BlockExpandBackward, real) {
for (size_t batchSize : {5, 32}) {
for (size_t channels : {1, 5, 32}) {
for (size_t inputHeight : {5, 33, 100}) {
for (size_t inputWidth : {5, 32, 96}) {
for (size_t block : {1, 3, 5}) {
for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) {
// init Test object
std::vector<size_t> strides = {stride, stride};
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> blocks = {block, block};
CpuGpuFuncCompare test("BlockExpandGrad",
FuncConfig()
.set("strides", strides)
.set("paddings", paddings)
.set("blocks", blocks));
size_t outputHeight =
1 +
(inputHeight + 2 * padding - block + stride - 1) / stride;
size_t outputWidth =
1 +
(inputWidth + 2 * padding - block + stride - 1) / stride;
TensorShape inputShape =
TensorShape({batchSize, channels, inputHeight, inputWidth});
TensorShape outputShape =
TensorShape({batchSize,
outputHeight * outputWidth,
channels * block * block});
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, outputShape));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, inputShape),
ADD_TO);
// run Function
test.run();
}
}
}
}
}
}
}
}
} // namespace paddle
......@@ -36,10 +36,12 @@ if(WITH_GPU)
add_simple_unittest(MulOpTest)
add_simple_unittest(CosSimOpTest)
add_simple_unittest(RowConvOpTest)
add_simple_unittest(BlockExpandOpTest)
add_simple_unittest(CropOpTest)
endif()
add_simple_unittest(ConvOpTest)
add_simple_unittest(Im2ColTest)
endif()
add_style_check_target(paddle_function ${h_files})
......
......@@ -12,101 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "GemmConvOp.h"
#include "ConvOp.h"
#include "GemmFunctor.h"
#include "Im2Col.h"
#include "paddle/math/MemoryHandle.h"
namespace paddle {
/*
* imData = [input_channels, input_height, input_width]
* colData = [input_channels, filter_height, filter_width,
* output_height, output_width]
*/
template <class T>
class Im2ColFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* imData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* colData) {
int channelsCol = inputChannels * filterHeight * filterWidth;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 ||
(imColIdx - paddingWidth) >= inputWidth) {
colData[(c * outputHeight + h) * outputWidth + w] = T(0);
} else {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
colData[(c * outputHeight + h) * outputWidth + w] =
imData[imRowIdx * inputWidth + imColIdx];
}
}
}
}
}
};
template <class T>
class Col2ImFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* colData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* imData) {
int channelsCol = inputChannels * filterHeight * filterWidth;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) >= 0 &&
(imRowIdx - paddingHeight) < inputHeight &&
(imColIdx - paddingWidth) >= 0 &&
(imColIdx - paddingWidth) < inputWidth) {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
imData[imRowIdx * inputWidth + imColIdx] +=
colData[(c * outputHeight + h) * outputWidth + w];
}
}
}
}
}
};
/*
* \brief Forward calculation of convolution.
*/
......@@ -154,15 +66,20 @@ public:
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>();
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
outputHeight * outputWidth;
resizeBuffer<Device>(size);
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<Device, real> im2col;
Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t inputOffset = imShape.getElements();
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
......@@ -170,18 +87,13 @@ public:
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
inputChannels / groups_,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
colData);
paddingW());
int M = outputChannels / groups_;
int N = outputHeight * outputWidth;
......@@ -247,15 +159,20 @@ public:
real* outputGrad = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* inputGrad = outputs[0].data<real>();
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
outputHeight * outputWidth;
resizeBuffer<Device>(size);
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf());
Col2ImFunctor<Device, real> col2im;
Col2ImFunctor<kCFO, Device, real> col2im;
GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t inputOffset = imShape.getElements();
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
......@@ -278,20 +195,14 @@ public:
0.0f,
colData,
N);
col2im(colData,
inputChannels / groups_,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
col2im(inputGrad + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
inputGrad + g * inputOffset);
paddingW());
}
inputGrad += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
......@@ -344,33 +255,33 @@ public:
real* outputGrad = inputs[0].data<real>();
real* inputData = inputs[1].data<real>();
real* filterGrad = outputs[0].data<real>();
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
outputHeight * outputWidth;
resizeBuffer<Device>(size);
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<Device, real> im2col;
Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t inputOffset = imShape.getElements();
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
inputChannels / groups_,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
colData);
paddingW());
int M = outputChannels / groups_;
int K = outputHeight * outputWidth;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "ConvOp.h"
namespace paddle {
/*
* imData = [input_channels, input_height, input_width]
* colData = [input_channels, filter_height, filter_width,
* output_height, output_width]
*/
template <DeviceType Device, class T>
class Im2ColFunctor {
public:
void operator()(const T* imData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* colData);
};
template <DeviceType Device, class T>
class Col2ImFunctor {
public:
void operator()(const T* colData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* imData);
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "TensorShape.h"
#include "TensorType.h"
namespace paddle {
/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */
enum ColFormat { kCFO = 0, kOCF = 1 };
/*
* \brief Converts the image data of three dimensions(CHW) into a colData of
* five dimensions in the Im2ColFunctor calculation,
* And in the Col2ImFunctor calculation, it is reversed.
*
* \param imData Image data.
* \param imShape The shape of imData,
* [inputChannels, inputHeight, inputWidth].
* \param colData Column data.
* \param colShape The shape of colData.
*
* If the template argument Format is kCFO, the shape of colData is:
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth]
* So, it is easy to reshape into a convolution matrix for convolution
* calculation based on matrix multiplication.
* The shape of convolution matrix is [height, width], where the height is equal
* inputChannels * filterHeight * filterWidth, and the width is equal
* outputHeight * outputWidth.
*
* Reshape:
* shape of colData shape of convolution matrix
* [inputChannels,
* filterHeight,
* filterWidth, ======> [height, width]
* outputHeight,
* outputWidth]
*
* If the template argument Format is kOCF, the shape of colData is:
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
* So, it is easy to reshape into a sequence matrix for rnn calculation.
* The shape of sequence matrix is [seqLength, stepSize], where the seqLength
* is equal outputHeight * outputWidth, and the stepSize is equal
* inputChannels * filterHeight * filterWidth.
*
* Reshape:
* shape of colData shape of sequence matrix
* [outputHeight,
* outputWidth,
* inputChannels, ======> [seqLength, stepSize]
* filterHeight,
* filterWidth]
*
* \note The caller needs to ensure that imShape.inputChannels is equal to
* colShape.inputChannels.
*/
template <ColFormat Format, DeviceType Device, class T>
class Im2ColFunctor {
public:
void operator()(const T* imData,
const TensorShape& imShape,
T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth);
};
template <ColFormat Format, DeviceType Device, class T>
class Col2ImFunctor {
public:
void operator()(T* imData,
const TensorShape& imShape,
const T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth);
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Im2Col.h"
namespace paddle {
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth]
*/
template <class T>
class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* imData,
const TensorShape& imShape,
T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[1];
int filterWidth = colShape[2];
int outputHeight = colShape[3];
int outputWidth = colShape[4];
int channelsCol = inputChannels * filterHeight * filterWidth;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 ||
(imColIdx - paddingWidth) >= inputWidth) {
colData[(c * outputHeight + h) * outputWidth + w] = T(0);
} else {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
colData[(c * outputHeight + h) * outputWidth + w] =
imData[imRowIdx * inputWidth + imColIdx];
}
}
}
}
}
};
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth]
*/
template <class T>
class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, T> {
public:
void operator()(T* imData,
const TensorShape& imShape,
const T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[1];
int filterWidth = colShape[2];
int outputHeight = colShape[3];
int outputWidth = colShape[4];
int channelsCol = inputChannels * filterHeight * filterWidth;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) >= 0 &&
(imRowIdx - paddingHeight) < inputHeight &&
(imColIdx - paddingWidth) >= 0 &&
(imColIdx - paddingWidth) < inputWidth) {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
imData[imRowIdx * inputWidth + imColIdx] +=
colData[(c * outputHeight + h) * outputWidth + w];
}
}
}
}
}
};
template class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, float>;
template class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, double>;
template class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, float>;
template class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, double>;
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
*/
template <class T>
class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* imData,
const TensorShape& imShape,
T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[3];
int filterWidth = colShape[4];
int outputHeight = colShape[0];
int outputWidth = colShape[1];
for (int outputH = 0; outputH < outputHeight; ++outputH) {
for (int outputW = 0; outputW < outputWidth; ++outputW) {
for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) {
for (int filterW = 0; filterW < filterWidth; ++filterW) {
int imRowOffset =
outputH * strideHeight + filterH - paddingHeight;
int imColOffset = outputW * strideWidth + filterW - paddingWidth;
int colDataOffset =
(((outputH * outputWidth + outputW) * inputChannels +
channel) *
filterHeight +
filterH) *
filterWidth +
filterW;
if (imRowOffset < 0 || imRowOffset >= inputHeight ||
imColOffset < 0 || imColOffset >= inputWidth) {
colData[colDataOffset] = float(0);
} else {
int imDataOffset =
(channel * inputHeight + imRowOffset) * inputWidth +
imColOffset;
colData[colDataOffset] = imData[imDataOffset];
}
}
}
}
}
}
}
};
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
*/
template <class T>
class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, T> {
public:
void operator()(T* imData,
const TensorShape& imShape,
const T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[3];
int filterWidth = colShape[4];
int outputHeight = colShape[0];
int outputWidth = colShape[1];
for (int outputH = 0; outputH < outputHeight; ++outputH) {
for (int outputW = 0; outputW < outputWidth; ++outputW) {
for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) {
for (int filterW = 0; filterW < filterWidth; ++filterW) {
int imRowOffset =
outputH * strideHeight + filterH - paddingHeight;
int imColOffset = outputW * strideWidth + filterW - paddingWidth;
int colDataOffset =
(((outputH * outputWidth + outputW) * inputChannels +
channel) *
filterHeight +
filterH) *
filterWidth +
filterW;
if (imRowOffset >= 0 && imRowOffset < inputHeight &&
imColOffset >= 0 && imColOffset < inputWidth) {
int imDataOffset =
(channel * inputHeight + imRowOffset) * inputWidth +
imColOffset;
imData[imDataOffset] += colData[colDataOffset];
}
}
}
}
}
}
}
};
template class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, float>;
template class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, double>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, float>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, double>;
} // namespace paddle
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ConvOp.h"
#include "GemmConvOp.h"
#include "Im2Col.h"
#include "hl_device_functions.cuh"
namespace paddle {
......@@ -57,22 +57,30 @@ void im2col(const T* data_im, int numOuts, int height, int width,
}
}
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth]
*/
template <class T>
class Im2ColFunctor<DEVICE_TYPE_GPU, T> {
class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, T> {
public:
void operator()(const T* imData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
const TensorShape& imShape,
T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* colData) {
int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[1];
int filterWidth = colShape[2];
int outputHeight = colShape[3];
int outputWidth = colShape[4];
int numKernels = inputChannels * outputHeight * outputWidth;
int blocks = (numKernels + 1024 -1) / 1024;
int blockX = 512;
......@@ -132,22 +140,30 @@ void col2im(size_t n, const T* data_col, size_t height,
}
}
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth]
*/
template <class T>
class Col2ImFunctor<DEVICE_TYPE_GPU, T> {
class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, T> {
public:
void operator()(const T* colData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
void operator()(T* imData,
const TensorShape& imShape,
const T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* imData) {
int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[1];
int filterWidth = colShape[2];
int outputHeight = colShape[3];
int outputWidth = colShape[4];
size_t numKernels = inputChannels * (inputHeight + 2*paddingHeight)
* (inputWidth + 2*paddingWidth);
......@@ -178,9 +194,188 @@ public:
}
};
template class Im2ColFunctor<DEVICE_TYPE_GPU, float>;
template class Im2ColFunctor<DEVICE_TYPE_GPU, double>;
template class Col2ImFunctor<DEVICE_TYPE_GPU, float>;
template class Col2ImFunctor<DEVICE_TYPE_GPU, double>;
template class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, float>;
template class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, double>;
template class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, float>;
template class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, double>;
template<class T>
__global__
void im2colOCF(const T* imData, T* colData,
int inputChannels,
int inputHeight, int inputWidth,
int filterHeight, int filterWidth,
int strideHeight, int strideWidth,
int paddingHeight, int paddingWidth,
int outputHeight, int outputWidth) {
int swId = blockIdx.x;
int shId = blockIdx.y;
for (int channelId = threadIdx.z;
channelId < inputChannels;
channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth;
int heightOffset = idy + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth
+ channelId * inputHeight * inputWidth;
int colOffset = idx + idy * filterWidth
+ channelId * filterHeight * filterWidth
+ (shId * outputWidth + swId)
* (inputChannels * filterHeight * filterWidth);
if (heightOffset >= inputHeight || heightOffset < 0 ||
widthOffset >= inputWidth || widthOffset < 0) {
colData[colOffset] = T(0);
} else {
colData[colOffset] = imData[imOffset];
}
}
}
}
}
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
*/
template <class T>
class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, T> {
public:
void operator()(const T* imData,
const TensorShape& imShape,
T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[3];
int filterWidth = colShape[4];
int outputHeight = colShape[0];
int outputWidth = colShape[1];
int blockDimX = 0;
int blockDimY = 0;
if (filterHeight <= 4 && filterWidth <= 4) {
blockDimX = 4;
blockDimY = 4;
} else if (filterHeight <= 8 && filterWidth <= 8) {
blockDimX = 8;
blockDimY = 8;
} else if (filterHeight <= 16 && filterWidth <= 16) {
blockDimX = 16;
blockDimY = 16;
} else {
blockDimX = 32;
blockDimY = 32;
}
int blockDimZ = 1024 / blockDimX / blockDimY;
dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels));
dim3 grid(outputWidth, outputHeight);
im2colOCF<T><<< grid, threads, 0, STREAM_DEFAULT >>>
(imData, colData, inputChannels, inputHeight, inputWidth,
filterHeight, filterWidth, strideHeight, strideWidth,
paddingHeight, paddingWidth, outputHeight, outputWidth);
CHECK_SYNC("Im2ColFunctor GPU failed");
}
};
template<class T>
__global__
void col2imOCF(T* imData, const T* colData,
int inputChannels,
int inputHeight, int inputWidth,
int filterHeight, int filterWidth,
int strideHeight, int strideWidth,
int paddingHeight, int paddingWidth,
int outputHeight, int outputWidth) {
int swId = blockIdx.x;
int shId = blockIdx.y;
for (int channelId = threadIdx.z;
channelId < inputChannels;
channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth;
int heightOffset = idy + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth
+ channelId * inputHeight * inputWidth;
int colOffset = idx + idy * filterWidth
+ channelId * filterHeight * filterWidth
+ (shId * outputWidth + swId)
* (inputChannels * filterHeight * filterWidth);
if (heightOffset >= 0 && heightOffset < inputHeight &&
widthOffset >= 0 && widthOffset < inputWidth) {
paddle::paddleAtomicAdd(imData + imOffset, colData[colOffset]);
}
}
}
}
}
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
*/
template <class T>
class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, T> {
public:
void operator()(T* imData,
const TensorShape& imShape,
const T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[3];
int filterWidth = colShape[4];
int outputHeight = colShape[0];
int outputWidth = colShape[1];
int blockDimX = 0;
int blockDimY = 0;
if (filterHeight <= 4 && filterWidth <= 4) {
blockDimX = 4;
blockDimY = 4;
} else if (filterHeight <= 8 && filterWidth <= 8) {
blockDimX = 8;
blockDimY = 8;
} else if (filterHeight <= 16 && filterWidth <= 16) {
blockDimX = 16;
blockDimY = 16;
} else {
blockDimX = 32;
blockDimY = 32;
}
int blockDimZ = 1024 / blockDimX / blockDimY;
dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels));
dim3 grid(outputWidth, outputHeight);
col2imOCF<T><<< grid, threads, 0, STREAM_DEFAULT >>>
(imData, colData, inputChannels, inputHeight, inputWidth,
filterHeight, filterWidth, strideHeight, strideWidth,
paddingHeight, paddingWidth, outputHeight, outputWidth);
CHECK_SYNC("Col2ImFunctor GPU failed");
}
};
template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, float>;
template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, double>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, float>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, double>;
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Im2Col.h"
#include <gtest/gtest.h>
#include "Function.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/tests/TensorCheck.h"
namespace paddle {
template <DeviceType Device, class T>
void TestIm2ColFunctor() {
for (size_t channels : {1, 5, 32}) {
for (size_t inputHeight : {5, 33, 100}) {
for (size_t inputWidth : {5, 32, 96}) {
for (size_t filterHeight : {1, 5}) {
for (size_t filterWidth : {3, 7}) {
for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) {
if (inputHeight <= filterHeight || inputWidth <= filterWidth)
break;
if (padding >= filterHeight || padding >= filterWidth) break;
size_t outputHeight =
(inputHeight - filterHeight + 2 * padding + stride) /
stride;
size_t outputWidth =
(inputWidth - filterWidth + 2 * padding + stride) / stride;
TensorShape imShape =
TensorShape({channels, inputHeight, inputWidth});
TensorShape colShape1 = TensorShape({channels,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
TensorShape colShape2 = TensorShape({outputHeight,
outputWidth,
channels,
filterHeight,
filterWidth});
size_t height = channels * filterHeight * filterWidth;
size_t width = outputHeight * outputWidth;
VectorPtr input1 = Vector::create(imShape.getElements(), false);
VectorPtr input2 = Vector::create(imShape.getElements(), false);
MatrixPtr output1 = Matrix::create(height, width, false, false);
MatrixPtr output2 = Matrix::create(width, height, false, false);
input1->uniform(0.001, 1);
input2->copyFrom(*input1);
Im2ColFunctor<kCFO, Device, T> im2Col1;
Im2ColFunctor<kOCF, Device, T> im2Col2;
im2Col1(input1->getData(),
imShape,
output1->getData(),
colShape1,
stride,
stride,
padding,
padding);
im2Col2(input2->getData(),
imShape,
output2->getData(),
colShape2,
stride,
stride,
padding,
padding);
// The transposition of the result of ColFormat == kCFO
// is equal to the result of ColFormat == kOCF.
MatrixPtr test;
output2->transpose(test, true);
autotest::TensorCheckErr(*output1, *test);
Col2ImFunctor<kCFO, Device, T> col2Im1;
Col2ImFunctor<kOCF, Device, T> col2Im2;
col2Im1(input1->getData(),
imShape,
output1->getData(),
colShape1,
stride,
stride,
padding,
padding);
col2Im2(input2->getData(),
imShape,
output2->getData(),
colShape2,
stride,
stride,
padding,
padding);
autotest::TensorCheckErr(*input1, *input2);
}
}
}
}
}
}
}
}
TEST(Im2ColFunctor, CPU) { TestIm2ColFunctor<DEVICE_TYPE_CPU, float>(); }
#ifndef PADDLE_ONLY_CPU
TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor<DEVICE_TYPE_GPU, float>(); }
#endif
} // namespace paddle
......@@ -37,6 +37,22 @@ bool BlockExpandLayer::init(const LayerMap& layerMap,
imgSizeH_ = blockConf.img_size_y();
imgSizeW_ = blockConf.img_size_x();
std::vector<size_t> strides = {(size_t)strideH_, (size_t)strideW_};
std::vector<size_t> paddings = {(size_t)paddingH_, (size_t)paddingW_};
std::vector<size_t> blocks = {(size_t)blockH_, (size_t)blockW_};
createFunction(forward_,
"BlockExpand",
FuncConfig()
.set("strides", strides)
.set("paddings", paddings)
.set("blocks", blocks));
createFunction(backward_,
"BlockExpandGrad",
FuncConfig()
.set("strides", strides)
.set("paddings", paddings)
.set("blocks", blocks));
return true;
}
......@@ -63,48 +79,27 @@ void BlockExpandLayer::forward(PassType passType) {
Layer::forward(passType);
size_t batchSize = inputLayers_[0]->getOutputValue()->getHeight();
size_t blockNum = getBlockNum();
size_t blockSize = blockH_ * blockW_ * channels_;
resetOutput(blockNum * batchSize, blockSize);
Argument& out = getOutput();
MatrixPtr outV = getOutputValue();
MatrixPtr input = getPrev(0)->getOutputValue();
Matrix::resizeOrCreate(outVTrans_, blockSize, blockNum, false, useGpu_);
// calculate output_.value
inputShape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_});
outputShape_ = TensorShape({batchSize, blockNum, blockSize});
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), inputShape_);
outputs.addArg(*getOutputValue(), outputShape_, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
// calculate output_.sequenceStartPositions and output_.cpuSequenceDims
Argument& out = getOutput();
ICpuGpuVector::resizeOrCreate(
out.sequenceStartPositions, batchSize + 1, false);
IVector::resizeOrCreate(out.cpuSequenceDims, 2 * batchSize, false);
int* start = out.sequenceStartPositions->getMutableData(false);
int* dims = out.cpuSequenceDims->getData();
for (size_t i = 0; i < batchSize; i++) {
outVTrans_->zeroMem();
/* expand each block as one row */
MatrixPtr inputTmp =
Matrix::create(input->getData() + i * input->getWidth(),
1,
input->getWidth(),
false,
useGpu_);
outVTrans_->convExpand(*inputTmp,
imgSizeH_,
imgSizeW_,
channels_,
blockH_,
blockW_,
strideH_,
strideW_,
paddingH_,
paddingW_,
outputH_,
outputW_);
MatrixPtr outVTmp =
Matrix::create(outV->getData() + i * blockNum * blockSize,
blockNum,
blockSize,
false,
useGpu_);
outVTrans_->transpose(outVTmp, false);
start[i] = i * blockNum;
dims[2 * i] = outputH_;
dims[2 * i + 1] = outputW_;
......@@ -113,48 +108,13 @@ void BlockExpandLayer::forward(PassType passType) {
}
void BlockExpandLayer::backward(const UpdateCallback& callback) {
size_t blockNum = outputH_ * outputW_;
size_t blockSize = blockH_ * blockW_ * channels_;
/* Calculate the input layers error */
MatrixPtr preGrad = inputLayers_[0]->getOutputGrad();
if (!preGrad) {
return;
}
MatrixPtr grad = getOutputGrad();
MatrixPtr gradTrans = Matrix::create(blockSize, blockNum, false, useGpu_);
size_t batchSize = preGrad->getHeight();
CHECK_EQ(batchSize * blockNum, grad->getHeight());
CHECK_EQ(blockSize, grad->getWidth());
for (size_t i = 0; i < batchSize; i++) {
MatrixPtr gradTmp =
Matrix::create(grad->getData() + i * blockNum * blockSize,
blockNum,
blockSize,
false,
useGpu_);
gradTmp->transpose(gradTrans, false);
MatrixPtr preGradTmp =
Matrix::create(preGrad->getData() + i * preGrad->getWidth(),
1,
preGrad->getWidth(),
false,
useGpu_);
preGradTmp->convShrink(*gradTrans,
imgSizeH_,
imgSizeW_,
channels_,
blockH_,
blockW_,
strideH_,
strideW_,
paddingH_,
paddingW_,
outputH_,
outputW_,
1.0,
1.0);
if (getInputGrad(0)) {
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getOutputGrad(), outputShape_);
outputs.addArg(*getInputGrad(0), inputShape_, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
}
......
......@@ -50,8 +50,8 @@ protected:
size_t blockH_, blockW_, strideH_, strideW_, paddingH_, paddingW_;
size_t imgSizeH_, imgSizeW_, outputH_, outputW_, channels_;
/// auxiliary variable, which saves the transposed output value.
MatrixPtr outVTrans_;
TensorShape inputShape_;
TensorShape outputShape_;
public:
explicit BlockExpandLayer(const LayerConfig& config) : Layer(config) {}
......
......@@ -1016,81 +1016,6 @@ void GpuMatrix::check(std::ostream& os, Matrix& refMat, bool printDiff) {
LOG(INFO) << "the diffCnt is " << diffCnt;
}
void GpuMatrix::convExpand(Matrix& feature,
int feaImgHeight,
int feaImgWidth,
int channels,
int blockH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW) {
CHECK(feature.useGpu_ == true) << "Matrix type are not equal";
CHECK_EQ(size_t(feaImgHeight * feaImgWidth * channels),
feature.getHeight() * feature.getWidth())
<< "Matrix dimensions are not equal";
size_t elemCnt = outputH * outputW * blockH * blockW * channels;
CHECK_EQ(elemCnt, height_ * width_) << "Matrix dimensions are not equal";
hl_expand_feature2col(feature.getData(),
channels,
feaImgHeight,
feaImgWidth,
blockH,
blockW,
strideH,
strideW,
paddingH,
paddingW,
outputH,
outputW,
getData());
}
void GpuMatrix::convShrink(Matrix& expandFeat,
int thisImgHeight,
int thisImgWidth,
int channels,
int blockH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW,
real alpha,
real beta) {
CHECK(expandFeat.useGpu_ == true) << "Matrix type are not equal";
CHECK_EQ(size_t(thisImgHeight * thisImgWidth * channels),
getHeight() * getWidth())
<< "Matrix dimensions are not equal";
size_t elemCnt = outputH * outputW * blockW * blockH * channels;
CHECK(elemCnt == expandFeat.getHeight() * expandFeat.getWidth())
<< "Matrix dimensions are not equal";
hl_shrink_col2feature(expandFeat.getData(),
channels,
thisImgHeight,
thisImgWidth,
blockH,
blockW,
strideH,
strideW,
paddingH,
paddingW,
outputH,
outputW,
getData(),
alpha,
beta);
}
void GpuMatrix::maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
......@@ -1777,103 +1702,6 @@ void CpuMatrix::inverse(MatrixPtr& matInv, bool memAlloc) {
CHECK_EQ(info, 0);
}
void CpuMatrix::convExpand(Matrix& feature,
int feaImgHeight,
int feaImgWidth,
int channels,
int blockH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW) {
CHECK(feature.useGpu_ == false) << "Matrix type are not equal";
CHECK_EQ(size_t(feaImgHeight * feaImgWidth * channels),
feature.getHeight() * feature.getWidth())
<< "Matrix dimensions are not equal";
size_t elemCnt = outputH * outputW * blockH * blockW * channels;
CHECK_EQ(elemCnt, height_ * width_) << "Matrix dimensions are not equal";
int channelsCol = channels * blockH * blockW;
real* srcData = feature.getData();
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % blockW;
int hOffset = (c / blockW) % blockH;
int c_im = c / blockH / blockW;
for (int h = 0; h < outputH; ++h) {
for (int w = 0; w < outputW; ++w) {
// no c_im*height to Exclude the channel number
int imgRowIdx = h * strideH + hOffset;
int imgColIdx = w * strideW + wOffset;
if ((imgRowIdx - paddingH) < 0 ||
(imgRowIdx - paddingH) >= feaImgHeight ||
(imgColIdx - paddingW) < 0 ||
(imgColIdx - paddingW) >= feaImgWidth) {
data_[(c * outputH + h) * outputW + w] = 0;
} else {
imgRowIdx += c_im * feaImgHeight - paddingH;
imgColIdx -= paddingW;
data_[(c * outputH + h) * outputW + w] =
srcData[imgRowIdx * feaImgWidth + imgColIdx];
}
}
}
}
}
void CpuMatrix::convShrink(Matrix& expandFeat,
int thisImgHeight,
int thisImgWidth,
int channels,
int blockH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW,
real alpha,
real beta) {
CHECK(expandFeat.useGpu_ == false) << "Matrix type are not equal";
CHECK_EQ(size_t(thisImgHeight * thisImgWidth * channels),
getHeight() * getWidth())
<< "Matrix dimensions are not equal";
size_t elemCnt = outputH * outputW * blockH * blockW * channels;
CHECK(elemCnt == expandFeat.getHeight() * expandFeat.getWidth())
<< "Matrix dimensions are not equal";
real* expandData = expandFeat.getData();
int channelsCol = channels * blockH * blockW;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % blockW;
int hOffset = (c / blockW) % blockH;
int c_im = c / blockW / blockH;
for (int h = 0; h < outputH; ++h) {
for (int w = 0; w < outputW; ++w) {
int imRowIdx = h * strideH + hOffset;
int imColIdx = w * strideW + wOffset;
if ((imRowIdx - paddingH) >= 0 &&
(imRowIdx - paddingH) < thisImgHeight &&
(imColIdx - paddingW) >= 0 &&
(imColIdx - paddingW) < thisImgWidth) {
imRowIdx += c_im * thisImgHeight - paddingH;
imColIdx -= paddingW;
data_[imRowIdx * thisImgWidth + imColIdx] =
alpha * expandData[(c * outputH + h) * outputW + w] +
beta * data_[imRowIdx * thisImgWidth + imColIdx];
}
}
}
}
}
void CpuMatrix::maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
......
......@@ -859,49 +859,6 @@ public:
LOG(FATAL) << "Not implemented";
}
/**
* This function is used to calculate the convolution:
*
* It will expand a feature matrix according to the
* convolution filters
*/
virtual void convExpand(Matrix& feature,
int feaImgHeight,
int feaImgWidth,
int channels,
int blockH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW) {
LOG(FATAL) << "Not implemeted";
}
/**
* This function is the reverse implementation of convExpand:
*
* Its function is to restore a expanded-matrix into a feature matrix
*/
virtual void convShrink(Matrix& expandColMat,
int thisImgHeight,
int thisImgWidth,
int channels,
int blockH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW,
real alpha = 1.0f,
real beta = 0.0f) {
LOG(FATAL) << "Not implemeted";
}
/**
* Pooling forward operation, pick out the largest element
* in the sizeX of value
......@@ -1335,34 +1292,6 @@ public:
void classificationError(Matrix& output, IVector& label, size_t topkSize = 1);
void convExpand(Matrix& feature,
int feaImgHeight,
int feaImgWidth,
int channels,
int blockH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW);
void convShrink(Matrix& expandColMat,
int thisImgHeight,
int thisImgWidth,
int channels,
int blockH,
int blochW,
int strideH,
int strideW,
int paddingH,
int paddingWreal,
int outputH,
int outputW,
real alpha = 1.0f,
real beta = 0.0f);
void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
......@@ -1522,34 +1451,6 @@ public:
MatrixPtr clone(size_t height, size_t width, bool useGpu = false);
void convExpand(Matrix& feature,
int feaImgHeight,
int feaImgWidth,
int channels,
int blcokH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW);
void convShrink(Matrix& expandFeat,
int thisImgHeight,
int thisImgWidth,
int channels,
int blockH,
int blockW,
int strideH,
int strideW,
int paddingH,
int paddingW,
int outputH,
int outputW,
real alpha = 1.0f,
real beta = 0.0f);
void maxPoolForward(Matrix& inputMat,
size_t imgSizeH,
size_t imgSizeW,
......
......@@ -48,6 +48,11 @@ void ExposeOperator(ClassType& m) {
.def("__str__", &ClassType::type::DebugString);
}
static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator;
return generator.fetch_add(1);
}
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle");
......@@ -98,7 +103,8 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference)
.def("create_var",
&pd::Scope::CreateVariable,
py::return_value_policy::reference);
py::return_value_policy::reference)
.def("get_var_name", &pd::Scope::GetVariableName);
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
......@@ -141,23 +147,24 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(operator_base);
using PlainNetPtr = std::shared_ptr<pd::PlainNet>;
py::class_<pd::PlainNet, PlainNetPtr> plain_net(m, "PlainNet");
plain_net
.def_static("create",
[]() -> std::shared_ptr<pd::PlainNet> {
auto retv = std::make_shared<pd::PlainNet>();
retv->type_ = "plain_net";
return retv;
})
py::class_<pd::PlainNet, PlainNetPtr> net(m, "Net");
net.def_static("create",
[]() -> std::shared_ptr<pd::PlainNet> {
auto retv = std::make_shared<pd::PlainNet>();
retv->type_ = "plain_net";
return retv;
})
.def("add_op", &pd::PlainNet::AddOp)
.def("add_op",
[](PlainNetPtr& self, const PlainNetPtr& plain_net) -> void {
self->AddOp(std::static_pointer_cast<pd::OperatorBase>(plain_net));
[](PlainNetPtr& self, const PlainNetPtr& net) -> void {
self->AddOp(std::static_pointer_cast<pd::OperatorBase>(net));
})
.def("complete_add_op", &pd::PlainNet::CompleteAddOp)
.def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); });
ExposeOperator(plain_net);
ExposeOperator(net);
m.def("unique_integer", UniqueIntegerGenerator);
return m.ptr();
}
......@@ -220,6 +220,9 @@ def create_op_creation_method(op_proto):
__impl__.all_input_args = [var.name for var in op_proto.inputs]
__impl__.all_output_args = [var.name for var in op_proto.outputs]
__impl__.all_attr_args = [attr.name for attr in op_proto.attrs]
__impl__.all_not_temp_output_args = [
var.name for var in op_proto.outputs if not var.temporary
]
return __impl__
......
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
from default_scope_funcs import create_var, get_var, get_cur_scope
__all__ = ['Network'] # Only expose Network
class NetworkFunctor(object):
"""
Network Op Creation Function. Used internally in this module.
It convert string input to Variable. If it is not created before, just
create in scope.
It is a functor object. means the instances are callable.
:param func: The op creation function which generated in Python.
:param net: The Network instance.
"""
def __init__(self, func, net):
self.func = func
self.net = net
def __call__(self, *args, **kwargs):
if len(args) != 0:
raise ValueError("Paddle must use keyword argument")
inputs = self.func.all_input_args
for ipt in inputs:
if ipt in kwargs:
var = kwargs[ipt]
if isinstance(var, basestring):
var = create_var(var)
if not isinstance(var, core.Variable):
raise TypeError(
"Input of op creation must be string or variable")
kwargs[ipt] = get_cur_scope().get_var_name(var)
notemp_outputs = self.func.all_not_temp_output_args
for name in notemp_outputs:
if name not in kwargs:
kwargs[
name] = self.func.__name__ + "@OUT@%d" % core.unique_integer(
)
outputs = self.func.all_output_args
for opt in outputs:
if opt in kwargs:
var = kwargs[opt]
if isinstance(var, basestring):
var = create_var(var)
if not isinstance(var, core.Variable):
raise TypeError(
"Output of op creation must be string or variable")
kwargs[opt] = get_cur_scope().get_var_name(var)
op = self.func(**kwargs)
self.net.net.add_op(op)
lst = [get_var(kwargs[opt]) for opt in notemp_outputs]
if len(lst) == 1:
return lst[0]
elif len(lst) == 0:
return None
else:
return lst
class Network(object):
"""
The network concept. It avoid user to manually create operator, create
variable, and combine them into a Net. Just use Network.xxx can create the
operator, create variables in default scope, and add them into `self.net`.
For example:
.. code-block: python
net = Network()
out = net.add_two(X="a", Y="b")
fc_out = net.fc(X="out", W="fc.w")
net.run(...)
"""
def __init__(self):
self.net = core.Net.create()
funcs = (func_name for func_name in dir(op_creations)
if not func_name.startswith("__"))
# TODO(yuyang18): This code can work, but do not generate a good
# docstring, try to give a better way generate function in runtime
# later.
for func_name in funcs:
func = getattr(op_creations, func_name)
impl = NetworkFunctor(func, self)
setattr(self, func_name, impl.__call__)
self.__complete_add_op__ = False
def infer_shape(self):
self.complete_add_op()
self.net.infer_shape(get_cur_scope())
def run(self, device_context):
self.complete_add_op()
self.net.run(get_cur_scope(), device_context)
def __str__(self):
return str(self.net)
def complete_add_op(self):
if not self.__complete_add_op__:
self.net.complete_add_op()
self.__complete_add_op__ = True
if __name__ == '__main__':
net = Network()
out = net.add_two(X="a", Y="b")
fc_out = net.fc(X=out, W="fc.w", b="fc.b", activation="softmax")
net.complete_add_op()
print net
......@@ -3,7 +3,7 @@ add_python_test(test_framework
test_scope.py
test_default_scope_funcs.py
test_op_creation_methods.py
test_plain_net.py
test_net.py
test_tensor.py
test_fc_op.py
test_add_two_op.py
......@@ -12,4 +12,5 @@ add_python_test(test_framework
test_mul_op.py
test_sigmoid_op.py
test_softmax_op.py
test_rowwise_add_op.py)
test_rowwise_add_op.py
test_network.py)
......@@ -5,11 +5,11 @@ import unittest
class TestNet(unittest.TestCase):
def test_net_all(self):
net = core.PlainNet.create()
net = core.Net.create()
op1 = op_creations.add_two(X="X", Y="Y", Out="Out")
net.add_op(op1)
net2 = core.PlainNet.create()
net2 = core.Net.create()
net2.add_op(op_creations.fc(X="X", W="w", Y="fc.out"))
net2.complete_add_op(True)
net.add_op(net2)
......
from paddle.v2.framework.network import Network
import paddle.v2.framework.core as core
import unittest
class TestNet(unittest.TestCase):
def test_net_all(self):
net = Network()
out = net.add_two(X="X", Y="Y")
fc_out = net.fc(X=out, W="w")
net.complete_add_op()
self.assertTrue(isinstance(fc_out, core.Variable))
self.assertEqual(
'''Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0).
Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0).
Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1).
''', str(net))
net2 = Network()
tmp = net2.add_two(X="X", Y="Y")
self.assertTrue(isinstance(tmp, core.Variable))
net2.complete_add_op()
self.assertEqual(
'''Op(plain_net), inputs:(X, Y), outputs:(add_two@OUT@2).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2).
''', str(net2))
if __name__ == '__main__':
unittest.main()
import ctypes
import os
path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
lib = ctypes.cdll.LoadLibrary(path)
__lib__ = None
def get_c_lib():
global __lib__
if __lib__ is None:
path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
__lib__ = ctypes.cdll.LoadLibrary(path)
return __lib__
class client(object):
......@@ -11,8 +18,8 @@ class client(object):
"""
def __init__(self, etcd_endpoints, timeout_sec, buf_size=0):
self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout_sec,
buf_size)
self.c = get_c_lib().paddle_new_etcd_master_client(
etcd_endpoints, timeout_sec, buf_size)
def request_save_model(self, trainer_id, block_ms):
"""request to save model
......@@ -32,10 +39,11 @@ class client(object):
saving the model, -1 if error happened.
"""
return lib.paddle_request_save_model(self.c, trainer_id, block_ms)
return get_c_lib().paddle_request_save_model(self.c, trainer_id,
block_ms)
def release(self):
lib.paddle_release_master_client(self.c)
get_c_lib().paddle_release_master_client(self.c)
self.c = None
def set_dataset(self, paths):
......@@ -45,7 +53,7 @@ class client(object):
for idx, path in enumerate(paths):
c_ptr = ctypes.c_char_p(path)
holder[idx] = c_ptr
lib.paddle_set_dataset(self.c, holder, len(paths))
get_c_lib().paddle_set_dataset(self.c, holder, len(paths))
def next_record(self):
"""gets next record for training
......@@ -56,7 +64,7 @@ class client(object):
"""
p = ctypes.c_char_p()
ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret)
size = get_c_lib().paddle_next_record(self.c, ret)
if size < 0:
# Error
return None, size
......@@ -67,5 +75,5 @@ class client(object):
record = ret.contents.value[:size]
# Memory created from C should be freed.
lib.mem_free(ret.contents)
get_c_lib().mem_free(ret.contents)
return record, 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册