layers.mul 计算double grad时报memory_size()的错
Created by: LDOUBLEV
报错环境: CPU/GPU paddle 1.5.2
问题复现代码:
def test_mul_grad_paddle_vs_torch():
train_program = fluid.Program()
start_program = fluid.Program()
#place = fluid.CUDAPlace(0)
place = fluid.CPUPlace()
with fluid.program_guard(train_program, start_program):
rng = np.random.RandomState(0)
inp_ = rng.uniform(-1, 1, [3, 2]).astype('float32')
w1_ = rng.uniform(-1, 1, [2, 5]).astype('float32')
w2_ = rng.uniform(-1, 1, [5, 4]).astype('float32')
yg_ = rng.uniform(-1, 1, [2, ]).astype('float32')
inp = fluid.layers.data('inp', [3, 2], append_batch_size=False)
w1 = fluid.layers.data('w1', [2, 5], append_batch_size=False)
w2 = fluid.layers.data('w2', [5, 4], append_batch_size=False)
yg = fluid.layers.data('yg', [2, ], append_batch_size=False)
inp.stop_gradient = False
w1.stop_gradient = False
w2.stop_gradient = False
yg.stop_gradient = False
y = fluid.layers.mul(fluid.layers.mul(inp, w1), w2)
f = y
x = [w1, w2]
dfdx_f1 = fluid.gradients(f, x, f)
# double gradient
dfdx_x_ = fluid.gradients([dfdx_f1[0]*x[0], dfdx_f1[1]*x[1]], f)
print(dfdx_x_)
exe = fluid.Executor(place)
exe.run(program=fluid.default_startup_program())
compiled_prog = fluid.compiler.CompiledProgram(train_program)
res = exe.run(compiled_prog, feed={'inp':inp_, 'w1':w1_, 'w2':w2_, 'yg':yg_},
fetch_list=[dfdx_f1[0].name, dfdx_f1[1].name, dfdx_x[0].name)
print(res[0], res[1], '\n', res[2])
""" for torch matmul """
import torch
def n2t(x): return torch.from_numpy(np.array(x))
inp_t, w1_t, w2_t, yg_t = n2t(inp_), n2t(w1_), n2t(w2_), n2t(yg_)
w1_t.requires_grad = True
w2_t.requires_grad = True
yt = torch.matmul(torch.matmul(inp_t, w1_t), w2_t)
x = [w1_t, w2_t]
g = yt.detach()
g.requires_grad = True
dfdxt_g = torch.autograd.grad(yt, x, g, create_graph=True)
print(dfdxt_g[0].detach().numpy(), '\n', dfdxt_g[1].detach().numpy())
dfdx_x = torch.autograd.grad(dfdxt_g, g, x, retain_graph=True)
np.testing.assert_allclose(dfdxt_g[0].detach().numpy(), res[0], atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(dfdxt_g[1].detach().numpy(), res[1], atol=1e-5, rtol=1e-5)