提交 dbf1d75f 编写于 作者: H hedaoyuan

Add a GemmConvMobileFunction.

上级 f66c17b6
......@@ -134,6 +134,154 @@ public:
}
};
/*
* \brief Forward calculation of convolution, optimized for mobile.
*/
template <DeviceType Device>
class GemmConvMobileFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
checkShape(input, filter, output);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
// TODO(hedaoyuan): Need to define some index macros,
// to avoid useing 0 and 1.
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
real beta;
if (outputs[0].getArgType() == ADD_TO) {
beta = 1.0;
} else {
beta = 0.0;
}
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>();
bool needIm2col = isNeedIm2col(filter);
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape;
real* colData = NULL;
size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth;
size_t colWidth = outputHeight * outputWidth;
// Max col matrix height 256, Max col matrix width 1024
size_t stepColHeight = std::min(colHeight, (size_t)256);
size_t stepColWidth = std::min(colWidth, (size_t)2048);
if (needIm2col) {
colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(stepColHeight * stepColWidth * sizeof(real));
colData = reinterpret_cast<real*>(memory_->getBuf());
}
Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = imShape.getElements();
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
int nStride = colWidth;
int kStride = colHeight;
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
if (needIm2col) {
real beta_ = beta;
for (size_t colHeightStart = 0; colHeightStart < colHeight;
colHeightStart += stepColHeight) {
for (size_t colWidthStart = 0; colWidthStart < colWidth;
colWidthStart += stepColWidth) {
int N = std::min(colWidth - colWidthStart, stepColWidth);
int K = std::min(colHeight - colHeightStart, stepColHeight);
// im2col
im2col(inputData + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW(),
colHeightStart,
K,
colWidthStart,
N);
// gemm
int M = outputChannels / groups_;
gemm(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
1.0f,
filterData + g * filterOffset + colHeightStart,
kStride,
colData,
N,
beta_,
outputData + g * outputOffset + colWidthStart,
nStride);
}
beta_ = 1.0;
}
} else {
int M = outputChannels / groups_;
int N = outputHeight * outputWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
1.0f,
filterData + g * filterOffset,
K,
inputData + g * inputOffset,
N,
beta,
outputData + g * outputOffset,
N);
}
}
inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
}
}
};
/*
* \brief Backward input calculation of convolution.
*/
......@@ -348,7 +496,11 @@ public:
}
};
#ifdef PADDLE_MOBILE_INFERENCE
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvMobileFunction);
#else
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
#endif
REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction);
#ifdef PADDLE_WITH_CUDA
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册