提交 c51d90d8 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1767 Move LayerNormGrad split pass ahead of kernel select

Merge pull request !1767 from huanghui/LayerNormGrad-split-pass
...@@ -145,7 +145,6 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g ...@@ -145,7 +145,6 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();
auto data_layout_pm = std::make_shared<PassManager>("pynative_transop_pm"); auto data_layout_pm = std::make_shared<PassManager>("pynative_transop_pm");
data_layout_pm->AddPass(std::make_shared<LayerNormGradSplit>());
data_layout_pm->AddPass(std::make_shared<RunOpInsertTransData>()); data_layout_pm->AddPass(std::make_shared<RunOpInsertTransData>());
data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
...@@ -182,7 +181,6 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) ...@@ -182,7 +181,6 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();
auto data_layout_pm = std::make_shared<PassManager>("transop_pm"); auto data_layout_pm = std::make_shared<PassManager>("transop_pm");
data_layout_pm->AddPass(std::make_shared<LayerNormGradSplit>());
data_layout_pm->AddPass(std::make_shared<InsertTransOp>()); data_layout_pm->AddPass(std::make_shared<InsertTransOp>());
data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
...@@ -238,6 +236,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap ...@@ -238,6 +236,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
} else { } else {
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>()); ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
...@@ -282,6 +281,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne ...@@ -282,6 +281,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
......
...@@ -32,7 +32,6 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop( ...@@ -32,7 +32,6 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop(
std::vector<AnfNodePtr> *layer_norm_x_backprop_outputs) const { std::vector<AnfNodePtr> *layer_norm_x_backprop_outputs) const {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(layer_norm_grad); MS_EXCEPTION_IF_NULL(layer_norm_grad);
MS_EXCEPTION_IF_NULL(kernel_select_);
auto prim = std::make_shared<Primitive>(kLayerNormXBackpropOpName); auto prim = std::make_shared<Primitive>(kLayerNormXBackpropOpName);
std::vector<AnfNodePtr> layer_norm_x_backprop_inputs = {NewValueNode(prim)}; std::vector<AnfNodePtr> layer_norm_x_backprop_inputs = {NewValueNode(prim)};
for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) { for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) {
...@@ -46,7 +45,6 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop( ...@@ -46,7 +45,6 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop(
auto shapes = {AnfAlgo::GetOutputInferShape(layer_norm_grad, 0)}; auto shapes = {AnfAlgo::GetOutputInferShape(layer_norm_grad, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_x_backprop.get()); AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_x_backprop.get());
kernel_select_->SelectKernel(layer_norm_x_backprop);
(*layer_norm_x_backprop_outputs).push_back(layer_norm_x_backprop); (*layer_norm_x_backprop_outputs).push_back(layer_norm_x_backprop);
} }
...@@ -55,7 +53,6 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop( ...@@ -55,7 +53,6 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop(
std::vector<AnfNodePtr> *layer_norm_beta_gamma_backprop_outputs) const { std::vector<AnfNodePtr> *layer_norm_beta_gamma_backprop_outputs) const {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(layer_norm_grad); MS_EXCEPTION_IF_NULL(layer_norm_grad);
MS_EXCEPTION_IF_NULL(kernel_select_);
auto prim = std::make_shared<Primitive>(kLayerNormBetaGammaBackpropOpName); auto prim = std::make_shared<Primitive>(kLayerNormBetaGammaBackpropOpName);
std::vector<AnfNodePtr> layer_norm_beta_gamma_backprop_inputs = {NewValueNode(prim)}; std::vector<AnfNodePtr> layer_norm_beta_gamma_backprop_inputs = {NewValueNode(prim)};
for (size_t i = 1; i < layer_norm_grad->inputs().size() - 1; ++i) { for (size_t i = 1; i < layer_norm_grad->inputs().size() - 1; ++i) {
...@@ -73,10 +70,9 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop( ...@@ -73,10 +70,9 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop(
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_beta_gamma_backprop.get()); AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_beta_gamma_backprop.get());
// get device shape of LayerNormGrad's 5th Input, and convert it to attr // get device shape of LayerNormGrad's 5th Input, and convert it to attr
std::vector<size_t> shape_gamma = AnfAlgo::GetInputDeviceShape(layer_norm_grad, 4); std::vector<size_t> shape_gamma = AnfAlgo::GetPrevNodeOutputInferShape(layer_norm_grad, 4);
AnfAlgo::SetNodeAttr(kAttrShapeGamma, MakeValue(opt::Convert2Int(shape_gamma)), layer_norm_beta_gamma_backprop); AnfAlgo::SetNodeAttr(kAttrShapeGamma, MakeValue(opt::Convert2Int(shape_gamma)), layer_norm_beta_gamma_backprop);
kernel_select_->SelectKernel(layer_norm_beta_gamma_backprop);
CreateMultipleOutputsOfAnfNode(graph, layer_norm_beta_gamma_backprop, kLayerNormBetaGammaBackpropOutputNum, CreateMultipleOutputsOfAnfNode(graph, layer_norm_beta_gamma_backprop, kLayerNormBetaGammaBackpropOutputNum,
layer_norm_beta_gamma_backprop_outputs); layer_norm_beta_gamma_backprop_outputs);
} }
......
...@@ -26,8 +26,7 @@ namespace mindspore { ...@@ -26,8 +26,7 @@ namespace mindspore {
namespace opt { namespace opt {
class LayerNormGradSplit : public PatternProcessPass { class LayerNormGradSplit : public PatternProcessPass {
public: public:
explicit LayerNormGradSplit(bool multigraph = true) explicit LayerNormGradSplit(bool multigraph = true) : PatternProcessPass("layer_norm_grad_split", multigraph) {}
: PatternProcessPass("layer_norm_grad_split", multigraph), kernel_select_(std::make_shared<KernelSelect>()) {}
~LayerNormGradSplit() override = default; ~LayerNormGradSplit() override = default;
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
...@@ -37,7 +36,6 @@ class LayerNormGradSplit : public PatternProcessPass { ...@@ -37,7 +36,6 @@ class LayerNormGradSplit : public PatternProcessPass {
std::vector<AnfNodePtr> *layer_norm_grad_outputs) const; std::vector<AnfNodePtr> *layer_norm_grad_outputs) const;
void CreateOutputsOfLayerNormBetaGammaBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, void CreateOutputsOfLayerNormBetaGammaBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad,
std::vector<AnfNodePtr> *layer_norm_beta_gamma_outputs) const; std::vector<AnfNodePtr> *layer_norm_beta_gamma_outputs) const;
KernelSelectPtr kernel_select_;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
......
...@@ -39,36 +39,6 @@ class TestHWLayerNormGradSplit : public BackendCommon { ...@@ -39,36 +39,6 @@ class TestHWLayerNormGradSplit : public BackendCommon {
UT::PyFuncGraphFetcher get_py_fun_; UT::PyFuncGraphFetcher get_py_fun_;
}; };
class MockLayerNormGradSplitKernelSelect : public KernelSelect {
public:
MockLayerNormGradSplitKernelSelect() = default;
~MockLayerNormGradSplitKernelSelect() override = default;
void SelectKernel(const CNodePtr &cnode) override {
auto name = AnfAlgo::GetCNodeName(cnode);
if (name == kLayerNormXBackpropOpName) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat(
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder.SetInputsDeviceType(
{kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16});
builder.SetOutputsFormat({kOpFormat_NC1HWC0});
builder.SetOutputsDeviceType({kNumberTypeFloat16});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
return;
}
if (name == kLayerNormBetaGammaBackpropOpName) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16});
builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
return;
}
}
}; // namespace opt
TEST_F(TestHWLayerNormGradSplit, test_layer_norm_grad_split) { TEST_F(TestHWLayerNormGradSplit, test_layer_norm_grad_split) {
get_py_fun_.SetDoResolve(true); get_py_fun_.SetDoResolve(true);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_layer_norm_grad_split", "before"); FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_layer_norm_grad_split", "before");
...@@ -81,49 +51,9 @@ TEST_F(TestHWLayerNormGradSplit, test_layer_norm_grad_split) { ...@@ -81,49 +51,9 @@ TEST_F(TestHWLayerNormGradSplit, test_layer_norm_grad_split) {
auto kernel_graph = GetKernelGraph(g, args_spec_list); auto kernel_graph = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kernel_graph, nullptr); EXPECT_NE(kernel_graph, nullptr);
// get LayerNormGrad
CNodePtr ret = kernel_graph->get_return();
EXPECT_NE(ret, nullptr);
EXPECT_NE(ret->input(1), nullptr);
EXPECT_TRUE(ret->input(1)->isa<CNode>());
auto make_tuple1 = ret->input(1)->cast<CNodePtr>();
EXPECT_NE(make_tuple1->input(1), nullptr);
EXPECT_TRUE(make_tuple1->input(1)->isa<CNode>());
auto make_tuple2 = make_tuple1->input(1)->cast<CNodePtr>();
EXPECT_NE(make_tuple2->input(1), nullptr);
EXPECT_TRUE(make_tuple2->input(1)->isa<CNode>());
auto tuple_getitem = make_tuple2->input(1)->cast<CNodePtr>();
EXPECT_NE(tuple_getitem->input(1), nullptr);
EXPECT_TRUE(tuple_getitem->input(1)->isa<CNode>());
auto layer_norm_grad = tuple_getitem->input(1)->cast<CNodePtr>();
// set kernel for LayerNormGrad
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
builder1.SetInputsFormat(
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder1.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder1.SetInputsDeviceType(
{kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16});
builder1.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16});
builder1.SetKernelType(TBE_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), layer_norm_grad.get());
// get param5
EXPECT_NE(layer_norm_grad->input(5), nullptr);
auto param = layer_norm_grad->input(5);
// set kernel for param5
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder2;
builder2.SetOutputsFormat({kOpFormat_NC1HWC0});
builder2.SetOutputsDeviceType({kNumberTypeFloat16});
AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), param.get());
// do layer_norm_grad_split pass
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>(); auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::LayerNormGradSplit>(); auto pass = std::make_shared<opt::LayerNormGradSplit>();
auto kernel_select = std::make_shared<MockLayerNormGradSplitKernelSelect>();
pass->kernel_select_ = kernel_select;
pm->AddPass(pass); pm->AddPass(pass);
optimizer->AddPassManager(pm); optimizer->AddPassManager(pm);
auto new_graph = optimizer->Optimize(kernel_graph); auto new_graph = optimizer->Optimize(kernel_graph);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册