提交 cf87218f 编写于 作者: H huanghui

place layernormgrad split pass before kernel select

上级 e7b7abc5
......@@ -145,7 +145,6 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
auto data_layout_pm = std::make_shared<PassManager>("pynative_transop_pm");
data_layout_pm->AddPass(std::make_shared<LayerNormGradSplit>());
data_layout_pm->AddPass(std::make_shared<RunOpInsertTransData>());
data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
......@@ -182,7 +181,6 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
auto data_layout_pm = std::make_shared<PassManager>("transop_pm");
data_layout_pm->AddPass(std::make_shared<LayerNormGradSplit>());
data_layout_pm->AddPass(std::make_shared<InsertTransOp>());
data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
......@@ -238,6 +236,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
} else {
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion>());
}
......@@ -281,6 +280,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
......
......@@ -32,7 +32,6 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop(
std::vector<AnfNodePtr> *layer_norm_x_backprop_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(layer_norm_grad);
MS_EXCEPTION_IF_NULL(kernel_select_);
auto prim = std::make_shared<Primitive>(kLayerNormXBackpropOpName);
std::vector<AnfNodePtr> layer_norm_x_backprop_inputs = {NewValueNode(prim)};
for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) {
......@@ -46,7 +45,6 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop(
auto shapes = {AnfAlgo::GetOutputInferShape(layer_norm_grad, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_x_backprop.get());
kernel_select_->SelectKernel(layer_norm_x_backprop);
(*layer_norm_x_backprop_outputs).push_back(layer_norm_x_backprop);
}
......@@ -55,7 +53,6 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop(
std::vector<AnfNodePtr> *layer_norm_beta_gamma_backprop_outputs) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(layer_norm_grad);
MS_EXCEPTION_IF_NULL(kernel_select_);
auto prim = std::make_shared<Primitive>(kLayerNormBetaGammaBackpropOpName);
std::vector<AnfNodePtr> layer_norm_beta_gamma_backprop_inputs = {NewValueNode(prim)};
for (size_t i = 1; i < layer_norm_grad->inputs().size() - 1; ++i) {
......@@ -73,10 +70,9 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop(
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_beta_gamma_backprop.get());
// get device shape of LayerNormGrad's 5th Input, and convert it to attr
std::vector<size_t> shape_gamma = AnfAlgo::GetInputDeviceShape(layer_norm_grad, 4);
std::vector<size_t> shape_gamma = AnfAlgo::GetPrevNodeOutputInferShape(layer_norm_grad, 4);
AnfAlgo::SetNodeAttr(kAttrShapeGamma, MakeValue(opt::Convert2Int(shape_gamma)), layer_norm_beta_gamma_backprop);
kernel_select_->SelectKernel(layer_norm_beta_gamma_backprop);
CreateMultipleOutputsOfAnfNode(graph, layer_norm_beta_gamma_backprop, kLayerNormBetaGammaBackpropOutputNum,
layer_norm_beta_gamma_backprop_outputs);
}
......
......@@ -26,8 +26,7 @@ namespace mindspore {
namespace opt {
class LayerNormGradSplit : public PatternProcessPass {
public:
explicit LayerNormGradSplit(bool multigraph = true)
: PatternProcessPass("layer_norm_grad_split", multigraph), kernel_select_(std::make_shared<KernelSelect>()) {}
explicit LayerNormGradSplit(bool multigraph = true) : PatternProcessPass("layer_norm_grad_split", multigraph) {}
~LayerNormGradSplit() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
......@@ -37,7 +36,6 @@ class LayerNormGradSplit : public PatternProcessPass {
std::vector<AnfNodePtr> *layer_norm_grad_outputs) const;
void CreateOutputsOfLayerNormBetaGammaBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad,
std::vector<AnfNodePtr> *layer_norm_beta_gamma_outputs) const;
KernelSelectPtr kernel_select_;
};
} // namespace opt
} // namespace mindspore
......
......@@ -39,36 +39,6 @@ class TestHWLayerNormGradSplit : public BackendCommon {
UT::PyFuncGraphFetcher get_py_fun_;
};
class MockLayerNormGradSplitKernelSelect : public KernelSelect {
public:
MockLayerNormGradSplitKernelSelect() = default;
~MockLayerNormGradSplitKernelSelect() override = default;
void SelectKernel(const CNodePtr &cnode) override {
auto name = AnfAlgo::GetCNodeName(cnode);
if (name == kLayerNormXBackpropOpName) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat(
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder.SetInputsDeviceType(
{kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16});
builder.SetOutputsFormat({kOpFormat_NC1HWC0});
builder.SetOutputsDeviceType({kNumberTypeFloat16});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
return;
}
if (name == kLayerNormBetaGammaBackpropOpName) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16});
builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
return;
}
}
}; // namespace opt
TEST_F(TestHWLayerNormGradSplit, test_layer_norm_grad_split) {
get_py_fun_.SetDoResolve(true);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_layer_norm_grad_split", "before");
......@@ -81,49 +51,9 @@ TEST_F(TestHWLayerNormGradSplit, test_layer_norm_grad_split) {
auto kernel_graph = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kernel_graph, nullptr);
// get LayerNormGrad
CNodePtr ret = kernel_graph->get_return();
EXPECT_NE(ret, nullptr);
EXPECT_NE(ret->input(1), nullptr);
EXPECT_TRUE(ret->input(1)->isa<CNode>());
auto make_tuple1 = ret->input(1)->cast<CNodePtr>();
EXPECT_NE(make_tuple1->input(1), nullptr);
EXPECT_TRUE(make_tuple1->input(1)->isa<CNode>());
auto make_tuple2 = make_tuple1->input(1)->cast<CNodePtr>();
EXPECT_NE(make_tuple2->input(1), nullptr);
EXPECT_TRUE(make_tuple2->input(1)->isa<CNode>());
auto tuple_getitem = make_tuple2->input(1)->cast<CNodePtr>();
EXPECT_NE(tuple_getitem->input(1), nullptr);
EXPECT_TRUE(tuple_getitem->input(1)->isa<CNode>());
auto layer_norm_grad = tuple_getitem->input(1)->cast<CNodePtr>();
// set kernel for LayerNormGrad
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
builder1.SetInputsFormat(
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder1.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder1.SetInputsDeviceType(
{kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16});
builder1.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16});
builder1.SetKernelType(TBE_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), layer_norm_grad.get());
// get param5
EXPECT_NE(layer_norm_grad->input(5), nullptr);
auto param = layer_norm_grad->input(5);
// set kernel for param5
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder2;
builder2.SetOutputsFormat({kOpFormat_NC1HWC0});
builder2.SetOutputsDeviceType({kNumberTypeFloat16});
AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), param.get());
// do layer_norm_grad_split pass
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::LayerNormGradSplit>();
auto kernel_select = std::make_shared<MockLayerNormGradSplitKernelSelect>();
pass->kernel_select_ = kernel_select;
pm->AddPass(pass);
optimizer->AddPassManager(pm);
auto new_graph = optimizer->Optimize(kernel_graph);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册