diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast2_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast2_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5b2a92bb74c9e2a24dda6a97e0dce0f7ecd7b92 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast2_fusion.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/replace_bn_grad_cast2_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef ReplaceBNGradCast2Fusion::DefinePattern() const { + VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGrad, dy_, x_, scale_, mean_, var_}); + VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2g, index_}); + VectorRef out_cast = VectorRef({prim::kPrimCast, tupleget}); + return out_cast; +} + +const AnfNodePtr ReplaceBNGradCast2Fusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto tuple = AnfAlgo::GetInputNode(utils::cast(node), 0); + auto index_node = AnfAlgo::GetInputNode(utils::cast(tuple), 1); + MS_EXCEPTION_IF_NULL(index_node); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + if (item_idx != 0) { + return nullptr; + } + auto fbn2g = AnfAlgo::GetInputNode(utils::cast(tuple), 0); + + auto dy_ = AnfAlgo::GetInputNode(utils::cast(fbn2g), 0); + auto x_ = AnfAlgo::GetInputNode(utils::cast(fbn2g), 1); + + auto scale = AnfAlgo::GetInputNode(utils::cast(fbn2g), 2); + auto mean = AnfAlgo::GetInputNode(utils::cast(fbn2g), 3); + auto var = AnfAlgo::GetInputNode(utils::cast(fbn2g), 4); + + MS_EXCEPTION_IF_NULL(fbn2g); + MS_EXCEPTION_IF_NULL(dy_); + MS_EXCEPTION_IF_NULL(scale); + MS_EXCEPTION_IF_NULL(x_); + MS_EXCEPTION_IF_NULL(mean); + MS_EXCEPTION_IF_NULL(var); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->Replace(utils::cast(node), utils::cast(tuple)); + std::vector outputs_type; + std::vector> outputs_shape; + auto output_num = AnfAlgo::GetOutputTensorNum(fbn2g); + for (size_t i = 0; i < output_num; i++) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2g, i)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(fbn2g, i)); + } + outputs_type[0] = AnfAlgo::GetPrevNodeOutputInferDataType(fbn2g, 0); + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2g.get()); + + outputs_type.clear(); + outputs_shape.clear(); + outputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(fbn2g, 0)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(tuple, 0)); + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, tuple.get()); + + return tuple; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast2_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast2_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..fcb56be7123ba58acf8cd15349b13a4ebeb566ab --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast2_fusion.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_GRAD_CAST2_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_GRAD_CAST2_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ReplaceBNGradCast2Fusion : public PatternProcessPass { + public: + explicit ReplaceBNGradCast2Fusion(bool multigraph = true) : PatternProcessPass("replace_grad_cast2", multigraph) { + dy_ = std::make_shared(); + x_ = std::make_shared(); + scale_ = std::make_shared(); + mean_ = std::make_shared(); + var_ = std::make_shared(); + dx_ = std::make_shared(); + bn_scale_ = std::make_shared(); + bn_bias_ = std::make_shared(); + index_ = std::make_shared(); + } + ~ReplaceBNGradCast2Fusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr dy_; + VarPtr x_; + VarPtr scale_; + VarPtr mean_; + VarPtr var_; + VarPtr dx_; + VarPtr bn_scale_; + VarPtr bn_bias_; + VarPtr index_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_GRAD_CAST2_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..9dba16bf860a57f7cb89513b38a0f16b94784bd2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef ReplaceBNGradCastFusion::DefinePattern() const { + VectorRef dy_cast = VectorRef({prim::kPrimCast, dy_}); + VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGrad, dy_cast, x_, scale_, mean_, var_}); + VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2g, index_}); + VectorRef out_cast = VectorRef({prim::kPrimCast, tupleget}); + return out_cast; +} + +const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + + auto tuple = AnfAlgo::GetInputNode(utils::cast(node), 0); + auto index_node = AnfAlgo::GetInputNode(utils::cast(tuple), 1); + MS_EXCEPTION_IF_NULL(index_node); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + if (item_idx != 0) { + return nullptr; + } + auto fbn2g = AnfAlgo::GetInputNode(utils::cast(tuple), 0); + + auto dy_after = AnfAlgo::GetInputNode(utils::cast(fbn2g), 0); + auto dy_before = AnfAlgo::GetInputNode(utils::cast(dy_after), 0); + auto x_ = AnfAlgo::GetInputNode(utils::cast(fbn2g), 1); + + auto scale = AnfAlgo::GetInputNode(utils::cast(fbn2g), 2); + auto mean = AnfAlgo::GetInputNode(utils::cast(fbn2g), 3); + auto var = AnfAlgo::GetInputNode(utils::cast(fbn2g), 4); + + MS_EXCEPTION_IF_NULL(fbn2g); + MS_EXCEPTION_IF_NULL(dy_after); + MS_EXCEPTION_IF_NULL(dy_before); + MS_EXCEPTION_IF_NULL(scale); + MS_EXCEPTION_IF_NULL(x_); + MS_EXCEPTION_IF_NULL(mean); + MS_EXCEPTION_IF_NULL(var); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->Replace(utils::cast(dy_after), utils::cast(dy_before)); + manager->Replace(utils::cast(node), utils::cast(tuple)); + std::vector outputs_type; + std::vector> outputs_shape; + auto output_num = AnfAlgo::GetOutputTensorNum(fbn2g); + for (size_t i = 0; i < output_num; i++) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2g, i)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(fbn2g, i)); + } + outputs_type[0] = kNumberTypeFloat16; + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2g.get()); + outputs_type.clear(); + outputs_shape.clear(); + outputs_type.push_back(kNumberTypeFloat16); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(tuple, 0)); + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, tuple.get()); + return tuple; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..b937aa25bf65089673fb1633c32a11922438dbed --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_GRAD_CAST_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_GRAD_CAST_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ReplaceBNGradCastFusion : public PatternProcessPass { + public: + explicit ReplaceBNGradCastFusion(bool multigraph = true) : PatternProcessPass("replace_bn_grad_cast", multigraph) { + dy_ = std::make_shared(); + x_ = std::make_shared(); + scale_ = std::make_shared(); + mean_ = std::make_shared(); + var_ = std::make_shared(); + dx_ = std::make_shared(); + bn_scale_ = std::make_shared(); + bn_bias_ = std::make_shared(); + index_ = std::make_shared(); + } + ~ReplaceBNGradCastFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr dy_; + VarPtr x_; + VarPtr scale_; + VarPtr mean_; + VarPtr var_; + VarPtr dx_; + VarPtr bn_scale_; + VarPtr bn_bias_; + VarPtr index_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_GRAD_CAST_FUSION_H_