提交 f65913d6 编写于 作者: C chenhaozhe

fix performance of bert

上级 25b0037b
...@@ -687,7 +687,7 @@ bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &va ...@@ -687,7 +687,7 @@ bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &va
MS_EXCEPTION_IF_NULL(equiv1_node); MS_EXCEPTION_IF_NULL(equiv1_node);
auto equiv2_node = GetAnfNodeByVar(equiv2, var_node); auto equiv2_node = GetAnfNodeByVar(equiv2, var_node);
MS_EXCEPTION_IF_NULL(equiv2_node); MS_EXCEPTION_IF_NULL(equiv2_node);
return equiv1_node == equiv2_node; return *equiv1_node == *equiv2_node;
} }
AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) {
......
...@@ -180,7 +180,7 @@ class Lamb(Optimizer): ...@@ -180,7 +180,7 @@ class Lamb(Optimizer):
beta2=0.999, beta2=0.999,
eps=1e-6, eps=1e-6,
weight_decay=0.0, weight_decay=0.0,
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name): decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()):
super(Lamb, self).__init__(start_learning_rate, params) super(Lamb, self).__init__(start_learning_rate, params)
if self.is_group: if self.is_group:
......
...@@ -191,8 +191,8 @@ def get_bprop_mul(self): ...@@ -191,8 +191,8 @@ def get_bprop_mul(self):
mul_func = P.Mul() mul_func = P.Mul()
def bprop(x, y, out, dout): def bprop(x, y, out, dout):
bc_dx = mul_func(dout, y) bc_dx = mul_func(y, dout)
bc_dy = mul_func(dout, x) bc_dy = mul_func(x, dout)
return binop_grad_common(x, y, bc_dx, bc_dy) return binop_grad_common(x, y, bc_dx, bc_dy)
return bprop return bprop
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册