未验证 提交 092a2b14 编写于 作者: A Adam Osewski 提交者: GitHub

More UT for LayerNormFuse pass (#30891)

* Additionally change to not throw error from inside pass.
上级 a80fe67f
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
...@@ -22,6 +21,7 @@ ...@@ -22,6 +21,7 @@
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -30,34 +30,57 @@ namespace ir { ...@@ -30,34 +30,57 @@ namespace ir {
// cpplint complaints (wrong!) for not included <string> header in below line. // cpplint complaints (wrong!) for not included <string> header in below line.
using string::PrettyLogDetail; // NOLINT using string::PrettyLogDetail; // NOLINT
#define CHECK_TRUE(expr, err_msg) \
do { \
int e_ = (expr); \
if (!e_) { \
VLOG(4) << err_msg; \
return; \
} \
} while (0)
#define EXPECT_TRUE(expr, err_msg) \
do { \
int e_ = (expr); \
if (!e_) { \
VLOG(4) << err_msg; \
return false; \
} \
} while (0)
namespace { namespace {
void validateReduceOpAttrs(const Node* node, const std::string& name) {
bool validateReduceOpAttrs(const Node* node, const std::string& name) {
const auto* op = node->Op(); const auto* op = node->Op();
if (op->HasAttr("dim")) { if (op->HasAttr("dim")) {
auto dims = BOOST_GET_CONST(std::vector<int>, op->GetAttr("dim")); auto dims = BOOST_GET_CONST(std::vector<int>, op->GetAttr("dim"));
PADDLE_ENFORCE_EQ(dims.size(), 1, platform::errors::PreconditionNotMet( EXPECT_TRUE(
"The LayerNorm fusion ", name, dims.size() == 1,
" reduction must happen only over " ::paddle::string::Sprintf(
"single dimension.")); "The LayerNorm fusion %s reduction must happen only over single "
PADDLE_ENFORCE_EQ(dims.front(), -1, platform::errors::PreconditionNotMet( "dimension.",
"The LayerNorm fusion ", name, name));
" reduction must happen over last " EXPECT_TRUE(dims.front() == -1,
"dimension.")); ::paddle::string::Sprintf("The LayerNorm fusion %s reduction "
"must happen over last dimension.",
name));
} }
if (op->HasAttr("reduce_all")) { if (op->HasAttr("reduce_all")) {
PADDLE_ENFORCE(!BOOST_GET_CONST(bool, op->GetAttr("reduce_all")), EXPECT_TRUE(
platform::errors::PreconditionNotMet( !BOOST_GET_CONST(bool, op->GetAttr("reduce_all")),
"The LayerNorm fusion ", name, ::paddle::string::Sprintf(
" reduction must have " "The LayerNorm fusion %s"
"\'reduce_all\' attribute set to false.")); "reduction must have \'reduce_all\' attribute set to false.",
name));
} }
if (op->HasAttr("keep_dim")) { if (op->HasAttr("keep_dim")) {
PADDLE_ENFORCE(BOOST_GET_CONST(bool, op->GetAttr("keep_dim")), EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("keep_dim")),
platform::errors::PreconditionNotMet( ::paddle::string::Sprintf(
"The LayerNorm fusion ", name, "The LayerNorm fusion %s"
" reduction must have " " reduction must have \'keep_dim\' attribute set to true.",
"\'keep_dim\' attribute set to true.")); name));
} }
return true;
} }
void setIntermediateOut(OpDesc* desc, const std::string& out_name, void setIntermediateOut(OpDesc* desc, const std::string& out_name,
...@@ -129,17 +152,16 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const { ...@@ -129,17 +152,16 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
auto* eps_tensor = scope->FindVar(eps->Name())->GetMutable<LoDTensor>(); auto* eps_tensor = scope->FindVar(eps->Name())->GetMutable<LoDTensor>();
// ------------------ subgraph node's validation --------------------------- // ------------------ subgraph node's validation ---------------------------
PADDLE_ENFORCE_EQ( CHECK_TRUE(
eps_tensor->numel(), 1, eps_tensor->numel() == 1,
platform::errors::InvalidArgument( ::paddle::string::Sprintf(
"The LayerNorm divisor " "The LayerNorm divisor epsilon value must be one-element tensor, "
"epsilon value must be one-element tensor, but has %s " "but has %s elements.",
"elements.",
eps_tensor->numel())); eps_tensor->numel()));
PADDLE_ENFORCE_EQ(eps_tensor->type(), proto::VarType::FP32, CHECK_TRUE(
platform::errors::InvalidArgument( eps_tensor->type() == proto::VarType::FP32,
"The LayerNorm divisor " ::paddle::string::Sprintf("The LayerNorm divisor epsilon value "
"epsilon value must be of FP32 data type, but is %s.", "must be of FP32 data type, but is %s.",
eps_tensor->type())); eps_tensor->type()));
const auto& gamma_shape = gamma->Var()->GetShape(); const auto& gamma_shape = gamma->Var()->GetShape();
...@@ -147,30 +169,29 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const { ...@@ -147,30 +169,29 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
const auto& x_shape = x->Var()->GetShape(); const auto& x_shape = x->Var()->GetShape();
int64_t x_last_dim = x_shape.back(); int64_t x_last_dim = x_shape.back();
PADDLE_ENFORCE_EQ(gamma_shape.size(), 1, CHECK_TRUE(
platform::errors::InvalidArgument( gamma_shape.size() == 1,
"The LayerNorm gamma " ::paddle::string::Sprintf("The LayerNorm gamma (scale) tensor "
"(scale) tensor shape must be one-dimensional, " "shape must be one-dimensional, but is %s.",
"but is %s.",
gamma_shape.size())); gamma_shape.size()));
PADDLE_ENFORCE_EQ(beta_shape.size(), 1, CHECK_TRUE(
platform::errors::InvalidArgument( beta_shape.size() == 1,
"The LayerNorm beta " ::paddle::string::Sprintf("The LayerNorm beta (shift) tensor "
"(shift) tensor shape must be one-dimensional, " "shape must be one-dimensional, but is %s.",
"but is %s.",
beta_shape.size())); beta_shape.size()));
PADDLE_ENFORCE_EQ(beta_shape, gamma_shape, CHECK_TRUE(beta_shape == gamma_shape,
platform::errors::InvalidArgument( ::paddle::string::Sprintf("The LayerNorm beta and gamma tensors "
"The LayerNorm beta " "shapes' must be equal."));
"and gamma tensors shapes' must be equal.")); CHECK_TRUE(
PADDLE_ENFORCE_EQ(gamma_shape.front(), x_last_dim, gamma_shape.front() == x_last_dim,
platform::errors::InvalidArgument( ::paddle::string::Sprintf(
"The LayerNorm beta " "The LayerNorm beta and gamma tensors "
"and gamma tensors shapes' must be equal to the last " "shapes' must be equal to the last input's dimension size."));
"input's dimension size."));
CHECK_TRUE(validateReduceOpAttrs(x_mean, "input mean"),
validateReduceOpAttrs(x_mean, "input mean"); "Validation of input mean node failed.");
validateReduceOpAttrs(std_dev, "std_dev mean"); CHECK_TRUE(validateReduceOpAttrs(std_dev, "std_dev mean"),
"Validation of standard deviation node failed.");
// ------------------ op creation and placement --------------------------- // ------------------ op creation and placement ---------------------------
...@@ -213,6 +234,9 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const { ...@@ -213,6 +234,9 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#undef CHECK_TRUE
#undef EXPECT_TRUE
REGISTER_PASS(layer_norm_fuse_pass, paddle::framework::ir::LayerNormFusePass); REGISTER_PASS(layer_norm_fuse_pass, paddle::framework::ir::LayerNormFusePass);
REGISTER_PASS_CAPABILITY(layer_norm_fuse_pass) REGISTER_PASS_CAPABILITY(layer_norm_fuse_pass)
.AddCombination( .AddCombination(
......
...@@ -13,7 +13,10 @@ ...@@ -13,7 +13,10 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/layer_norm_fuse_pass.h" #include "paddle/fluid/framework/ir/layer_norm_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_test_util.h" #include "paddle/fluid/framework/ir/pass_test_util.h"
...@@ -31,64 +34,113 @@ namespace ir { ...@@ -31,64 +34,113 @@ namespace ir {
namespace { namespace {
ProgramDesc BuildGraphProgram() { class LayerNormFuseTest {
auto prog = test::BuildProgramDesc( public:
{"x", "x_mean_out", "x_sub_mean_out", "x_sub_mean_sqr_out", "std_dev_out", LayerNormFuseTest()
"std_dev_eps_out", "std_dev_eps_sqrt_out", "division_out", "scale_out", : m_prog{test::BuildProgramDesc(
"shift_out"}, {"x", "x_mean_out", "x_sub_mean_out", "x_sub_mean_sqr_out",
{"sqr_pow", "eps", "gamma", "beta"}); "std_dev_out", "std_dev_eps_out", "std_dev_eps_sqrt_out",
"division_out", "scale_out", "shift_out"},
const auto& block_desc = prog.Block(0); {"sqr_pow", "eps", "gamma", "beta"})},
auto* x_var_desc = block_desc.FindVar("x"); m_place{},
m_exe{m_place},
m_block_desc{m_prog.Block(0)} {
auto* x_var_desc = m_block_desc.FindVar("x");
x_var_desc->SetDataType(proto::VarType::FP32); x_var_desc->SetDataType(proto::VarType::FP32);
x_var_desc->SetShape({3, 32, 48}); x_var_desc->SetShape({3, 32, 48});
auto* eps_var_desc = block_desc.FindVar("eps"); auto* eps_var_desc = m_block_desc.FindVar("eps");
eps_var_desc->SetDataType(proto::VarType::FP32); eps_var_desc->SetDataType(proto::VarType::FP32);
eps_var_desc->SetShape({1}); eps_var_desc->SetShape({1});
auto* gamma_var_desc = block_desc.FindVar("gamma"); auto* gamma_var_desc = m_block_desc.FindVar("gamma");
gamma_var_desc->SetDataType(proto::VarType::FP32); gamma_var_desc->SetDataType(proto::VarType::FP32);
gamma_var_desc->SetShape({48}); gamma_var_desc->SetShape({48});
auto* beta_var_desc = block_desc.FindVar("beta"); auto* beta_var_desc = m_block_desc.FindVar("beta");
beta_var_desc->SetDataType(proto::VarType::FP32); beta_var_desc->SetDataType(proto::VarType::FP32);
beta_var_desc->SetShape({48}); beta_var_desc->SetShape({48});
auto* x_mean = test::CreateOp(&prog, "reduce_mean", {{"X", "x"}}, auto* x_mean = test::CreateOp(&m_prog, "reduce_mean", {{"X", "x"}},
{{"Out", "x_mean_out"}}, false); {{"Out", "x_mean_out"}}, false);
x_mean->SetAttr("dim", std::vector<int>{-1}); x_mean->SetAttr("dim", std::vector<int>{-1});
x_mean->SetAttr("keep_dim", true); x_mean->SetAttr("keep_dim", true);
x_mean->SetAttr("reduce_all", false); x_mean->SetAttr("reduce_all", false);
test::CreateOp(&prog, "elementwise_sub", {{"X", "x"}, {"Y", "x_mean_out"}}, test::CreateOp(&m_prog, "elementwise_sub",
{{"X", "x"}, {"Y", "x_mean_out"}},
{{"Out", "x_sub_mean_out"}}, false); {{"Out", "x_sub_mean_out"}}, false);
test::CreateOp(&prog, "elementwise_pow", test::CreateOp(&m_prog, "elementwise_pow",
{{"X", "x_sub_mean_out"}, {"Y", "sqr_pow"}}, {{"X", "x_sub_mean_out"}, {"Y", "sqr_pow"}},
{{"Out", "x_sub_mean_sqr_out"}}, false); {{"Out", "x_sub_mean_sqr_out"}}, false);
auto* std_dev = auto* std_dev =
test::CreateOp(&prog, "reduce_mean", {{"X", "x_sub_mean_sqr_out"}}, test::CreateOp(&m_prog, "reduce_mean", {{"X", "x_sub_mean_sqr_out"}},
{{"Out", "std_dev_out"}}, false); {{"Out", "std_dev_out"}}, false);
std_dev->SetAttr("dim", std::vector<int>{-1}); std_dev->SetAttr("dim", std::vector<int>{-1});
std_dev->SetAttr("keep_dim", true); std_dev->SetAttr("keep_dim", true);
std_dev->SetAttr("reduce_all", false); std_dev->SetAttr("reduce_all", false);
test::CreateOp(&prog, "elementwise_add", {{"X", "std_dev_out"}, {"Y", "eps"}}, test::CreateOp(&m_prog, "elementwise_add",
{{"X", "std_dev_out"}, {"Y", "eps"}},
{{"Out", "std_dev_eps_out"}}, false); {{"Out", "std_dev_eps_out"}}, false);
test::CreateOp(&prog, "sqrt", {{"X", "std_dev_eps_out"}}, test::CreateOp(&m_prog, "sqrt", {{"X", "std_dev_eps_out"}},
{{"Out", "std_dev_eps_sqrt_out"}}, false); {{"Out", "std_dev_eps_sqrt_out"}}, false);
test::CreateOp(&prog, "elementwise_div", test::CreateOp(&m_prog, "elementwise_div",
{{"X", "x_sub_mean_out"}, {"Y", "std_dev_eps_sqrt_out"}}, {{"X", "x_sub_mean_out"}, {"Y", "std_dev_eps_sqrt_out"}},
{{"Out", "division_out"}}, false); {{"Out", "division_out"}}, false);
test::CreateOp(&prog, "elementwise_mul", test::CreateOp(&m_prog, "elementwise_mul",
{{"X", "division_out"}, {"Y", "gamma"}}, {{"X", "division_out"}, {"Y", "gamma"}},
{{"Out", "scale_out"}}, false); {{"Out", "scale_out"}}, false);
test::CreateOp(&prog, "elementwise_add", {{"X", "scale_out"}, {"Y", "beta"}}, test::CreateOp(&m_prog, "elementwise_add",
{{"Out", "shift_out"}}, false); {{"X", "scale_out"}, {"Y", "beta"}}, {{"Out", "shift_out"}},
return prog; false);
} }
template <typename Func>
LayerNormFuseTest(const Func& func, int removed_nodes = 0,
int added_nodes = 0)
: LayerNormFuseTest() {
m_removed_nodes = removed_nodes;
m_added_nodes = added_nodes;
func(m_block_desc);
}
void setupGraph() {
auto initFun = [this](const Scope& scope,
const paddle::platform::CPUPlace& place) {
this->initEpsTensorValue(scope, place);
};
setupGraphWithInitFunc(initFun);
}
template <typename Func>
void setupGraphWithInitFunc(const Func& func) {
m_graph.reset(new Graph(m_prog));
// Init scope, as it is used in pass
m_exe.CreateVariables(m_prog, 0, true, &m_scope);
func(m_scope, m_place);
m_graph->SetNotOwned(kParamScopeAttr, &m_scope);
}
void run(bool fusion = false) const {
EXPECT_TRUE(test::RunPassAndAssert(m_graph.get(), "layer_norm_fuse_pass",
"x", "shift_out", m_removed_nodes,
m_added_nodes));
EXPECT_TRUE(CheckSubgraphOpsCount(*m_graph, fusion));
}
const ProgramDesc& getProgramDesc() const { return m_prog; }
const Graph* getGraph() const { return m_graph.get(); }
private:
void initEpsTensorValue(const Scope& scope,
const paddle::platform::CPUPlace& place) {
float eps_value = 1e-5;
test::InitLoDTensorHolder<float>(scope, place, "eps", {1}, &eps_value);
}
bool CheckFusedSubgraphOpsCount(const Graph& graph) { bool CheckSubgraphOpsCount(const Graph& graph, bool fusion) const {
if (fusion)
return test::AssertOpsCount(graph, {{"reduce_mean", 0}, return test::AssertOpsCount(graph, {{"reduce_mean", 0},
{"elementwise_sub", 0}, {"elementwise_sub", 0},
{"elementwise_pow", 0}, {"elementwise_pow", 0},
...@@ -97,34 +149,38 @@ bool CheckFusedSubgraphOpsCount(const Graph& graph) { ...@@ -97,34 +149,38 @@ bool CheckFusedSubgraphOpsCount(const Graph& graph) {
{"elementwise_div", 0}, {"elementwise_div", 0},
{"elementwise_mul", 0}, {"elementwise_mul", 0},
{"layer_norm", 1}}); {"layer_norm", 1}});
} else
return test::AssertOpsCount(graph, {{"reduce_mean", 2},
{"elementwise_sub", 1},
{"elementwise_pow", 1},
{"elementwise_add", 2},
{"sqrt", 1},
{"elementwise_div", 1},
{"elementwise_mul", 1},
{"layer_norm", 0}});
}
int m_removed_nodes{19};
int m_added_nodes{3};
ProgramDesc m_prog;
paddle::platform::CPUPlace m_place;
NaiveExecutor m_exe;
const BlockDesc& m_block_desc;
Scope m_scope;
std::unique_ptr<Graph> m_graph{nullptr};
};
} // namespace } // namespace
// ------------------------------ Test cases ----------------------------------- // ------------------------------ Test cases -----------------------------------
TEST(FuseLayerNormPass, TestFuse) { TEST(FuseLayerNormPass, TestFuse) {
ProgramDesc prog = BuildGraphProgram(); LayerNormFuseTest lnorm_test;
lnorm_test.setupGraph();
lnorm_test.run(true);
Graph graph(prog); // additional attribute checks
constexpr int removed_nodes = 19; for (const auto* node : lnorm_test.getGraph()->Nodes()) {
// LayerNorm + outputs: {Mean, Variance}
constexpr int added_nodes = 3;
auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place};
Scope scope;
float eps_value = 1e-5f;
// Init scope, as it is used in pass
exe.CreateVariables(prog, 0, true, &scope);
test::InitLoDTensorHolder<float>(&scope, place, "eps", {1}, &eps_value);
graph.SetNotOwned(kParamScopeAttr, &scope);
EXPECT_TRUE(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x",
"shift_out", removed_nodes, added_nodes));
EXPECT_TRUE(CheckFusedSubgraphOpsCount(graph));
for (const auto* node : graph.Nodes()) {
if (node->IsOp() && node->Op()->Type() == "layer_norm") { if (node->IsOp() && node->Op()->Type() == "layer_norm") {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("is_test")); ASSERT_TRUE(op->HasAttr("is_test"));
...@@ -136,54 +192,194 @@ TEST(FuseLayerNormPass, TestFuse) { ...@@ -136,54 +192,194 @@ TEST(FuseLayerNormPass, TestFuse) {
} }
TEST(FuseLayerNormPass, TestInvalidEpsNumel) { TEST(FuseLayerNormPass, TestInvalidEpsNumel) {
ProgramDesc prog = BuildGraphProgram(); const auto editEpsFun = [](const BlockDesc& block_desc) {
auto* eps_var_desc = block_desc.FindVar("eps");
auto* eps_var_desc = prog.Block(0).FindVar("eps");
eps_var_desc->SetDataType(proto::VarType::FP32); eps_var_desc->SetDataType(proto::VarType::FP32);
eps_var_desc->SetShape({2}); eps_var_desc->SetShape({2});
};
Graph graph(prog); const auto initEpsTensor = [](const Scope& scope,
constexpr int removed_nodes = 19; const paddle::platform::CPUPlace& place) {
constexpr int added_nodes = 3;
auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place};
Scope scope;
auto eps_values = std::vector<float>{1e-5f, 1e-5f}; auto eps_values = std::vector<float>{1e-5f, 1e-5f};
// Init scope, as it is used in pass test::InitLoDTensorHolder<float>(scope, place, "eps", {2},
exe.CreateVariables(prog, 0, true, &scope);
test::InitLoDTensorHolder<float>(&scope, place, "eps", {2},
eps_values.data()); eps_values.data());
};
graph.SetNotOwned(kParamScopeAttr, &scope); LayerNormFuseTest lnorm_test(editEpsFun);
EXPECT_THROW(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x", lnorm_test.setupGraphWithInitFunc(initEpsTensor);
"shift_out", removed_nodes, added_nodes), lnorm_test.run(false);
paddle::platform::EnforceNotMet);
} }
TEST(FuseLayerNormPass, TestInvalidEpsDataType) { TEST(FuseLayerNormPass, TestInvalidEpsDataType) {
ProgramDesc prog = BuildGraphProgram(); const auto editEpsFun = [](const BlockDesc& block_desc) {
auto* eps_var_desc = block_desc.FindVar("eps");
auto* eps_var_desc = prog.Block(0).FindVar("eps");
eps_var_desc->SetDataType(proto::VarType::FP64); eps_var_desc->SetDataType(proto::VarType::FP64);
eps_var_desc->SetShape({1}); eps_var_desc->SetShape({1});
};
const auto initEpsTensor = [](const Scope& scope,
const paddle::platform::CPUPlace& place) {
double eps_value = 1e-5;
test::InitLoDTensorHolder<double>(scope, place, "eps", {1}, &eps_value);
};
Graph graph(prog); LayerNormFuseTest lnorm_test(editEpsFun);
constexpr int removed_nodes = 19; lnorm_test.setupGraphWithInitFunc(initEpsTensor);
constexpr int added_nodes = 3; lnorm_test.run(false);
}
auto place = paddle::platform::CPUPlace(); TEST(FuseLayerNormPass, TestInvalidGammaRank) {
NaiveExecutor exe{place}; const auto editGammaFun = [](const BlockDesc& block_desc) {
Scope scope; auto* gamma_var_desc = block_desc.FindVar("gamma");
double eps_value = 1e-5; gamma_var_desc->SetDataType(proto::VarType::FP32);
// Init scope, as it is used in pass gamma_var_desc->SetShape({48, 32});
exe.CreateVariables(prog, 0, true, &scope); };
test::InitLoDTensorHolder<double>(&scope, place, "eps", {1}, &eps_value);
LayerNormFuseTest lnorm_test(editGammaFun);
lnorm_test.setupGraph();
lnorm_test.run(false);
}
TEST(FuseLayerNormPass, TestInvalidBetaRank) {
const auto editBetaFun = [](const BlockDesc& block_desc) {
auto* beta_var_desc = block_desc.FindVar("beta");
beta_var_desc->SetDataType(proto::VarType::FP32);
beta_var_desc->SetShape({48, 32});
};
LayerNormFuseTest lnorm_test(editBetaFun);
lnorm_test.setupGraph();
lnorm_test.run(false);
}
TEST(FuseLayerNormPass, TestUnequalGammaBetaShapes) {
const auto editGammaBetaFun = [](const BlockDesc& block_desc) {
auto* beta_var_desc = block_desc.FindVar("beta");
beta_var_desc->SetDataType(proto::VarType::FP32);
beta_var_desc->SetShape({32});
};
LayerNormFuseTest lnorm_test(editGammaBetaFun);
lnorm_test.setupGraph();
lnorm_test.run(false);
}
TEST(FuseLayerNormPass, TestGammaBetaUnequalInputChannelShape) {
const auto editGammaBetaFun = [](const BlockDesc& block_desc) {
auto* beta_var_desc = block_desc.FindVar("beta");
beta_var_desc->SetDataType(proto::VarType::FP32);
beta_var_desc->SetShape({32});
auto* gamma_var_desc = block_desc.FindVar("gamma");
gamma_var_desc->SetDataType(proto::VarType::FP32);
gamma_var_desc->SetShape({32});
};
LayerNormFuseTest lnorm_test(editGammaBetaFun);
lnorm_test.setupGraph();
lnorm_test.run(false);
}
TEST(FuseLayerNormPass, NoFusionBadInMeanDimAttrRank) {
const auto editFun = [](const BlockDesc& block_desc) {
auto* x_mean_desc =
test::GetOp(block_desc, "reduce_mean", "Out", "x_mean_out");
ASSERT_NE(x_mean_desc, nullptr);
x_mean_desc->SetAttr("dim", std::vector<int>{1, 1});
};
LayerNormFuseTest lnorm_test(editFun);
lnorm_test.setupGraph();
lnorm_test.run(false);
}
TEST(FuseLayerNormPass, NoFusionBadInMeanDimAttr) {
const auto editFun = [](const BlockDesc& block_desc) {
auto* x_mean_desc =
test::GetOp(block_desc, "reduce_mean", "Out", "x_mean_out");
ASSERT_NE(x_mean_desc, nullptr);
x_mean_desc->SetAttr("dim", std::vector<int>{1});
};
LayerNormFuseTest lnorm_test(editFun);
lnorm_test.setupGraph();
lnorm_test.run(false);
}
TEST(FuseLayerNormPass, NoFusionBadInMeanKeepDimAttr) {
const auto editFun = [](const BlockDesc& block_desc) {
auto* x_mean_desc =
test::GetOp(block_desc, "reduce_mean", "Out", "x_mean_out");
ASSERT_NE(x_mean_desc, nullptr);
x_mean_desc->SetAttr("keep_dim", false);
};
LayerNormFuseTest lnorm_test(editFun);
lnorm_test.setupGraph();
lnorm_test.run(false);
}
TEST(FuseLayerNormPass, NoFusionBadInMeanReduceAllAttr) {
const auto editFun = [](const BlockDesc& block_desc) {
auto* x_mean_desc =
test::GetOp(block_desc, "reduce_mean", "Out", "x_mean_out");
ASSERT_NE(x_mean_desc, nullptr);
x_mean_desc->SetAttr("reduce_all", true);
};
LayerNormFuseTest lnorm_test(editFun);
lnorm_test.setupGraph();
lnorm_test.run(false);
}
TEST(FuseLayerNormPass, NoFusionBadStdDevMeanDimAttrRank) {
const auto editFun = [](const BlockDesc& block_desc) {
auto* std_dev_desc =
test::GetOp(block_desc, "reduce_mean", "Out", "std_dev_out");
ASSERT_NE(std_dev_desc, nullptr);
std_dev_desc->SetAttr("dim", std::vector<int>{1, 1});
};
LayerNormFuseTest lnorm_test(editFun);
lnorm_test.setupGraph();
lnorm_test.run(false);
}
TEST(FuseLayerNormPass, NoFusionBadStdDevMeanDimAttr) {
const auto editFun = [](const BlockDesc& block_desc) {
auto* std_dev_desc =
test::GetOp(block_desc, "reduce_mean", "Out", "std_dev_out");
ASSERT_NE(std_dev_desc, nullptr);
std_dev_desc->SetAttr("dim", std::vector<int>{1});
};
LayerNormFuseTest lnorm_test(editFun);
lnorm_test.setupGraph();
lnorm_test.run(false);
}
TEST(FuseLayerNormPass, NoFusionBadStdDevMeanKeepDimAttr) {
const auto editFun = [](const BlockDesc& block_desc) {
auto* std_dev_desc =
test::GetOp(block_desc, "reduce_mean", "Out", "std_dev_out");
ASSERT_NE(std_dev_desc, nullptr);
std_dev_desc->SetAttr("keep_dim", false);
};
LayerNormFuseTest lnorm_test(editFun);
lnorm_test.setupGraph();
lnorm_test.run(false);
}
graph.SetNotOwned(kParamScopeAttr, &scope); TEST(FuseLayerNormPass, NoFusionBadStdDevMeanReduceAllAttr) {
EXPECT_THROW(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x", const auto editFun = [](const BlockDesc& block_desc) {
"shift_out", removed_nodes, added_nodes), auto* std_dev_desc =
paddle::platform::EnforceNotMet); test::GetOp(block_desc, "reduce_mean", "Out", "std_dev_out");
ASSERT_NE(std_dev_desc, nullptr);
std_dev_desc->SetAttr("reduce_all", true);
};
LayerNormFuseTest lnorm_test(editFun);
lnorm_test.setupGraph();
lnorm_test.run(false);
} }
TEST(FuseLayerNormPass, pass_op_version_check) { TEST(FuseLayerNormPass, pass_op_version_check) {
......
...@@ -175,10 +175,11 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name, ...@@ -175,10 +175,11 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name,
} }
template <typename T> template <typename T>
void InitLoDTensorHolder(Scope* scope, const paddle::platform::Place& place, void InitLoDTensorHolder(const Scope& scope,
const paddle::platform::Place& place,
const std::string& var_name, const std::string& var_name,
const std::vector<int64_t>& dims, const T* data) { const std::vector<int64_t>& dims, const T* data) {
auto var = scope->Var(var_name); auto var = scope.FindLocalVar(var_name);
auto tensor = var->GetMutable<LoDTensor>(); auto tensor = var->GetMutable<LoDTensor>();
auto* tensor_mem_ptr = tensor->mutable_data<T>(make_ddim(dims), place); auto* tensor_mem_ptr = tensor->mutable_data<T>(make_ddim(dims), place);
if (data != nullptr) { if (data != nullptr) {
...@@ -189,14 +190,16 @@ void InitLoDTensorHolder(Scope* scope, const paddle::platform::Place& place, ...@@ -189,14 +190,16 @@ void InitLoDTensorHolder(Scope* scope, const paddle::platform::Place& place,
} }
// Instantiate for below data types. // Instantiate for below data types.
template void InitLoDTensorHolder<float>(Scope*, const paddle::platform::Place&, template void InitLoDTensorHolder<float>(const Scope&,
const paddle::platform::Place&,
const std::string&, const std::string&,
const std::vector<int64_t>&, const std::vector<int64_t>&,
const float*); const float*);
template void InitLoDTensorHolder<int>(Scope*, const paddle::platform::Place&, template void InitLoDTensorHolder<int>(const Scope&,
const paddle::platform::Place&,
const std::string&, const std::string&,
const std::vector<int64_t>&, const int*); const std::vector<int64_t>&, const int*);
template void InitLoDTensorHolder<double>(Scope*, template void InitLoDTensorHolder<double>(const Scope&,
const paddle::platform::Place&, const paddle::platform::Place&,
const std::string&, const std::string&,
const std::vector<int64_t>&, const std::vector<int64_t>&,
...@@ -205,7 +208,13 @@ template void InitLoDTensorHolder<double>(Scope*, ...@@ -205,7 +208,13 @@ template void InitLoDTensorHolder<double>(Scope*,
OpDesc* GetOp(const ProgramDesc& prog, const std::string& op_type, OpDesc* GetOp(const ProgramDesc& prog, const std::string& op_type,
const std::string& output_name, const std::string& output_name,
const std::string& output_arg_name) { const std::string& output_arg_name) {
auto all_ops = prog.Block(0).AllOps(); return GetOp(prog.Block(0), op_type, output_name, output_arg_name);
}
OpDesc* GetOp(const BlockDesc& block_desc, const std::string& op_type,
const std::string& output_name,
const std::string& output_arg_name) {
auto all_ops = block_desc.AllOps();
for (auto* op_desc : all_ops) { for (auto* op_desc : all_ops) {
if (op_desc->Type() == op_type && op_desc->HasOutput(output_name)) { if (op_desc->Type() == op_type && op_desc->HasOutput(output_name)) {
const auto& arg_names = op_desc->Outputs().at(output_name); const auto& arg_names = op_desc->Outputs().at(output_name);
......
...@@ -128,7 +128,8 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name, ...@@ -128,7 +128,8 @@ bool RunPassAndAssert(Graph* graph, const std::string& pass_name,
/// @tparam T Tensor data type. /// @tparam T Tensor data type.
/// ///
template <typename T> template <typename T>
void InitLoDTensorHolder(Scope* scope, const paddle::platform::Place& place, void InitLoDTensorHolder(const Scope& scope,
const paddle::platform::Place& place,
const std::string& var_name, const std::string& var_name,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
const T* data = nullptr); const T* data = nullptr);
...@@ -148,6 +149,10 @@ OpDesc* GetOp(const ProgramDesc& prog, const std::string& op_type, ...@@ -148,6 +149,10 @@ OpDesc* GetOp(const ProgramDesc& prog, const std::string& op_type,
const std::string& output_name, const std::string& output_name,
const std::string& output_arg_name); const std::string& output_arg_name);
OpDesc* GetOp(const BlockDesc& block_desc, const std::string& op_type,
const std::string& output_name,
const std::string& output_arg_name);
} // namespace test } // namespace test
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册