提交 90bf4f60 编写于 作者: H hedaoyuan

Add stride support 2 for NeonDepthwiseConvTranspose.

上级 840104c9
......@@ -566,6 +566,63 @@ struct Padding<float> {
}
};
// for stride is 2
struct StridePadding {
static void run(const float* input,
float* inputPadding,
int channels,
int inputHeight,
int inputWidth,
int padInputHeight,
int padInputWidth) {
const int paddingHeight = (padInputHeight - (inputHeight * 2 - 1)) / 2;
const int paddingWidth = (padInputWidth - (inputWidth * 2 - 1)) / 2;
for (int c = 0; c < channels; c++) {
if (paddingHeight > 0) {
memset(inputPadding, 0, padInputWidth * paddingHeight * sizeof(float));
inputPadding += padInputWidth * paddingHeight;
}
for (int i = 0; i < inputHeight; i++) {
// padding head
for (int j = 0; j < paddingWidth; j++) {
*inputPadding++ = float(0);
}
int step = inputWidth >> 2;
int remain = inputWidth & 3;
float32x4_t s1 = vdupq_n_f32(0.f);
for (int s = 0; s < step; s++) {
float32x4_t s0 = vld1q_f32(input);
float32x4x2_t v = {s0, s1};
vst2q_f32(inputPadding, v);
input += 4;
inputPadding += 8;
}
for (int r = 0; r < remain; r++) {
*inputPadding++ = *input++;
*inputPadding++ = float(0);
}
inputPadding--;
// padding tail
for (int j = 0; j < paddingWidth; j++) {
*inputPadding++ = float(0);
}
if (i != inputHeight - 1) {
memset(inputPadding, 0, padInputWidth * sizeof(float));
inputPadding += padInputWidth;
}
}
if (paddingHeight > 0) {
memset(inputPadding, 0, padInputWidth * paddingHeight * sizeof(float));
inputPadding += padInputWidth * paddingHeight;
}
}
}
};
#endif
#endif
......
......@@ -74,6 +74,7 @@ public:
int newSize = batchSize * inputChannels * padInputHeight * padInputWidth;
resizeBuffer<Device>(newSize);
inputPadding = reinterpret_cast<float*>(memory_->getBuf());
if (strideH() == 1) {
neon::Padding<float>::run(inputData,
inputPadding,
batchSize * inputChannels,
......@@ -81,6 +82,17 @@ public:
inputWidth,
padInputHeight,
padInputWidth);
} else if (strideH() == 2) {
neon::StridePadding::run(inputData,
inputPadding,
batchSize * inputChannels,
inputHeight,
inputWidth,
padInputHeight,
padInputWidth);
} else {
LOG(FATAL) << "Not supported";
}
}
std::function<void(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册