提交 3b84584e 编写于 作者: Z zhongpu 提交者: liym27

fix Variable's gradient api in framework.py, test=develop (#21577)

* fix Variable's gradient api in framework.py, test=develop

* remove namescope, test=develop
上级 4b4a9cc8
...@@ -1090,6 +1090,7 @@ class Variable(object): ...@@ -1090,6 +1090,7 @@ class Variable(object):
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
# example1: return ndarray
x = np.ones([2, 2], np.float32) x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
inputs2 = [] inputs2 = []
...@@ -1104,6 +1105,19 @@ class Variable(object): ...@@ -1104,6 +1105,19 @@ class Variable(object):
loss2.backward(backward_strategy) loss2.backward(backward_strategy)
print(loss2.gradient()) print(loss2.gradient())
# example2: return tuple of ndarray
with fluid.dygraph.guard():
embedding = fluid.dygraph.Embedding(
size=[20, 32],
param_attr='emb.w',
is_sparse=True)
x_data = np.arange(12).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, 3, 1))
x = fluid.dygraph.base.to_variable(x_data)
out = embedding(x)
out.backward()
print(embedding.weight.gradient())
""" """
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册