diff --git a/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc b/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc index 6734c74222ff82b2168537c57ad73cbc3a0075f0..69edc3d87f97d6762079a37c920f5ece57903cfa 100644 --- a/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include "paddle/fluid/framework/framework.pb.h" @@ -22,6 +21,7 @@ #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/string/pretty_log.h" +#include "paddle/fluid/string/printf.h" namespace paddle { namespace framework { @@ -30,34 +30,57 @@ namespace ir { // cpplint complaints (wrong!) for not included header in below line. using string::PrettyLogDetail; // NOLINT +#define CHECK_TRUE(expr, err_msg) \ + do { \ + int e_ = (expr); \ + if (!e_) { \ + VLOG(4) << err_msg; \ + return; \ + } \ + } while (0) + +#define EXPECT_TRUE(expr, err_msg) \ + do { \ + int e_ = (expr); \ + if (!e_) { \ + VLOG(4) << err_msg; \ + return false; \ + } \ + } while (0) + namespace { -void validateReduceOpAttrs(const Node* node, const std::string& name) { + +bool validateReduceOpAttrs(const Node* node, const std::string& name) { const auto* op = node->Op(); if (op->HasAttr("dim")) { auto dims = BOOST_GET_CONST(std::vector, op->GetAttr("dim")); - PADDLE_ENFORCE_EQ(dims.size(), 1, platform::errors::PreconditionNotMet( - "The LayerNorm fusion ", name, - " reduction must happen only over " - "single dimension.")); - PADDLE_ENFORCE_EQ(dims.front(), -1, platform::errors::PreconditionNotMet( - "The LayerNorm fusion ", name, - " reduction must happen over last " - "dimension.")); + EXPECT_TRUE( + dims.size() == 1, + ::paddle::string::Sprintf( + "The LayerNorm fusion %s reduction must happen only over single " + "dimension.", + name)); + EXPECT_TRUE(dims.front() == -1, + ::paddle::string::Sprintf("The LayerNorm fusion %s reduction " + "must happen over last dimension.", + name)); } if (op->HasAttr("reduce_all")) { - PADDLE_ENFORCE(!BOOST_GET_CONST(bool, op->GetAttr("reduce_all")), - platform::errors::PreconditionNotMet( - "The LayerNorm fusion ", name, - " reduction must have " - "\'reduce_all\' attribute set to false.")); + EXPECT_TRUE( + !BOOST_GET_CONST(bool, op->GetAttr("reduce_all")), + ::paddle::string::Sprintf( + "The LayerNorm fusion %s" + "reduction must have \'reduce_all\' attribute set to false.", + name)); } if (op->HasAttr("keep_dim")) { - PADDLE_ENFORCE(BOOST_GET_CONST(bool, op->GetAttr("keep_dim")), - platform::errors::PreconditionNotMet( - "The LayerNorm fusion ", name, - " reduction must have " - "\'keep_dim\' attribute set to true.")); + EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("keep_dim")), + ::paddle::string::Sprintf( + "The LayerNorm fusion %s" + " reduction must have \'keep_dim\' attribute set to true.", + name)); } + return true; } void setIntermediateOut(OpDesc* desc, const std::string& out_name, @@ -129,48 +152,46 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const { auto* eps_tensor = scope->FindVar(eps->Name())->GetMutable(); // ------------------ subgraph node's validation --------------------------- - PADDLE_ENFORCE_EQ( - eps_tensor->numel(), 1, - platform::errors::InvalidArgument( - "The LayerNorm divisor " - "epsilon value must be one-element tensor, but has %s " - "elements.", + CHECK_TRUE( + eps_tensor->numel() == 1, + ::paddle::string::Sprintf( + "The LayerNorm divisor epsilon value must be one-element tensor, " + "but has %s elements.", eps_tensor->numel())); - PADDLE_ENFORCE_EQ(eps_tensor->type(), proto::VarType::FP32, - platform::errors::InvalidArgument( - "The LayerNorm divisor " - "epsilon value must be of FP32 data type, but is %s.", - eps_tensor->type())); + CHECK_TRUE( + eps_tensor->type() == proto::VarType::FP32, + ::paddle::string::Sprintf("The LayerNorm divisor epsilon value " + "must be of FP32 data type, but is %s.", + eps_tensor->type())); const auto& gamma_shape = gamma->Var()->GetShape(); const auto& beta_shape = beta->Var()->GetShape(); const auto& x_shape = x->Var()->GetShape(); int64_t x_last_dim = x_shape.back(); - PADDLE_ENFORCE_EQ(gamma_shape.size(), 1, - platform::errors::InvalidArgument( - "The LayerNorm gamma " - "(scale) tensor shape must be one-dimensional, " - "but is %s.", - gamma_shape.size())); - PADDLE_ENFORCE_EQ(beta_shape.size(), 1, - platform::errors::InvalidArgument( - "The LayerNorm beta " - "(shift) tensor shape must be one-dimensional, " - "but is %s.", - beta_shape.size())); - PADDLE_ENFORCE_EQ(beta_shape, gamma_shape, - platform::errors::InvalidArgument( - "The LayerNorm beta " - "and gamma tensors shapes' must be equal.")); - PADDLE_ENFORCE_EQ(gamma_shape.front(), x_last_dim, - platform::errors::InvalidArgument( - "The LayerNorm beta " - "and gamma tensors shapes' must be equal to the last " - "input's dimension size.")); - - validateReduceOpAttrs(x_mean, "input mean"); - validateReduceOpAttrs(std_dev, "std_dev mean"); + CHECK_TRUE( + gamma_shape.size() == 1, + ::paddle::string::Sprintf("The LayerNorm gamma (scale) tensor " + "shape must be one-dimensional, but is %s.", + gamma_shape.size())); + CHECK_TRUE( + beta_shape.size() == 1, + ::paddle::string::Sprintf("The LayerNorm beta (shift) tensor " + "shape must be one-dimensional, but is %s.", + beta_shape.size())); + CHECK_TRUE(beta_shape == gamma_shape, + ::paddle::string::Sprintf("The LayerNorm beta and gamma tensors " + "shapes' must be equal.")); + CHECK_TRUE( + gamma_shape.front() == x_last_dim, + ::paddle::string::Sprintf( + "The LayerNorm beta and gamma tensors " + "shapes' must be equal to the last input's dimension size.")); + + CHECK_TRUE(validateReduceOpAttrs(x_mean, "input mean"), + "Validation of input mean node failed."); + CHECK_TRUE(validateReduceOpAttrs(std_dev, "std_dev mean"), + "Validation of standard deviation node failed."); // ------------------ op creation and placement --------------------------- @@ -213,6 +234,9 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const { } // namespace framework } // namespace paddle +#undef CHECK_TRUE +#undef EXPECT_TRUE + REGISTER_PASS(layer_norm_fuse_pass, paddle::framework::ir::LayerNormFusePass); REGISTER_PASS_CAPABILITY(layer_norm_fuse_pass) .AddCombination( diff --git a/paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc b/paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc index c79c9dda8e54f66f3840f0f1f715d04690cd3f5d..bc083e0d0f964e05837b0fd0ddd92c9302441eaf 100644 --- a/paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/layer_norm_fuse_pass_tester.cc @@ -13,7 +13,10 @@ // limitations under the License. #include +#include +#include +#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/ir/layer_norm_fuse_pass.h" #include "paddle/fluid/framework/ir/pass_test_util.h" @@ -31,100 +34,153 @@ namespace ir { namespace { -ProgramDesc BuildGraphProgram() { - auto prog = test::BuildProgramDesc( - {"x", "x_mean_out", "x_sub_mean_out", "x_sub_mean_sqr_out", "std_dev_out", - "std_dev_eps_out", "std_dev_eps_sqrt_out", "division_out", "scale_out", - "shift_out"}, - {"sqr_pow", "eps", "gamma", "beta"}); - - const auto& block_desc = prog.Block(0); - auto* x_var_desc = block_desc.FindVar("x"); - x_var_desc->SetDataType(proto::VarType::FP32); - x_var_desc->SetShape({3, 32, 48}); - - auto* eps_var_desc = block_desc.FindVar("eps"); - eps_var_desc->SetDataType(proto::VarType::FP32); - eps_var_desc->SetShape({1}); - - auto* gamma_var_desc = block_desc.FindVar("gamma"); - gamma_var_desc->SetDataType(proto::VarType::FP32); - gamma_var_desc->SetShape({48}); - - auto* beta_var_desc = block_desc.FindVar("beta"); - beta_var_desc->SetDataType(proto::VarType::FP32); - beta_var_desc->SetShape({48}); - - auto* x_mean = test::CreateOp(&prog, "reduce_mean", {{"X", "x"}}, - {{"Out", "x_mean_out"}}, false); - x_mean->SetAttr("dim", std::vector{-1}); - x_mean->SetAttr("keep_dim", true); - x_mean->SetAttr("reduce_all", false); - - test::CreateOp(&prog, "elementwise_sub", {{"X", "x"}, {"Y", "x_mean_out"}}, - {{"Out", "x_sub_mean_out"}}, false); - test::CreateOp(&prog, "elementwise_pow", - {{"X", "x_sub_mean_out"}, {"Y", "sqr_pow"}}, - {{"Out", "x_sub_mean_sqr_out"}}, false); - auto* std_dev = - test::CreateOp(&prog, "reduce_mean", {{"X", "x_sub_mean_sqr_out"}}, - {{"Out", "std_dev_out"}}, false); - std_dev->SetAttr("dim", std::vector{-1}); - std_dev->SetAttr("keep_dim", true); - std_dev->SetAttr("reduce_all", false); - - test::CreateOp(&prog, "elementwise_add", {{"X", "std_dev_out"}, {"Y", "eps"}}, - {{"Out", "std_dev_eps_out"}}, false); - test::CreateOp(&prog, "sqrt", {{"X", "std_dev_eps_out"}}, - {{"Out", "std_dev_eps_sqrt_out"}}, false); - test::CreateOp(&prog, "elementwise_div", - {{"X", "x_sub_mean_out"}, {"Y", "std_dev_eps_sqrt_out"}}, - {{"Out", "division_out"}}, false); - test::CreateOp(&prog, "elementwise_mul", - {{"X", "division_out"}, {"Y", "gamma"}}, - {{"Out", "scale_out"}}, false); - test::CreateOp(&prog, "elementwise_add", {{"X", "scale_out"}, {"Y", "beta"}}, - {{"Out", "shift_out"}}, false); - return prog; -} - -bool CheckFusedSubgraphOpsCount(const Graph& graph) { - return test::AssertOpsCount(graph, {{"reduce_mean", 0}, - {"elementwise_sub", 0}, - {"elementwise_pow", 0}, - {"elementwise_add", 0}, - {"sqrt", 0}, - {"elementwise_div", 0}, - {"elementwise_mul", 0}, - {"layer_norm", 1}}); -} +class LayerNormFuseTest { + public: + LayerNormFuseTest() + : m_prog{test::BuildProgramDesc( + {"x", "x_mean_out", "x_sub_mean_out", "x_sub_mean_sqr_out", + "std_dev_out", "std_dev_eps_out", "std_dev_eps_sqrt_out", + "division_out", "scale_out", "shift_out"}, + {"sqr_pow", "eps", "gamma", "beta"})}, + m_place{}, + m_exe{m_place}, + m_block_desc{m_prog.Block(0)} { + auto* x_var_desc = m_block_desc.FindVar("x"); + x_var_desc->SetDataType(proto::VarType::FP32); + x_var_desc->SetShape({3, 32, 48}); + + auto* eps_var_desc = m_block_desc.FindVar("eps"); + eps_var_desc->SetDataType(proto::VarType::FP32); + eps_var_desc->SetShape({1}); + + auto* gamma_var_desc = m_block_desc.FindVar("gamma"); + gamma_var_desc->SetDataType(proto::VarType::FP32); + gamma_var_desc->SetShape({48}); + + auto* beta_var_desc = m_block_desc.FindVar("beta"); + beta_var_desc->SetDataType(proto::VarType::FP32); + beta_var_desc->SetShape({48}); + + auto* x_mean = test::CreateOp(&m_prog, "reduce_mean", {{"X", "x"}}, + {{"Out", "x_mean_out"}}, false); + x_mean->SetAttr("dim", std::vector{-1}); + x_mean->SetAttr("keep_dim", true); + x_mean->SetAttr("reduce_all", false); + + test::CreateOp(&m_prog, "elementwise_sub", + {{"X", "x"}, {"Y", "x_mean_out"}}, + {{"Out", "x_sub_mean_out"}}, false); + test::CreateOp(&m_prog, "elementwise_pow", + {{"X", "x_sub_mean_out"}, {"Y", "sqr_pow"}}, + {{"Out", "x_sub_mean_sqr_out"}}, false); + auto* std_dev = + test::CreateOp(&m_prog, "reduce_mean", {{"X", "x_sub_mean_sqr_out"}}, + {{"Out", "std_dev_out"}}, false); + std_dev->SetAttr("dim", std::vector{-1}); + std_dev->SetAttr("keep_dim", true); + std_dev->SetAttr("reduce_all", false); + + test::CreateOp(&m_prog, "elementwise_add", + {{"X", "std_dev_out"}, {"Y", "eps"}}, + {{"Out", "std_dev_eps_out"}}, false); + test::CreateOp(&m_prog, "sqrt", {{"X", "std_dev_eps_out"}}, + {{"Out", "std_dev_eps_sqrt_out"}}, false); + test::CreateOp(&m_prog, "elementwise_div", + {{"X", "x_sub_mean_out"}, {"Y", "std_dev_eps_sqrt_out"}}, + {{"Out", "division_out"}}, false); + test::CreateOp(&m_prog, "elementwise_mul", + {{"X", "division_out"}, {"Y", "gamma"}}, + {{"Out", "scale_out"}}, false); + test::CreateOp(&m_prog, "elementwise_add", + {{"X", "scale_out"}, {"Y", "beta"}}, {{"Out", "shift_out"}}, + false); + } + + template + LayerNormFuseTest(const Func& func, int removed_nodes = 0, + int added_nodes = 0) + : LayerNormFuseTest() { + m_removed_nodes = removed_nodes; + m_added_nodes = added_nodes; + func(m_block_desc); + } + + void setupGraph() { + auto initFun = [this](const Scope& scope, + const paddle::platform::CPUPlace& place) { + this->initEpsTensorValue(scope, place); + }; + setupGraphWithInitFunc(initFun); + } + + template + void setupGraphWithInitFunc(const Func& func) { + m_graph.reset(new Graph(m_prog)); + // Init scope, as it is used in pass + m_exe.CreateVariables(m_prog, 0, true, &m_scope); + func(m_scope, m_place); + m_graph->SetNotOwned(kParamScopeAttr, &m_scope); + } + + void run(bool fusion = false) const { + EXPECT_TRUE(test::RunPassAndAssert(m_graph.get(), "layer_norm_fuse_pass", + "x", "shift_out", m_removed_nodes, + m_added_nodes)); + EXPECT_TRUE(CheckSubgraphOpsCount(*m_graph, fusion)); + } + + const ProgramDesc& getProgramDesc() const { return m_prog; } + const Graph* getGraph() const { return m_graph.get(); } + + private: + void initEpsTensorValue(const Scope& scope, + const paddle::platform::CPUPlace& place) { + float eps_value = 1e-5; + test::InitLoDTensorHolder(scope, place, "eps", {1}, &eps_value); + } + + bool CheckSubgraphOpsCount(const Graph& graph, bool fusion) const { + if (fusion) + return test::AssertOpsCount(graph, {{"reduce_mean", 0}, + {"elementwise_sub", 0}, + {"elementwise_pow", 0}, + {"elementwise_add", 0}, + {"sqrt", 0}, + {"elementwise_div", 0}, + {"elementwise_mul", 0}, + {"layer_norm", 1}}); + else + return test::AssertOpsCount(graph, {{"reduce_mean", 2}, + {"elementwise_sub", 1}, + {"elementwise_pow", 1}, + {"elementwise_add", 2}, + {"sqrt", 1}, + {"elementwise_div", 1}, + {"elementwise_mul", 1}, + {"layer_norm", 0}}); + } + + int m_removed_nodes{19}; + int m_added_nodes{3}; + ProgramDesc m_prog; + paddle::platform::CPUPlace m_place; + NaiveExecutor m_exe; + const BlockDesc& m_block_desc; + Scope m_scope; + std::unique_ptr m_graph{nullptr}; +}; } // namespace // ------------------------------ Test cases ----------------------------------- TEST(FuseLayerNormPass, TestFuse) { - ProgramDesc prog = BuildGraphProgram(); - - Graph graph(prog); - constexpr int removed_nodes = 19; - // LayerNorm + outputs: {Mean, Variance} - constexpr int added_nodes = 3; - - auto place = paddle::platform::CPUPlace(); - NaiveExecutor exe{place}; - Scope scope; - float eps_value = 1e-5f; - // Init scope, as it is used in pass - exe.CreateVariables(prog, 0, true, &scope); - test::InitLoDTensorHolder(&scope, place, "eps", {1}, &eps_value); - - graph.SetNotOwned(kParamScopeAttr, &scope); - EXPECT_TRUE(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x", - "shift_out", removed_nodes, added_nodes)); - EXPECT_TRUE(CheckFusedSubgraphOpsCount(graph)); - - for (const auto* node : graph.Nodes()) { + LayerNormFuseTest lnorm_test; + lnorm_test.setupGraph(); + lnorm_test.run(true); + + // additional attribute checks + for (const auto* node : lnorm_test.getGraph()->Nodes()) { if (node->IsOp() && node->Op()->Type() == "layer_norm") { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("is_test")); @@ -136,54 +192,194 @@ TEST(FuseLayerNormPass, TestFuse) { } TEST(FuseLayerNormPass, TestInvalidEpsNumel) { - ProgramDesc prog = BuildGraphProgram(); + const auto editEpsFun = [](const BlockDesc& block_desc) { + auto* eps_var_desc = block_desc.FindVar("eps"); + eps_var_desc->SetDataType(proto::VarType::FP32); + eps_var_desc->SetShape({2}); + }; + const auto initEpsTensor = [](const Scope& scope, + const paddle::platform::CPUPlace& place) { + auto eps_values = std::vector{1e-5f, 1e-5f}; + test::InitLoDTensorHolder(scope, place, "eps", {2}, + eps_values.data()); + }; + + LayerNormFuseTest lnorm_test(editEpsFun); + lnorm_test.setupGraphWithInitFunc(initEpsTensor); + lnorm_test.run(false); +} + +TEST(FuseLayerNormPass, TestInvalidEpsDataType) { + const auto editEpsFun = [](const BlockDesc& block_desc) { + auto* eps_var_desc = block_desc.FindVar("eps"); + eps_var_desc->SetDataType(proto::VarType::FP64); + eps_var_desc->SetShape({1}); + }; + const auto initEpsTensor = [](const Scope& scope, + const paddle::platform::CPUPlace& place) { + double eps_value = 1e-5; + test::InitLoDTensorHolder(scope, place, "eps", {1}, &eps_value); + }; + + LayerNormFuseTest lnorm_test(editEpsFun); + lnorm_test.setupGraphWithInitFunc(initEpsTensor); + lnorm_test.run(false); +} + +TEST(FuseLayerNormPass, TestInvalidGammaRank) { + const auto editGammaFun = [](const BlockDesc& block_desc) { + auto* gamma_var_desc = block_desc.FindVar("gamma"); + gamma_var_desc->SetDataType(proto::VarType::FP32); + gamma_var_desc->SetShape({48, 32}); + }; - auto* eps_var_desc = prog.Block(0).FindVar("eps"); - eps_var_desc->SetDataType(proto::VarType::FP32); - eps_var_desc->SetShape({2}); + LayerNormFuseTest lnorm_test(editGammaFun); + lnorm_test.setupGraph(); + lnorm_test.run(false); +} + +TEST(FuseLayerNormPass, TestInvalidBetaRank) { + const auto editBetaFun = [](const BlockDesc& block_desc) { + auto* beta_var_desc = block_desc.FindVar("beta"); + beta_var_desc->SetDataType(proto::VarType::FP32); + beta_var_desc->SetShape({48, 32}); + }; - Graph graph(prog); - constexpr int removed_nodes = 19; - constexpr int added_nodes = 3; + LayerNormFuseTest lnorm_test(editBetaFun); + lnorm_test.setupGraph(); + lnorm_test.run(false); +} - auto place = paddle::platform::CPUPlace(); - NaiveExecutor exe{place}; - Scope scope; - auto eps_values = std::vector{1e-5f, 1e-5f}; - // Init scope, as it is used in pass - exe.CreateVariables(prog, 0, true, &scope); - test::InitLoDTensorHolder(&scope, place, "eps", {2}, - eps_values.data()); +TEST(FuseLayerNormPass, TestUnequalGammaBetaShapes) { + const auto editGammaBetaFun = [](const BlockDesc& block_desc) { + auto* beta_var_desc = block_desc.FindVar("beta"); + beta_var_desc->SetDataType(proto::VarType::FP32); + beta_var_desc->SetShape({32}); + }; - graph.SetNotOwned(kParamScopeAttr, &scope); - EXPECT_THROW(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x", - "shift_out", removed_nodes, added_nodes), - paddle::platform::EnforceNotMet); + LayerNormFuseTest lnorm_test(editGammaBetaFun); + lnorm_test.setupGraph(); + lnorm_test.run(false); } -TEST(FuseLayerNormPass, TestInvalidEpsDataType) { - ProgramDesc prog = BuildGraphProgram(); - - auto* eps_var_desc = prog.Block(0).FindVar("eps"); - eps_var_desc->SetDataType(proto::VarType::FP64); - eps_var_desc->SetShape({1}); - - Graph graph(prog); - constexpr int removed_nodes = 19; - constexpr int added_nodes = 3; - - auto place = paddle::platform::CPUPlace(); - NaiveExecutor exe{place}; - Scope scope; - double eps_value = 1e-5; - // Init scope, as it is used in pass - exe.CreateVariables(prog, 0, true, &scope); - test::InitLoDTensorHolder(&scope, place, "eps", {1}, &eps_value); - - graph.SetNotOwned(kParamScopeAttr, &scope); - EXPECT_THROW(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x", - "shift_out", removed_nodes, added_nodes), - paddle::platform::EnforceNotMet); +TEST(FuseLayerNormPass, TestGammaBetaUnequalInputChannelShape) { + const auto editGammaBetaFun = [](const BlockDesc& block_desc) { + auto* beta_var_desc = block_desc.FindVar("beta"); + beta_var_desc->SetDataType(proto::VarType::FP32); + beta_var_desc->SetShape({32}); + + auto* gamma_var_desc = block_desc.FindVar("gamma"); + gamma_var_desc->SetDataType(proto::VarType::FP32); + gamma_var_desc->SetShape({32}); + }; + + LayerNormFuseTest lnorm_test(editGammaBetaFun); + lnorm_test.setupGraph(); + lnorm_test.run(false); +} + +TEST(FuseLayerNormPass, NoFusionBadInMeanDimAttrRank) { + const auto editFun = [](const BlockDesc& block_desc) { + auto* x_mean_desc = + test::GetOp(block_desc, "reduce_mean", "Out", "x_mean_out"); + ASSERT_NE(x_mean_desc, nullptr); + x_mean_desc->SetAttr("dim", std::vector{1, 1}); + }; + + LayerNormFuseTest lnorm_test(editFun); + lnorm_test.setupGraph(); + lnorm_test.run(false); +} + +TEST(FuseLayerNormPass, NoFusionBadInMeanDimAttr) { + const auto editFun = [](const BlockDesc& block_desc) { + auto* x_mean_desc = + test::GetOp(block_desc, "reduce_mean", "Out", "x_mean_out"); + ASSERT_NE(x_mean_desc, nullptr); + x_mean_desc->SetAttr("dim", std::vector{1}); + }; + + LayerNormFuseTest lnorm_test(editFun); + lnorm_test.setupGraph(); + lnorm_test.run(false); +} + +TEST(FuseLayerNormPass, NoFusionBadInMeanKeepDimAttr) { + const auto editFun = [](const BlockDesc& block_desc) { + auto* x_mean_desc = + test::GetOp(block_desc, "reduce_mean", "Out", "x_mean_out"); + ASSERT_NE(x_mean_desc, nullptr); + x_mean_desc->SetAttr("keep_dim", false); + }; + + LayerNormFuseTest lnorm_test(editFun); + lnorm_test.setupGraph(); + lnorm_test.run(false); +} + +TEST(FuseLayerNormPass, NoFusionBadInMeanReduceAllAttr) { + const auto editFun = [](const BlockDesc& block_desc) { + auto* x_mean_desc = + test::GetOp(block_desc, "reduce_mean", "Out", "x_mean_out"); + ASSERT_NE(x_mean_desc, nullptr); + x_mean_desc->SetAttr("reduce_all", true); + }; + + LayerNormFuseTest lnorm_test(editFun); + lnorm_test.setupGraph(); + lnorm_test.run(false); +} + +TEST(FuseLayerNormPass, NoFusionBadStdDevMeanDimAttrRank) { + const auto editFun = [](const BlockDesc& block_desc) { + auto* std_dev_desc = + test::GetOp(block_desc, "reduce_mean", "Out", "std_dev_out"); + ASSERT_NE(std_dev_desc, nullptr); + std_dev_desc->SetAttr("dim", std::vector{1, 1}); + }; + + LayerNormFuseTest lnorm_test(editFun); + lnorm_test.setupGraph(); + lnorm_test.run(false); +} + +TEST(FuseLayerNormPass, NoFusionBadStdDevMeanDimAttr) { + const auto editFun = [](const BlockDesc& block_desc) { + auto* std_dev_desc = + test::GetOp(block_desc, "reduce_mean", "Out", "std_dev_out"); + ASSERT_NE(std_dev_desc, nullptr); + std_dev_desc->SetAttr("dim", std::vector{1}); + }; + + LayerNormFuseTest lnorm_test(editFun); + lnorm_test.setupGraph(); + lnorm_test.run(false); +} + +TEST(FuseLayerNormPass, NoFusionBadStdDevMeanKeepDimAttr) { + const auto editFun = [](const BlockDesc& block_desc) { + auto* std_dev_desc = + test::GetOp(block_desc, "reduce_mean", "Out", "std_dev_out"); + ASSERT_NE(std_dev_desc, nullptr); + std_dev_desc->SetAttr("keep_dim", false); + }; + + LayerNormFuseTest lnorm_test(editFun); + lnorm_test.setupGraph(); + lnorm_test.run(false); +} + +TEST(FuseLayerNormPass, NoFusionBadStdDevMeanReduceAllAttr) { + const auto editFun = [](const BlockDesc& block_desc) { + auto* std_dev_desc = + test::GetOp(block_desc, "reduce_mean", "Out", "std_dev_out"); + ASSERT_NE(std_dev_desc, nullptr); + std_dev_desc->SetAttr("reduce_all", true); + }; + + LayerNormFuseTest lnorm_test(editFun); + lnorm_test.setupGraph(); + lnorm_test.run(false); } TEST(FuseLayerNormPass, pass_op_version_check) { diff --git a/paddle/fluid/framework/ir/pass_test_util.cc b/paddle/fluid/framework/ir/pass_test_util.cc index c37331dec05b4e67dd5a0aaea8050fe5b7d11278..a98fe8a20719b2e074d9bd2c55ff8b4b71d14650 100644 --- a/paddle/fluid/framework/ir/pass_test_util.cc +++ b/paddle/fluid/framework/ir/pass_test_util.cc @@ -175,10 +175,11 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name, } template -void InitLoDTensorHolder(Scope* scope, const paddle::platform::Place& place, +void InitLoDTensorHolder(const Scope& scope, + const paddle::platform::Place& place, const std::string& var_name, const std::vector& dims, const T* data) { - auto var = scope->Var(var_name); + auto var = scope.FindLocalVar(var_name); auto tensor = var->GetMutable(); auto* tensor_mem_ptr = tensor->mutable_data(make_ddim(dims), place); if (data != nullptr) { @@ -189,14 +190,16 @@ void InitLoDTensorHolder(Scope* scope, const paddle::platform::Place& place, } // Instantiate for below data types. -template void InitLoDTensorHolder(Scope*, const paddle::platform::Place&, +template void InitLoDTensorHolder(const Scope&, + const paddle::platform::Place&, const std::string&, const std::vector&, const float*); -template void InitLoDTensorHolder(Scope*, const paddle::platform::Place&, +template void InitLoDTensorHolder(const Scope&, + const paddle::platform::Place&, const std::string&, const std::vector&, const int*); -template void InitLoDTensorHolder(Scope*, +template void InitLoDTensorHolder(const Scope&, const paddle::platform::Place&, const std::string&, const std::vector&, @@ -205,7 +208,13 @@ template void InitLoDTensorHolder(Scope*, OpDesc* GetOp(const ProgramDesc& prog, const std::string& op_type, const std::string& output_name, const std::string& output_arg_name) { - auto all_ops = prog.Block(0).AllOps(); + return GetOp(prog.Block(0), op_type, output_name, output_arg_name); +} + +OpDesc* GetOp(const BlockDesc& block_desc, const std::string& op_type, + const std::string& output_name, + const std::string& output_arg_name) { + auto all_ops = block_desc.AllOps(); for (auto* op_desc : all_ops) { if (op_desc->Type() == op_type && op_desc->HasOutput(output_name)) { const auto& arg_names = op_desc->Outputs().at(output_name); diff --git a/paddle/fluid/framework/ir/pass_test_util.h b/paddle/fluid/framework/ir/pass_test_util.h index 519522a932ceb791f80d3e280fc274b469973054..9a75bcd366b39ecc0ed9d94fc78f43446d6f077e 100644 --- a/paddle/fluid/framework/ir/pass_test_util.h +++ b/paddle/fluid/framework/ir/pass_test_util.h @@ -128,7 +128,8 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name, /// @tparam T Tensor data type. /// template -void InitLoDTensorHolder(Scope* scope, const paddle::platform::Place& place, +void InitLoDTensorHolder(const Scope& scope, + const paddle::platform::Place& place, const std::string& var_name, const std::vector& dims, const T* data = nullptr); @@ -148,6 +149,10 @@ OpDesc* GetOp(const ProgramDesc& prog, const std::string& op_type, const std::string& output_name, const std::string& output_arg_name); +OpDesc* GetOp(const BlockDesc& block_desc, const std::string& op_type, + const std::string& output_name, + const std::string& output_arg_name); + } // namespace test } // namespace ir } // namespace framework