未验证 提交 3ba69809 编写于 作者: A Adam Osewski 提交者: GitHub

Fix LayerNorm tester for gcc4.8 (#30962)

上级 93c1d9e7
...@@ -43,21 +43,21 @@ class LayerNormFuseTest { ...@@ -43,21 +43,21 @@ class LayerNormFuseTest {
"division_out", "scale_out", "shift_out"}, "division_out", "scale_out", "shift_out"},
{"sqr_pow", "eps", "gamma", "beta"})}, {"sqr_pow", "eps", "gamma", "beta"})},
m_place{}, m_place{},
m_exe{m_place}, m_exe{m_place} {
m_block_desc{m_prog.Block(0)} { const BlockDesc& block_desc = m_prog.Block(0);
auto* x_var_desc = m_block_desc.FindVar("x"); auto* x_var_desc = block_desc.FindVar("x");
x_var_desc->SetDataType(proto::VarType::FP32); x_var_desc->SetDataType(proto::VarType::FP32);
x_var_desc->SetShape({3, 32, 48}); x_var_desc->SetShape({3, 32, 48});
auto* eps_var_desc = m_block_desc.FindVar("eps"); auto* eps_var_desc = block_desc.FindVar("eps");
eps_var_desc->SetDataType(proto::VarType::FP32); eps_var_desc->SetDataType(proto::VarType::FP32);
eps_var_desc->SetShape({1}); eps_var_desc->SetShape({1});
auto* gamma_var_desc = m_block_desc.FindVar("gamma"); auto* gamma_var_desc = block_desc.FindVar("gamma");
gamma_var_desc->SetDataType(proto::VarType::FP32); gamma_var_desc->SetDataType(proto::VarType::FP32);
gamma_var_desc->SetShape({48}); gamma_var_desc->SetShape({48});
auto* beta_var_desc = m_block_desc.FindVar("beta"); auto* beta_var_desc = block_desc.FindVar("beta");
beta_var_desc->SetDataType(proto::VarType::FP32); beta_var_desc->SetDataType(proto::VarType::FP32);
beta_var_desc->SetShape({48}); beta_var_desc->SetShape({48});
...@@ -102,7 +102,7 @@ class LayerNormFuseTest { ...@@ -102,7 +102,7 @@ class LayerNormFuseTest {
: LayerNormFuseTest() { : LayerNormFuseTest() {
m_removed_nodes = removed_nodes; m_removed_nodes = removed_nodes;
m_added_nodes = added_nodes; m_added_nodes = added_nodes;
func(m_block_desc); func(m_prog.Block(0));
} }
void setupGraph() { void setupGraph() {
...@@ -165,7 +165,6 @@ class LayerNormFuseTest { ...@@ -165,7 +165,6 @@ class LayerNormFuseTest {
ProgramDesc m_prog; ProgramDesc m_prog;
paddle::platform::CPUPlace m_place; paddle::platform::CPUPlace m_place;
NaiveExecutor m_exe; NaiveExecutor m_exe;
const BlockDesc& m_block_desc;
Scope m_scope; Scope m_scope;
std::unique_ptr<Graph> m_graph{nullptr}; std::unique_ptr<Graph> m_graph{nullptr};
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册