未验证 提交 3ba69809 编写于 作者: A Adam Osewski 提交者: GitHub

Fix LayerNorm tester for gcc4.8 (#30962)

上级 93c1d9e7
......@@ -43,21 +43,21 @@ class LayerNormFuseTest {
"division_out", "scale_out", "shift_out"},
{"sqr_pow", "eps", "gamma", "beta"})},
m_place{},
m_exe{m_place},
m_block_desc{m_prog.Block(0)} {
auto* x_var_desc = m_block_desc.FindVar("x");
m_exe{m_place} {
const BlockDesc& block_desc = m_prog.Block(0);
auto* x_var_desc = block_desc.FindVar("x");
x_var_desc->SetDataType(proto::VarType::FP32);
x_var_desc->SetShape({3, 32, 48});
auto* eps_var_desc = m_block_desc.FindVar("eps");
auto* eps_var_desc = block_desc.FindVar("eps");
eps_var_desc->SetDataType(proto::VarType::FP32);
eps_var_desc->SetShape({1});
auto* gamma_var_desc = m_block_desc.FindVar("gamma");
auto* gamma_var_desc = block_desc.FindVar("gamma");
gamma_var_desc->SetDataType(proto::VarType::FP32);
gamma_var_desc->SetShape({48});
auto* beta_var_desc = m_block_desc.FindVar("beta");
auto* beta_var_desc = block_desc.FindVar("beta");
beta_var_desc->SetDataType(proto::VarType::FP32);
beta_var_desc->SetShape({48});
......@@ -102,7 +102,7 @@ class LayerNormFuseTest {
: LayerNormFuseTest() {
m_removed_nodes = removed_nodes;
m_added_nodes = added_nodes;
func(m_block_desc);
func(m_prog.Block(0));
}
void setupGraph() {
......@@ -165,7 +165,6 @@ class LayerNormFuseTest {
ProgramDesc m_prog;
paddle::platform::CPUPlace m_place;
NaiveExecutor m_exe;
const BlockDesc& m_block_desc;
Scope m_scope;
std::unique_ptr<Graph> m_graph{nullptr};
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册