diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc index 8a1f2d4b38a6d0e16db9c6d6d10ef736fe235be8..96fb0681c8afa86b54eb28ae3003068d26698d90 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc @@ -70,6 +70,7 @@ STATUS TflitePoolingParser::Parse(const std::unique_ptr &tfli attr->global = false; attr->roundMode = schema::RoundMode_FLOOR; + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); // calculate pad params auto data_index = tflite_op->inputs[0]; diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index e444b6dc4bb0881d30681579383ae8a77e5991f3..64e638038916947ea70c8c383ce74497eb009ab9 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -346,6 +346,14 @@ bool IsConvNode(const BaseRef &n) { return false; } +bool IsPoolingNode(const BaseRef &n) { + if (utils::isa(n) || utils::isa(n)) { + auto type = opt::GetCNodeType(n); + return type == schema::PrimitiveType_Pooling; + } + return false; +} + bool CheckIsAllInputsParam(const AnfNodePtr &node) { if (utils::isa(node)) { auto cnode = node->cast(); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 2299779b394bae4752a34c82103346ef56a3796c..066882caca05030b068f7cec81eba53cf2f58f29 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -56,6 +56,8 @@ bool IsParamNode(const BaseRef &n); bool IsConvNode(const BaseRef &n); +bool IsPoolingNode(const BaseRef &n); + bool CheckIsAllInputsParam(const AnfNodePtr &node); size_t GetOutputTensorNum(const AnfNodePtr &node); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc index c779465eeb668ca3c89942e0fa202baf3d0172bc..c89c25f0bc21b6df29b64a6953048825b444b7a5 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc @@ -6,7 +6,7 @@ * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 - *conv_activation_fusion.h + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h index 9436240bf1ce9c2b616b2a6955e361b2e79c3282..af7d900bec508b295852f413aba41738bc4a7e05 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h @@ -6,7 +6,7 @@ * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 - *conv_activation_fusion.h + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/mindspore/lite/tools/optimizer/fusion/pooling_activation_fusion.cc b/mindspore/lite/tools/optimizer/fusion/pooling_activation_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..acd3579ea2588b8b07ced52db717e2cb220ea9da --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/pooling_activation_fusion.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/pooling_activation_fusion.h" +#include +#include "src/ops/primitive_c.h" +#include "src/ops/pooling.h" +#include "src/ops/activation.h" +#include "schema/inner/model_generated.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore::opt { +namespace { +constexpr size_t kActivationInputsLength = 2; +} +const BaseRef PoolingActivationFusion::DefinePattern() const { + auto pooling_var = std::make_shared(IsPoolingNode)(); + auto prim = new schema::PrimitiveT(); + if (prim == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return nullptr; + } + prim->value.type = primitive_type; + auto prim_value = std::make_shared(prim); + + return VectorRef({prim_value, pooling_var}); +} + +const AnfNodePtr PoolingActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_LOG(DEBUG) << "pooling activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type]; + CheckIfFuncGraphIsNull(func_graph); + + CheckIfAnfNodeIsNull(node); + auto act_node = node->cast(); + CheckIfCNodeIsNull(act_node); + CheckInputSize(act_node, kActivationInputsLength); + + auto primitivec = GetValueNode>(act_node->input(0)); + MS_ASSERT(utils::isa>(primitivec)); + auto act_primitivec = utils::cast>(primitivec); + MS_ASSERT(act_primitivec != nullptr); + if (act_primitivec->GetType() != activation_type) { + return node; + } + AnfNodePtr pre_node = act_node->input(1); + CheckIfAnfNodeIsNull(pre_node); + if (pre_node != nullptr && pre_node->isa()) { + if (IsMultiOutputTensors(func_graph, pre_node)) { + return node; + } + auto pooling_node = pre_node->cast(); + auto primitive_c = GetValueNode>(pooling_node->input(0)); + MS_ASSERT(primitive_c); + + MS_ASSERT(utils::isa>(primitive_c)); + auto primc = utils::cast>(primitive_c); + MS_ASSERT(primc != nullptr); + if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) { + primc->SetActivationType(activation_type); + return pre_node; + } + } + return node; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/pooling_activation_fusion.h b/mindspore/lite/tools/optimizer/fusion/pooling_activation_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..b01206ea82845248926b280e4ac0ae1158d98eec --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/pooling_activation_fusion.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_POOLING_ACTIVATION_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_POOLING_ACTIVATION_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace opt { +class PoolingActivationFusion : public PatternProcessPass { + public: + explicit PoolingAActivationFusion(bool multigraph = true, const std::string &name = "pooling_activation_fusion", + schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU, + schema::ActivationType activation = schema::ActivationType_LEAKY_RELU) + : PatternProcessPass(name, multigraph), primitive_type(primitive), activation_type(activation) {} + ~PoolingAActivationFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + schema::PrimitiveType primitive_type; + schema::ActivationType activation_type; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_POOLING_ACTIVATION_FUSION_H_