提交 2b061c84 编写于 作者: Y yujianfeng

Add batch_norm_grad infer fisson

上级 597933b0
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h"
#include <vector>
#include "pre_activate/common/helper.h"
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kBatchNormGradInferOutputNum = 3;
bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
if (manager->node_users().find(node) == manager->node_users().end()) {
MS_LOG(DEBUG) << "The node " << node->DebugString() << " should have some outputs";
return false;
}
for (const auto &node_index : manager->node_users()[node]) {
AnfNodePtr output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
auto tuple_getiterm_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode);
auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(index_node);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int index = GetValue<int>(value_node->value());
if (index == kBatchNormGradInferOutputNum || index == kBatchNormGradInferOutputNum + 1) {
MS_LOG(DEBUG) << "The output " << index << " of node " << node->DebugString() << " is not null, no need change";
return false;
}
}
return true;
}
} // namespace
AnfNodePtr BatchNormGradInferFission::CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(bn_grad);
MS_EXCEPTION_IF_NULL(equiv);
// Set inputs
auto iter_input0 = (*equiv).find(input0_var_);
if (iter_input0 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched.";
}
auto iter_input2 = (*equiv).find(input2_var_);
if (iter_input2 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input2 var after matched.";
}
auto iter_input4 = (*equiv).find(input4_var_);
if (iter_input4 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched.";
}
std::vector<AnfNodePtr> bn_infer_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNInferGradOpName)), utils::cast<AnfNodePtr>(iter_input0->second),
utils::cast<AnfNodePtr>(iter_input2->second), utils::cast<AnfNodePtr>(iter_input4->second)};
auto bn_infer_grad = func_graph->NewCNode(bn_infer_grad_inputs);
MS_EXCEPTION_IF_NULL(bn_infer_grad);
// Set abstract, the output of new node is taking the place of the 0th output of bn_grad.
auto bn_grad_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_grad->abstract());
MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple);
if (bn_grad_abstract_tuple->elements().empty()) {
MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be empty";
}
bn_infer_grad->set_abstract(bn_grad_abstract_tuple->elements()[0]);
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_infer_grad);
bn_infer_grad->set_scope(bn_grad->scope());
return bn_infer_grad;
}
AnfNodePtr BatchNormGradInferFission::CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph,
const AnfNodePtr &bn_grad,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(bn_grad);
MS_EXCEPTION_IF_NULL(equiv);
// Set inputs
auto iter_input0 = (*equiv).find(input0_var_);
if (iter_input0 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched.";
}
auto iter_input1 = (*equiv).find(input1_var_);
if (iter_input1 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input1 var after matched.";
}
auto iter_input3 = (*equiv).find(input3_var_);
if (iter_input3 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input3 var after matched.";
}
auto iter_input4 = (*equiv).find(input4_var_);
if (iter_input4 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched.";
}
std::vector<AnfNodePtr> bn_training_update_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)),
utils::cast<AnfNodePtr>(iter_input0->second), utils::cast<AnfNodePtr>(iter_input1->second),
utils::cast<AnfNodePtr>(iter_input3->second), utils::cast<AnfNodePtr>(iter_input4->second)};
auto bn_training_update_grad = func_graph->NewCNode(bn_training_update_grad_inputs);
MS_EXCEPTION_IF_NULL(bn_training_update_grad);
// Set abstract, the outputs of new node are taking the place of the 1st and 2nd outputs of bn_grad.
auto bn_grad_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_grad->abstract());
MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple);
if (bn_grad_abstract_tuple->elements().size() < kBatchNormGradInferOutputNum) {
MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be less than 3";
}
std::vector<AbstractBasePtr> abstract_list{bn_grad_abstract_tuple->elements()[1],
bn_grad_abstract_tuple->elements()[2]};
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
bn_training_update_grad->set_abstract(abstract_tuple);
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_training_update_grad);
bn_training_update_grad->set_scope(bn_grad->scope());
return bn_training_update_grad;
}
const BaseRef BatchNormGradInferFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimBatchNormGrad, input0_var_, input1_var_, input2_var_, input3_var_, input4_var_, Xs});
}
const AnfNodePtr BatchNormGradInferFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, node->cast<CNodePtr>())) {
MS_LOG(DEBUG) << "The BatchNormGrad " << node->DebugString() << " has no is_training attr, should not be changed";
return nullptr;
}
if (AnfAlgo::GetNodeAttr<bool>(node, kAttrIsTraining)) {
MS_LOG(DEBUG) << "The is_training attr value of " << node->DebugString() << " is true, no need change";
return nullptr;
}
if (!CheckOutputsIndex(func_graph, node)) {
MS_LOG(DEBUG) << "The output 3 or 4 of BatchNormGrad is not null, no need change";
return nullptr;
}
AnfNodePtr bn_infer_grad = CreateBNInferGrad(func_graph, node, equiv);
AnfNodePtr bn_training_update_grad = CreateBNTrainingUpdateGrad(func_graph, node, equiv);
std::vector<AnfNodePtr> bn_training_update_grad_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_grad, kBNTrainingUpdateGradOutputNum,
&bn_training_update_grad_outputs);
if (bn_training_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) {
MS_LOG(EXCEPTION) << "The output size of " << bn_training_update_grad << " should be "
<< kBNTrainingUpdateGradOutputNum << ", but it is " << bn_training_update_grad_outputs.size();
}
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_infer_grad,
bn_training_update_grad_outputs[0], bn_training_update_grad_outputs[1]};
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple);
return make_tuple;
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_
#include <memory>
#include "pre_activate/common/optimizer.h"
namespace mindspore {
namespace opt {
class BatchNormGradInferFission : public PatternProcessPass {
public:
explicit BatchNormGradInferFission(bool multigraph = true)
: PatternProcessPass("batch_norm_grad_infer_fission", multigraph),
input0_var_(std::make_shared<Var>()),
input1_var_(std::make_shared<Var>()),
input2_var_(std::make_shared<Var>()),
input3_var_(std::make_shared<Var>()),
input4_var_(std::make_shared<Var>()) {}
~BatchNormGradInferFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, const EquivPtr &equiv) const;
AnfNodePtr CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad,
const EquivPtr &equiv) const;
VarPtr input0_var_;
VarPtr input1_var_;
VarPtr input2_var_;
VarPtr input3_var_;
VarPtr input4_var_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_
......@@ -139,6 +139,7 @@ constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2D
constexpr auto kLabelSetOpName = "LabelSet";
constexpr auto kLabelSwitchOpName = "LabelSwitch";
constexpr auto kLabelGotoOpName = "LabelGoto";
constexpr auto kBNInferGradOpName = "BNInferGrad";
// attr key name
constexpr auto kAttrInputNames = "input_names";
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h"
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
namespace mindspore {
namespace opt {
class TestHWBatchNormGradInferFission : public BackendCommon {
public:
TestHWBatchNormGradInferFission()
: get_py_fun_("gtest_input.pre_activate.batch_norm_grad_infer_fission_test", true) {}
~TestHWBatchNormGradInferFission() override = default;
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_fission) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp_x{32, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 5; ++i) {
args_spec_list.push_back(x_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::BatchNormGradInferFission>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_no_fission1) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "before_is_training");
EXPECT_NE(g, nullptr);
std::vector<int> shp_x{32, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 5; ++i) {
args_spec_list.push_back(x_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::BatchNormGradInferFission>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
}
TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_no_fission2) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "before_output3_not_null");
EXPECT_NE(g, nullptr);
std::vector<int> shp_x{32, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 5; ++i) {
args_spec_list.push_back(x_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::BatchNormGradInferFission>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
}
} // namespace opt
} // namespace mindspore
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import Primitive
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
BatchNormGradTraining = G.BatchNormGrad(is_training=True)
BatchNormGradInfer = G.BatchNormGrad(is_training=False)
BNInferGrad = Primitive('BNInferGrad')
BNTrainingUpdateGrad = Primitive('BNTrainingUpdateGrad')
class FnDict:
def __init__(self):
self.fnDict = {}
def __call__(self, fn):
self.fnDict[fn.__name__] = fn
def __getitem__(self, name):
return self.fnDict[name]
def test_batch_norm_grad_infer_fission(tag):
fns = FnDict()
@fns
def before(input0, input1, input2, input3, input4):
batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4)
outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2))
output = tuple_getitem(outputs, 0)
return output
@fns
def before_is_training(input0, input1, input2, input3, input4):
batch_norm = BatchNormGradTraining(input0, input1, input2, input3, input4)
outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2))
output = tuple_getitem(outputs, 0)
return output
@fns
def before_output3_not_null(input0, input1, input2, input3, input4):
batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4)
outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 3))
output = tuple_getitem(outputs, 0)
return output
@fns
def after(input0, input1, input2, input3, input4):
bn_infer_grad = BNInferGrad(input0, input2, input4)
bn_training_update_grad = BNTrainingUpdateGrad(input0, input1, input3, input4)
outputs = make_tuple(bn_infer_grad, tuple_getitem(bn_training_update_grad, 0),
tuple_getitem(bn_training_update_grad, 1))
new_outputs = make_tuple(tuple_getitem(outputs, 0), tuple_getitem(outputs, 1), tuple_getitem(outputs, 2))
output = tuple_getitem(new_outputs, 0)
return make_tuple(output)
return fns[tag]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册