未验证 提交 bcbb7a97 编写于 作者: Z zyfncg 提交者: GitHub

support selected_rows kernel for multiply in dygraph (#45217)

上级 e31a0a50
...@@ -1860,7 +1860,8 @@ ...@@ -1860,7 +1860,8 @@
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
kernel : kernel :
func : multiply func : multiply {dense, dense -> dense},
multiply_sr {selected_rows, dense -> selected_rows}
backward : multiply_grad backward : multiply_grad
- api : nearest_interp - api : nearest_interp
......
...@@ -555,7 +555,7 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -555,7 +555,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
if need_clip: if need_clip:
clip_input = (clip_var.astype('float16') if g.dtype clip_input = (clip_var.astype('float16') if g.dtype
== core.VarDesc.VarType.FP16 else clip_var) == core.VarDesc.VarType.FP16 else clip_var)
new_grad = _C_ops.elementwise_mul(g, clip_input) new_grad = layers.elementwise_mul(g, clip_input)
params_and_grads.append((p, new_grad)) params_and_grads.append((p, new_grad))
else: else:
params_and_grads.append((p, g)) params_and_grads.append((p, g))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册