提交 56414c7d 编写于 作者: S songyouwei 提交者: hong

move private weight fields to public ones (#21982)

* move private weight properties to public ones
test=develop

* revert changes to FC
test=develop

* fix unittest
test=develop

* fix unittest
test=develop

* fix coverage
test=develop

* fix merged dev
test=develop

* bug fix
test=develop
上级 b852ef73
此差异已折叠。
...@@ -153,8 +153,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -153,8 +153,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2) v2 = fluid.dygraph.to_variable(value2)
loss = case1(v1, v2) loss = case1(v1, v2)
loss.backward() loss.backward()
self.assertTrue(case1.fc2._w._grad_ivar() is not None) self.assertTrue(case1.fc2.weight._grad_ivar() is not None)
self.assertTrue(case1.fc1._w._grad_ivar() is not None) self.assertTrue(case1.fc1.weight._grad_ivar() is not None)
def test_auto_prune2(self): def test_auto_prune2(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -166,8 +166,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -166,8 +166,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
loss = case2(v1, v2) loss = case2(v1, v2)
loss.backward() loss.backward()
self.assertTrue(case2.fc2._w._grad_ivar() is None) self.assertTrue(case2.fc2.weight._grad_ivar() is None)
self.assertTrue(case2.fc1._w._grad_ivar() is not None) self.assertTrue(case2.fc1.weight._grad_ivar() is not None)
def test_auto_prune3(self): def test_auto_prune3(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -178,7 +178,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -178,7 +178,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2) v2 = fluid.dygraph.to_variable(value2)
loss, part2 = case3(v1, v2, 1) loss, part2 = case3(v1, v2, 1)
loss.backward() loss.backward()
self.assertTrue(case3.fc._w._grad_ivar() is not None) self.assertTrue(case3.fc.weight._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 0).all()) self.assertTrue((part2.gradient() == 0).all())
def test_auto_prune4(self): def test_auto_prune4(self):
...@@ -190,7 +190,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -190,7 +190,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2) v2 = fluid.dygraph.to_variable(value2)
loss, part2 = case4(v1, v2, 1) loss, part2 = case4(v1, v2, 1)
part2.backward() part2.backward()
self.assertTrue(case4.fc._w._grad_ivar() is not None) self.assertTrue(case4.fc.weight._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 1).all()) self.assertTrue((part2.gradient() == 1).all())
def test_auto_prune5(self): def test_auto_prune5(self):
...@@ -202,7 +202,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -202,7 +202,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2) v2 = fluid.dygraph.to_variable(value2)
loss, part1, part2 = case4(v1, v2, 2) loss, part1, part2 = case4(v1, v2, 2)
part1.backward() part1.backward()
self.assertTrue(case4.fc._w._grad_ivar() is not None) self.assertTrue(case4.fc.weight._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 0).all()) self.assertTrue((part2.gradient() == 0).all())
def test_auto_prune6(self): def test_auto_prune6(self):
...@@ -220,7 +220,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -220,7 +220,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
out1.stop_gradient = True out1.stop_gradient = True
out = fluid.layers.concat(input=[out1, out2, c], axis=1) out = fluid.layers.concat(input=[out1, out2, c], axis=1)
out.backward() out.backward()
self.assertTrue((fc._w.gradient() == 0).all()) self.assertTrue((fc.weight.gradient() == 0).all())
self.assertTrue((out1.gradient() == 0).all()) self.assertTrue((out1.gradient() == 0).all())
def test_auto_prune7(self): def test_auto_prune7(self):
...@@ -239,7 +239,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -239,7 +239,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
out = fluid.layers.concat(input=[out1, out2, c], axis=1) out = fluid.layers.concat(input=[out1, out2, c], axis=1)
backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy = fluid.dygraph.BackwardStrategy()
out.backward(backward_strategy) out.backward(backward_strategy)
self.assertTrue((fc._w.gradient() == 0).all()) self.assertTrue((fc.weight.gradient() == 0).all())
self.assertTrue((out1.gradient() == 0).all()) self.assertTrue((out1.gradient() == 0).all())
def test_auto_prune8(self): def test_auto_prune8(self):
...@@ -253,17 +253,17 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -253,17 +253,17 @@ class TestImperativeAutoPrune(unittest.TestCase):
b = fluid.dygraph.to_variable(value1) b = fluid.dygraph.to_variable(value1)
c = fluid.dygraph.to_variable(value2) c = fluid.dygraph.to_variable(value2)
out1 = fc(a) out1 = fc(a)
fc_origin = fc._w.numpy() fc_origin = fc.weight.numpy()
out2 = fc2(out1) out2 = fc2(out1)
fc2_origin = fc2._w.numpy() fc2_origin = fc2.weight.numpy()
fc2._w.stop_gradient = True fc2.weight.stop_gradient = True
out2.backward() out2.backward()
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
learning_rate=0.003, learning_rate=0.003,
parameter_list=(fc.parameters() + fc2.parameters())) parameter_list=(fc.parameters() + fc2.parameters()))
optimizer.minimize(out2) optimizer.minimize(out2)
self.assertTrue(np.array_equal(fc2_origin, fc2._w.numpy())) self.assertTrue(np.array_equal(fc2_origin, fc2.weight.numpy()))
self.assertFalse(np.array_equal(fc_origin, fc._w.numpy())) self.assertFalse(np.array_equal(fc_origin, fc.weight.numpy()))
def test_auto_prune9(self): def test_auto_prune9(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -276,19 +276,19 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -276,19 +276,19 @@ class TestImperativeAutoPrune(unittest.TestCase):
b = fluid.dygraph.to_variable(value1) b = fluid.dygraph.to_variable(value1)
c = fluid.dygraph.to_variable(value2) c = fluid.dygraph.to_variable(value2)
out1 = fc(a) out1 = fc(a)
fc_origin = fc._w.numpy() fc_origin = fc.weight.numpy()
out2 = fc2(out1) out2 = fc2(out1)
fc2_origin = fc2._w.numpy() fc2_origin = fc2.weight.numpy()
out2.stop_gradient = True out2.stop_gradient = True
out2.backward() out2.backward()
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
learning_rate=0.003, learning_rate=0.003,
parameter_list=(fc.parameters() + fc2.parameters())) parameter_list=(fc.parameters() + fc2.parameters()))
optimizer.minimize(out2) optimizer.minimize(out2)
self.assertTrue(np.array_equal(fc2_origin, fc2._w.numpy())) self.assertTrue(np.array_equal(fc2_origin, fc2.weight.numpy()))
self.assertTrue(np.array_equal(fc_origin, fc._w.numpy())) self.assertTrue(np.array_equal(fc_origin, fc.weight.numpy()))
try: try:
fc2._w.gradient() fc2.weight.gradient()
except ValueError as e: except ValueError as e:
assert type(e) == ValueError assert type(e) == ValueError
...@@ -309,7 +309,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -309,7 +309,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True backward_strategy.sort_sum_gradient = True
out.backward(backward_strategy) out.backward(backward_strategy)
self.assertTrue((fc._w.gradient() == 0).all()) self.assertTrue((fc.weight.gradient() == 0).all())
self.assertTrue((out1.gradient() == 0).all()) self.assertTrue((out1.gradient() == 0).all())
def test_auto_prune_with_optimizer(self): def test_auto_prune_with_optimizer(self):
...@@ -336,10 +336,10 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -336,10 +336,10 @@ class TestImperativeAutoPrune(unittest.TestCase):
loss.backward() loss.backward()
_, params_grads = optimizer.minimize(loss, grad_clip=grad_clip) _, params_grads = optimizer.minimize(loss, grad_clip=grad_clip)
for items in params_grads: for items in params_grads:
assert items[0].name is not model.embed1._w.name assert items[0].name is not model.embed1.weight.name
assert items[0].name is not model.fc1._w.name assert items[0].name is not model.fc1.weight.name
assert model.embed1._w._grad_ivar() is None assert model.embed1.weight._grad_ivar() is None
assert model.fc1._w._grad_ivar() is None assert model.fc1.weight._grad_ivar() is None
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
model = MyLayer2("mylayer", vocab_size, size) model = MyLayer2("mylayer", vocab_size, size)
...@@ -355,10 +355,10 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -355,10 +355,10 @@ class TestImperativeAutoPrune(unittest.TestCase):
loss.backward() loss.backward()
optimizer.minimize(loss, grad_clip=grad_clip) optimizer.minimize(loss, grad_clip=grad_clip)
for items in params_grads: for items in params_grads:
assert items[0].name is not model.embed1._w.name assert items[0].name is not model.embed1.weight.name
assert items[0].name is not model.fc1._w.name assert items[0].name is not model.fc1.weight.name
assert model.embed1._w._grad_ivar() is None assert model.embed1.weight._grad_ivar() is None
assert model.fc1._w._grad_ivar() is None assert model.fc1.weight._grad_ivar() is None
def test_case2_prune_no_grad_branch(self): def test_case2_prune_no_grad_branch(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -369,8 +369,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -369,8 +369,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
case3 = AutoPruneLayer2("l2") case3 = AutoPruneLayer2("l2")
loss = case3(v1, v2) loss = case3(v1, v2)
loss.backward() loss.backward()
self.assertTrue(case3.fc2._w._grad_ivar() is None) self.assertTrue(case3.fc2.weight._grad_ivar() is None)
self.assertTrue(case3.fc._w._grad_ivar() is not None) self.assertTrue(case3.fc.weight._grad_ivar() is not None)
def test_case2_prune_no_grad_branch(self): def test_case2_prune_no_grad_branch(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -381,8 +381,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -381,8 +381,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
case3 = AutoPruneLayer2("l2") case3 = AutoPruneLayer2("l2")
loss = case3(v1, v2) loss = case3(v1, v2)
loss.backward() loss.backward()
self.assertTrue(case3.fc2._w._grad_ivar() is None) self.assertTrue(case3.fc2.weight._grad_ivar() is None)
self.assertTrue(case3.fc._w._grad_ivar() is not None) self.assertTrue(case3.fc.weight._grad_ivar() is not None)
def test_case3_prune_no_grad_branch2(self): def test_case3_prune_no_grad_branch2(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -395,7 +395,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -395,7 +395,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
out = fluid.layers.one_hot(input=label, depth=100) out = fluid.layers.one_hot(input=label, depth=100)
loss = fluid.layers.mean(out) loss = fluid.layers.mean(out)
loss.backward() loss.backward()
self.assertTrue(fc._w._grad_ivar() is None) self.assertTrue(fc.weight._grad_ivar() is None)
def test_case4_with_no_grad_op_maker(self): def test_case4_with_no_grad_op_maker(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
...@@ -342,7 +342,7 @@ class TestImperative(unittest.TestCase): ...@@ -342,7 +342,7 @@ class TestImperative(unittest.TestCase):
out = mlp(var_inp) out = mlp(var_inp)
dy_out = out.numpy() dy_out = out.numpy()
out.backward() out.backward()
dy_grad = mlp._fc1._w.gradient() dy_grad = mlp._fc1.weight.gradient()
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var_inp2 = fluid.dygraph.base.to_variable(np_inp) var_inp2 = fluid.dygraph.base.to_variable(np_inp)
...@@ -352,7 +352,7 @@ class TestImperative(unittest.TestCase): ...@@ -352,7 +352,7 @@ class TestImperative(unittest.TestCase):
backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True backward_strategy.sort_sum_gradient = True
out2.backward(backward_strategy) out2.backward(backward_strategy)
dy_grad2 = mlp2._fc1._w.gradient() dy_grad2 = mlp2._fc1.weight.gradient()
with new_program_scope(): with new_program_scope():
inp = fluid.layers.data( inp = fluid.layers.data(
...@@ -360,7 +360,7 @@ class TestImperative(unittest.TestCase): ...@@ -360,7 +360,7 @@ class TestImperative(unittest.TestCase):
mlp = MLP("mlp") mlp = MLP("mlp")
out = mlp(inp) out = mlp(inp)
param_grads = fluid.backward.append_backward( param_grads = fluid.backward.append_backward(
out, parameter_list=[mlp._fc1._w.name])[0] out, parameter_list=[mlp._fc1.weight.name])[0]
exe = fluid.Executor(fluid.CPUPlace( exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
......
...@@ -59,7 +59,7 @@ class SimpleNet(fluid.Layer): ...@@ -59,7 +59,7 @@ class SimpleNet(fluid.Layer):
x_emb = self.embedding(input) x_emb = self.embedding(input)
projection = fluid.layers.matmul( projection = fluid.layers.matmul(
x_emb, fluid.layers.transpose( x_emb, fluid.layers.transpose(
self.embedding._w, perm=[1, 0])) self.embedding.weight, perm=[1, 0]))
projection = fluid.layers.elementwise_add(projection, self.softmax_bias) projection = fluid.layers.elementwise_add(projection, self.softmax_bias)
projection = fluid.layers.reshape( projection = fluid.layers.reshape(
projection, shape=[-1, self.vocab_size]) projection, shape=[-1, self.vocab_size])
......
...@@ -61,7 +61,7 @@ class TestSimpleNet(unittest.TestCase): ...@@ -61,7 +61,7 @@ class TestSimpleNet(unittest.TestCase):
input_emb, emb = simplenet(input) input_emb, emb = simplenet(input)
try: try:
emb._w.gradient() emb.weight.gradient()
except ValueError as e: except ValueError as e:
assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str( assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str(
e) e)
...@@ -73,11 +73,11 @@ class TestSimpleNet(unittest.TestCase): ...@@ -73,11 +73,11 @@ class TestSimpleNet(unittest.TestCase):
input_emb.backward(backward_strategy) input_emb.backward(backward_strategy)
adam.minimize(input_emb) # grad_clip=grad_clip adam.minimize(input_emb) # grad_clip=grad_clip
emb._w.gradient() emb.weight.gradient()
emb.clear_gradients() emb.clear_gradients()
try: try:
emb._w.gradient() emb.weight.gradient()
except ValueError as e: except ValueError as e:
assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str( assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str(
e) e)
...@@ -108,7 +108,7 @@ class TestSimpleNet(unittest.TestCase): ...@@ -108,7 +108,7 @@ class TestSimpleNet(unittest.TestCase):
input_emb, emb = simplenet(input) input_emb, emb = simplenet(input)
try: try:
emb._w.gradient() emb.weight.gradient()
except ValueError as e: except ValueError as e:
assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str( assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str(
e) e)
...@@ -120,11 +120,11 @@ class TestSimpleNet(unittest.TestCase): ...@@ -120,11 +120,11 @@ class TestSimpleNet(unittest.TestCase):
input_emb.backward(backward_strategy) input_emb.backward(backward_strategy)
adam.minimize(input_emb, grad_clip=grad_clip) adam.minimize(input_emb, grad_clip=grad_clip)
emb._w.gradient() emb.weight.gradient()
emb.clear_gradients() emb.clear_gradients()
try: try:
emb._w.gradient() emb.weight.gradient()
except ValueError as e: except ValueError as e:
assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str( assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str(
e) e)
......
...@@ -66,7 +66,7 @@ class SimpleNet(fluid.Layer): ...@@ -66,7 +66,7 @@ class SimpleNet(fluid.Layer):
fc = fluid.layers.elementwise_add(fc, self.softmax_bias) fc = fluid.layers.elementwise_add(fc, self.softmax_bias)
projection = fluid.layers.matmul( projection = fluid.layers.matmul(
fc, fluid.layers.transpose( fc, fluid.layers.transpose(
self.embedding._w, perm=[1, 0])) self.embedding.weight, perm=[1, 0]))
projection = fluid.layers.reshape( projection = fluid.layers.reshape(
projection, shape=[-1, self.vocab_size]) projection, shape=[-1, self.vocab_size])
loss = fluid.layers.softmax_with_cross_entropy( loss = fluid.layers.softmax_with_cross_entropy(
......
...@@ -843,7 +843,7 @@ class WrapDecoderLayer(Layer): ...@@ -843,7 +843,7 @@ class WrapDecoderLayer(Layer):
if self._weight_sharing: if self._weight_sharing:
predict = fluid.layers.matmul( predict = fluid.layers.matmul(
x=dec_output_reshape, x=dec_output_reshape,
y=self._prepare_decoder_layer._input_emb._w, y=self._prepare_decoder_layer._input_emb.weight,
transpose_y=True) transpose_y=True)
else: else:
predict = self._fc(dec_output_reshape) predict = self._fc(dec_output_reshape)
...@@ -917,7 +917,7 @@ class TransFormer(Layer): ...@@ -917,7 +917,7 @@ class TransFormer(Layer):
is_sparse=is_sparse) is_sparse=is_sparse)
if weight_sharing: if weight_sharing:
self._wrap_decoder_layer._prepare_decoder_layer._input_emb._w = self._wrap_encoder_layer._prepare_encoder_layer._input_emb._w self._wrap_decoder_layer._prepare_decoder_layer._input_emb.weight = self._wrap_encoder_layer._prepare_encoder_layer._input_emb.weight
def forward(self, enc_inputs, dec_inputs, label, weights): def forward(self, enc_inputs, dec_inputs, label, weights):
enc_output = self._wrap_encoder_layer(enc_inputs) enc_output = self._wrap_encoder_layer(enc_inputs)
......
...@@ -368,7 +368,7 @@ class TestLayer(LayerTest): ...@@ -368,7 +368,7 @@ class TestLayer(LayerTest):
filter_size=[2, 2], filter_size=[2, 2],
bias_attr=False) bias_attr=False)
dy_ret = conv2d(base.to_variable(images)) dy_ret = conv2d(base.to_variable(images))
self.assertTrue(conv2d._bias_param is None) self.assertTrue(conv2d.bias is None)
self.assertTrue(np.allclose(static_ret, dy_ret_value)) self.assertTrue(np.allclose(static_ret, dy_ret_value))
self.assertTrue(np.allclose(static_ret, static_ret2)) self.assertTrue(np.allclose(static_ret, static_ret2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册