未验证 提交 bcbb7a97 编写于 作者: Z zyfncg 提交者: GitHub

support selected_rows kernel for multiply in dygraph (#45217)

上级 e31a0a50
......@@ -1860,7 +1860,8 @@
infer_meta :
func : ElementwiseInferMeta
kernel :
func : multiply
func : multiply {dense, dense -> dense},
multiply_sr {selected_rows, dense -> selected_rows}
backward : multiply_grad
- api : nearest_interp
......
......@@ -555,7 +555,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
if need_clip:
clip_input = (clip_var.astype('float16') if g.dtype
== core.VarDesc.VarType.FP16 else clip_var)
new_grad = _C_ops.elementwise_mul(g, clip_input)
new_grad = layers.elementwise_mul(g, clip_input)
params_and_grads.append((p, new_grad))
else:
params_and_grads.append((p, g))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册