“develop”上不存在“9_convolutional_bn”
提交 66520af9 编写于 作者: X xzl

accelerate inputbackward(delete 'if' in this func) of depthwise conv

上级 dbb65880
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include "DepthwiseConvOp.h" #include "DepthwiseConvOp.h"
#include "GemmFunctor.h" #include "GemmFunctor.h"
#include "paddle/math/BaseMatrix.h" #include "paddle/math/BaseMatrix.h"
...@@ -93,29 +94,32 @@ void ConvolutionDepthwiseInputBackward(const int nthreads, ...@@ -93,29 +94,32 @@ void ConvolutionDepthwiseInputBackward(const int nthreads,
const int c_in = (index / inputHeight / inputWidth) % inputChannels; const int c_in = (index / inputHeight / inputWidth) % inputChannels;
const int h_in = (index / inputWidth) % inputHeight; const int h_in = (index / inputWidth) % inputHeight;
const int w_in = index % inputWidth; const int w_in = index % inputWidth;
const int c_out_start = c_in * filterMultiplier; const int c_out_start = c_in * filterMultiplier;
int h_out_start = (h_in - filterHeight + paddingH + strideH)/strideH;
h_out_start = 0 > h_out_start ? 0 : h_out_start;
int h_out_end = (h_in + paddingH)/strideH;
h_out_end = outputHeight - 1 < h_out_end? outputHeight - 1 : h_out_end;
int w_out_start = (w_in - filterWidth + paddingW + strideW)/strideW;
w_out_start = 0 > w_out_start ? 0 : w_out_start;
int w_out_end = (w_in + paddingW)/strideW;
w_out_end = outputWidth - 1 < w_out_end? outputWidth - 1 : w_out_end;
T value = 0; T value = 0;
for (int c_out = c_out_start; for (int c_out = c_out_start;
c_out < c_out_start + filterMultiplier; c_out ++) { c_out < c_out_start + filterMultiplier; c_out ++) {
const T* weight = weight_data + c_out * filterHeight * filterWidth; for (int h_out = h_out_start; h_out <= h_out_end; ++h_out) {
for (int kh = 0; kh < filterHeight; ++kh) { const int filter_h = h_in + paddingH - h_out * strideH;
for (int kw = 0; kw < filterWidth; ++kw) { for (int w_out = w_out_start; w_out <= w_out_end; ++w_out) {
const int h_out_s = h_in + paddingH - kh; const int filter_w = w_in + paddingW - w_out * strideW;
const int w_out_s = w_in + paddingW - kw; const int filter_offset = c_out * filterHeight * filterWidth
if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) { + filter_h * filterWidth + filter_w;
const int h_out = h_out_s / strideH; const int top_diff_offset = ((batch * outputChannels + c_out) *
const int w_out = w_out_s / strideW; outputHeight + h_out)* outputWidth + w_out;
// TODO(zhaolong) : the 'if' affect the effectiveness, value += top_diff[top_diff_offset] * weight_data[filter_offset];
// it needs to optimize }
if ((h_out >= 0) && (h_out < outputHeight)
&& (w_out >= 0) && (w_out < outputWidth)) {
const int offset = ((batch * outputChannels + c_out)
* outputHeight + h_out) * outputWidth + w_out;
value += (*weight) * top_diff[offset];
}
}
++weight;
}
} }
} }
bottom_diff[index] += value; bottom_diff[index] += value;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册