提交 d10b2479 编写于 作者: E eclipsess

temp fix dw3x3

上级 ca33688e
......@@ -118,13 +118,15 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(),
......
......@@ -118,14 +118,16 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
......
......@@ -124,13 +124,15 @@ void ConvCompute(const ConvParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3) {
param.Filter()->dims()[2] == 3 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
} else {
......
......@@ -122,14 +122,16 @@ void ConvBNAddReluCompute(const FusionConvBNAddReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
......
......@@ -117,14 +117,16 @@ void ConvBNReluCompute(const FusionConvBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
......
......@@ -30,13 +30,15 @@ void DepthwiseConvCompute(const ConvParam<CPU> &param) {
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[3]) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
&Bias, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 &&
param.Input()->dims()[2] == param.Input()->dims()[3]) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), &Bias, param.Output(), false);
......
......@@ -115,14 +115,16 @@ void DWConvBNReluCompute(const FusionDWConvBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 &&
param.Input()->dims()[2] == param.Input()->dims()[2]) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册