提交 faa45610 编写于 作者: xiebaiyuan's avatar xiebaiyuan

use gemm to s1p0 instead of s1p1

上级 d0063b37
...@@ -121,7 +121,8 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) { ...@@ -121,7 +121,8 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.paddings_[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, false); param.Bias(), true, false);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
......
...@@ -125,7 +125,8 @@ void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) { ...@@ -125,7 +125,8 @@ void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.paddings_[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true, true); param.Bias(), true, true);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
......
...@@ -122,7 +122,8 @@ void ConvBNAddReluCompute(const FusionConvBNAddReluParam<CPU> &param) { ...@@ -122,7 +122,8 @@ void ConvBNAddReluCompute(const FusionConvBNAddReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.paddings_[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(), param.Output(), param.NewScale(),
param.NewBias(), true); param.NewBias(), true);
......
...@@ -120,7 +120,8 @@ void ConvBNReluCompute(const FusionConvBNReluParam<CPU> &param) { ...@@ -120,7 +120,8 @@ void ConvBNReluCompute(const FusionConvBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.paddings_[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(), param.Output(), param.NewScale(),
param.NewBias(), true); param.NewBias(), true);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册