提交 92332bba 编写于 作者: E eclipsess

temp fix dw3x3

上级 2e9ae70a
...@@ -118,13 +118,15 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) { ...@@ -118,13 +118,15 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true); param.Bias(), true);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(), // math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(), // param.Paddings(),
// param.Filter(), param.Bias(), // param.Filter(), param.Bias(),
......
...@@ -118,14 +118,16 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam<CPU> &param) { ...@@ -118,14 +118,16 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(), param.Output(), param.NewScale(),
param.NewBias(), true); param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(), // param.Output(), param.NewScale(),
// param.NewBias(), 1); // param.NewBias(), 1);
......
...@@ -124,13 +124,15 @@ void ConvCompute(const ConvParam<CPU> &param) { ...@@ -124,13 +124,15 @@ void ConvCompute(const ConvParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false); nullptr, false);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3) { param.Filter()->dims()[2] == 3 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false); param.Filter(), nullptr, param.Output(), false);
} else { } else {
......
...@@ -122,14 +122,16 @@ void ConvBNAddReluCompute(const FusionConvBNAddReluParam<CPU> &param) { ...@@ -122,14 +122,16 @@ void ConvBNAddReluCompute(const FusionConvBNAddReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(), param.Output(), param.NewScale(),
param.NewBias(), true); param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(), // param.Output(), param.NewScale(),
// param.NewBias(), 1); // param.NewBias(), 1);
......
...@@ -117,14 +117,16 @@ void ConvBNReluCompute(const FusionConvBNReluParam<CPU> &param) { ...@@ -117,14 +117,16 @@ void ConvBNReluCompute(const FusionConvBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(), param.Output(), param.NewScale(),
param.NewBias(), true); param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(), // param.Output(), param.NewScale(),
// param.NewBias(), 1); // param.NewBias(), 1);
......
...@@ -30,13 +30,15 @@ void DepthwiseConvCompute(const ConvParam<CPU> &param) { ...@@ -30,13 +30,15 @@ void DepthwiseConvCompute(const ConvParam<CPU> &param) {
Bias.mutable_data<float>({param.Groups()}); Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[3]) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
&Bias, false); &Bias, false);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 &&
param.Input()->dims()[2] == param.Input()->dims()[3]) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(), // math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(), // param.Paddings(),
// param.Filter(), &Bias, param.Output(), false); // param.Filter(), &Bias, param.Output(), false);
......
...@@ -115,14 +115,16 @@ void DWConvBNReluCompute(const FusionDWConvBNReluParam<CPU> &param) { ...@@ -115,14 +115,16 @@ void DWConvBNReluCompute(const FusionDWConvBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(), param.Output(), param.NewScale(),
param.NewBias(), true); param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(), // param.Output(), param.NewScale(),
// param.NewBias(), 1); // param.NewBias(), 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册