提交 56b5d147 编写于 作者: G guofei 提交者: Huihuang Zheng

Fix the error of init variable in StaticRNN when stop_gradient=ON (#21118)

上级 3c98ec90
...@@ -635,11 +635,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase { ...@@ -635,11 +635,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
RecurrentBase::kOutputs); RecurrentBase::kOutputs);
// In some case the kInitialStates is empty. // In some case the kInitialStates is empty.
if (ctx->HasInputs(RecurrentBase::kInitialStates)) { if (ctx->HasInputs(RecurrentBase::kInitialStates) &&
PADDLE_ENFORCE_EQ(ctx->HasOutputs(framework::GradVarName( ctx->HasOutputs(
RecurrentBase::kInitialStates)), framework::GradVarName(RecurrentBase::kInitialStates))) {
true, "The output of(%s) should not be empty.",
framework::GradVarName(RecurrentBase::kInitialStates));
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInitialStates), ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInitialStates),
ctx->GetInputsDim(RecurrentBase::kInitialStates)); ctx->GetInputsDim(RecurrentBase::kInitialStates));
} }
......
...@@ -123,7 +123,8 @@ class RecurrentOpTest1(unittest.TestCase): ...@@ -123,7 +123,8 @@ class RecurrentOpTest1(unittest.TestCase):
def setUp(self): def setUp(self):
self.setup_program() self.setup_program()
self.data_field = {"x", "h_boot"} self.feed_data_field = {"x", "h_boot"}
self.grad_data_field = self.feed_data_field
self.input_shape = (self.sent_len, self.batch_size, self.input_dim) self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
self.output_shape = (self.sent_len, self.batch_size, self.input_dim) self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
...@@ -161,7 +162,7 @@ class RecurrentOpTest1(unittest.TestCase): ...@@ -161,7 +162,7 @@ class RecurrentOpTest1(unittest.TestCase):
def forward(self): def forward(self):
self.feed_map = { self.feed_map = {
x: create_tensor(getattr(self.py_rnn, x), self.place) x: create_tensor(getattr(self.py_rnn, x), self.place)
for x in self.data_field for x in self.feed_data_field
} }
exe = Executor(self.place) exe = Executor(self.place)
out = exe.run(self.main_program, out = exe.run(self.main_program,
...@@ -173,11 +174,11 @@ class RecurrentOpTest1(unittest.TestCase): ...@@ -173,11 +174,11 @@ class RecurrentOpTest1(unittest.TestCase):
def backward(self): def backward(self):
self.feed_map = { self.feed_map = {
x: create_tensor(getattr(self.py_rnn, x), self.place) x: create_tensor(getattr(self.py_rnn, x), self.place)
for x in self.data_field for x in self.feed_data_field
} }
fetch_list = [ fetch_list = [
self.main_program.global_block().var(grad_var_name(x)) self.main_program.global_block().var(grad_var_name(x))
for x in self.data_field for x in self.grad_data_field
] ]
exe = Executor(self.place) exe = Executor(self.place)
...@@ -195,7 +196,7 @@ class RecurrentOpTest1(unittest.TestCase): ...@@ -195,7 +196,7 @@ class RecurrentOpTest1(unittest.TestCase):
ana_grad = [np.array(x) for x in self.backward()] ana_grad = [np.array(x) for x in self.backward()]
num_grad = self.get_numerical_gradient() num_grad = self.get_numerical_gradient()
for idx, name in enumerate(self.data_field): for idx, name in enumerate(self.grad_data_field):
self.assertEqual(num_grad[idx].shape, ana_grad[idx].shape) self.assertEqual(num_grad[idx].shape, ana_grad[idx].shape)
self.assertTrue( self.assertTrue(
np.isclose( np.isclose(
...@@ -212,7 +213,7 @@ class RecurrentOpTest1(unittest.TestCase): ...@@ -212,7 +213,7 @@ class RecurrentOpTest1(unittest.TestCase):
def get_numerical_gradient(self, delta=0.005): def get_numerical_gradient(self, delta=0.005):
dloss_dout = 1.0 dloss_dout = 1.0
feed_list = [getattr(self.py_rnn, x) for x in self.data_field] feed_list = [getattr(self.py_rnn, x) for x in self.grad_data_field]
grad_list = [np.zeros_like(x) for x in feed_list] grad_list = [np.zeros_like(x) for x in feed_list]
for feed, grad in zip(feed_list, grad_list): for feed, grad in zip(feed_list, grad_list):
for f, g in np.nditer([feed, grad], op_flags=['readwrite']): for f, g in np.nditer([feed, grad], op_flags=['readwrite']):
...@@ -253,7 +254,8 @@ class RecurrentOpTest2(RecurrentOpTest1): ...@@ -253,7 +254,8 @@ class RecurrentOpTest2(RecurrentOpTest1):
def setUp(self): def setUp(self):
self.setup_program() self.setup_program()
self.data_field = {"x", "h_boot", "W", "U"} self.feed_data_field = {"x", "h_boot", "W", "U"}
self.grad_data_field = self.feed_data_field
self.input_shape = (self.sent_len, self.batch_size, self.input_dim) self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
self.output_shape = (self.sent_len, self.batch_size, self.input_dim) self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
...@@ -352,7 +354,8 @@ class RecurrentOpMultipleMemoryTest(RecurrentOpTest1): ...@@ -352,7 +354,8 @@ class RecurrentOpMultipleMemoryTest(RecurrentOpTest1):
def setUp(self): def setUp(self):
self.setup_program() self.setup_program()
self.data_field = {"x", "h_boot1", "h_boot2"} self.feed_data_field = {"x", "h_boot1", "h_boot2"}
self.grad_data_field = self.feed_data_field
self.input_shape = (self.sent_len, self.batch_size, self.input_dim) self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
self.output_shape = (self.sent_len, self.batch_size, self.input_dim) self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
...@@ -435,7 +438,8 @@ class RecurrentOpNoMemBootTest(RecurrentOpTest1): ...@@ -435,7 +438,8 @@ class RecurrentOpNoMemBootTest(RecurrentOpTest1):
def setUp(self): def setUp(self):
self.setup_program() self.setup_program()
self.data_field = {"x"} self.feed_data_field = {"x"}
self.grad_data_field = self.feed_data_field
self.input_shape = (self.sent_len, self.batch_size, self.input_dim) self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
self.output_shape = (self.sent_len, self.batch_size, self.input_dim) self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
...@@ -535,7 +539,8 @@ class RecurrentOpSubBlockTest(RecurrentOpTest1): ...@@ -535,7 +539,8 @@ class RecurrentOpSubBlockTest(RecurrentOpTest1):
def setUp(self): def setUp(self):
self.setup_program() self.setup_program()
self.data_field = {"x", "emb", "w1", "w2"} self.feed_data_field = {"x", "emb", "w1", "w2"}
self.grad_data_field = self.feed_data_field
self.input_shape = (self.sent_len, self.batch_size, self.input_dim) self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
self.output_shape = (self.sent_len, self.batch_size, self.input_dim) self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
...@@ -602,5 +607,76 @@ class RecurrentOpSubBlockTest(RecurrentOpTest1): ...@@ -602,5 +607,76 @@ class RecurrentOpSubBlockTest(RecurrentOpTest1):
return rnn() return rnn()
class RecurrentOpStopGradientTest(RecurrentOpTest1):
"""
Test RNNOp with stop_gradient = True
equation:
h_t = \sigma (W x_t + U h_{t-1})
weights:
- W
- U
vars:
- x
memories:
- h
output:
- h
"""
input_dim = 2
batch_size = 10
sent_len = 2
def setUp(self):
self.setup_program()
self.feed_data_field = {"x", "h_boot", "W", "U"}
self.grad_data_field = {"x", "W", "U"}
self.input_shape = (self.sent_len, self.batch_size, self.input_dim)
self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
self.py_rnn = PySimpleRNN2(self.input_shape, self.output_shape)
with fluid.program_guard(self.main_program, self.startup_program):
self.output = layers.mean(self.create_rnn_op())
def create_rnn_op(self):
x = layers.data(
shape=[self.sent_len, self.batch_size, self.input_dim],
dtype="float32",
name="x",
append_batch_size=False)
x.stop_gradient = False
h_boot = layers.data(
shape=[self.input_dim], dtype="float32", name="h_boot")
h_boot.stop_gradient = True
rnn = layers.StaticRNN()
with rnn.step():
h_pre = rnn.memory(init=h_boot) # init doesn't have gradient
x_t = rnn.step_input(x)
temp_l = layers.fc(
input=x_t,
size=self.input_dim,
param_attr=ParamAttr(
name="W",
initializer=fluid.initializer.ConstantInitializer(1.0)),
bias_attr=False)
temp_r = layers.fc(
input=h_pre,
size=self.input_dim,
param_attr=ParamAttr(
name="U",
initializer=fluid.initializer.ConstantInitializer(0.0)),
bias_attr=False)
h = layers.sigmoid(x=layers.elementwise_add(temp_l, temp_r))
rnn.update_memory(h_pre, h)
rnn.output(h)
return rnn()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册