diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index fc58d2e1c717304973ad5adb2f7ecaed06d9e179..0dbec7d672967e8051ba9e8bacbd74636b4ce9dd 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -933,6 +933,31 @@ def _append_backward_ops_(block, cb(block=target_block, context=grad_to_var) +def _is_grad_var_(var_name): + return core.grad_var_suffix() in var_name + + +# Find the op who holds the sub_block as its "sub_block" attr +def _find_parent_op_(sub_block): + sub_block_id = sub_block.idx + + if sub_block_id == 0: + return None + + program = sub_block.program + for block_id in six.moves.range(program.num_blocks): + block_desc = program.block(block_id).desc + for op_idx in six.moves.range(block_desc.op_size()): + op = block_desc.op(op_idx) + if op.has_attr("sub_block") and op._block_attr_id( + "sub_block") == sub_block_id: + return op + + # NOTE(paddle-dev): When optimizer is added in conditional block, + # sub_block may not be found. + return None + + def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): """ Create new variables required by backward pass. @@ -948,11 +973,73 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): key(str): forward variable name val(tuple): a tuple of (str, Block), str is the corresponding grad name, Block is the block containing grad variable """ + ops_to_remove = [] + ''' + NOTE(paddle-dev): while_grad op may hold some inputs which are not found + in the parent/forward block, and they are also the outputs of while_grad + op. These kinds of inputs are the recursive outputs inside while_grad op. + They should be considered as "already created" when scanning the inner + ops of while_grad ops. + ''' + parent_op = _find_parent_op_(block) + parent_op_vars = [] + if parent_op is not None: + input_args = parent_op.input_arg_names() + output_args = parent_op.output_arg_names() + for in_arg in input_args: + if in_arg in output_args: + parent_op_vars.append(in_arg) + for op_idx in range(start_op_idx, block.desc.op_size()): op_desc = block.desc.op(op_idx) if op_desc.has_attr("sub_block"): sub_block = block.program.block(op_desc._block_attr_id("sub_block")) _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map) + + grad_var_ins = [ + var for var in op_desc.input_arg_names() if _is_grad_var_(var) + ] + grad_var_outs = [ + var for var in op_desc.output_arg_names() if _is_grad_var_(var) + ] + + inputs = [ + var for var in op_desc.input_arg_names() + if var != core.empty_var_name() + ] + outputs = [ + var for var in op_desc.output_arg_names() + if var != core.empty_var_name() + ] + + # If the outputs of grad op is empty, just remove it + if not outputs: + ops_to_remove.append(op_idx) + continue + else: + ''' + If the output is not empty and there is any grad input, find + whether there is any existing input. If not, just remove it. + ''' + if grad_var_ins: + existing_grad_var_ins = [ + var for var in grad_var_ins + if block.desc.has_var_recursive(cpt.to_bytes(var)) or var in + parent_op_vars + ] + if not existing_grad_var_ins: + ''' + FIXME(paddle-dev, zengjinle): rnn_memory_helper_grad is used + in recurrent op. The input of this op does not even exist in + the program! Therefore, any dependency analysis would not + work to this op! If I do not add the following code, this op + would be pruned, and the calculation result would be wrong. + Maybe we should re-design this op later... + ''' + if op_desc.type() not in ['rnn_memory_helper_grad']: + ops_to_remove.append(op_idx) + continue + new_vars = set() # create new gradient variables for grad_var_name in op_desc.output_arg_names(): @@ -972,6 +1059,9 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): if arg in new_vars: _infer_var_data_type_shape_(arg, block) + for op_idx in reversed(ops_to_remove): + block.desc._remove_op(op_idx, op_idx + 1) + def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): var_map = copy.copy(target_grad_map) diff --git a/python/paddle/fluid/tests/unittests/test_dynamic_rnn_stop_gradient.py b/python/paddle/fluid/tests/unittests/test_dynamic_rnn_stop_gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..243ad4c082ab07507c748d75ee7bcf290bd44518 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dynamic_rnn_stop_gradient.py @@ -0,0 +1,91 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import unittest + + +def build_and_run_program(place, batch_size, beam_size, stop_gradient=False): + fluid.default_startup_program().random_seed = 1 + fluid.default_main_program().random_seed = 1 + np.random.seed(2) + + x = layers.assign( + np.random.rand(batch_size, beam_size, 32).astype("float32")) + indices = fluid.data(shape=[None, beam_size], dtype="int64", name="indices") + step_idx = layers.fill_constant( + shape=[1], dtype="int64", value=0, force_cpu=True) + max_len = layers.fill_constant( + shape=[1], dtype="int64", value=10, force_cpu=True) + cond = layers.less_than(x=step_idx, y=max_len) + while_op = layers.While(cond) + scores = layers.array_write(x, step_idx) + with while_op.block(): + bs = layers.cast(layers.shape(x)[0], "int64") + for _ in range(20): + bs = layers.cast(bs, 'int64') + bs.stop_gradient = stop_gradient + batch_pos = layers.expand( + layers.unsqueeze( + layers.range( + 0, bs, 1, dtype=bs.dtype), [1]), [1, beam_size]) + topk_coordinates = layers.stack([batch_pos, indices], axis=2) + topk_coordinates.stop_gradient = stop_gradient + score = layers.gather_nd(x, topk_coordinates) + layers.increment(x=step_idx, value=1.0, in_place=True) + layers.array_write(score, i=step_idx, array=scores) + length_cond = layers.less_than(x=step_idx, y=max_len) + layers.assign(length_cond, cond) + + out = layers.tensor_array_to_tensor(scores, axis=0, use_stack=True)[0] + loss = layers.reduce_mean(out) + opt = fluid.optimizer.Adam(0.01) + opt.minimize(loss) + exe = fluid.Executor(place) + data = np.random.random_integers( + low=0, high=beam_size - 1, size=(batch_size, beam_size)).astype("int64") + loss_val, = exe.run(feed={"indices": data}, fetch_list=[loss]) + + return loss_val + + +class TestDynRNNStopGradient(unittest.TestCase): + def setUp(self): + self.batch_size = 20 + self.beam_size = 64 + + def run_main(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + with fluid.scope_guard(fluid.Scope()): + value1 = build_and_run_program(place, self.batch_size, + self.beam_size, False) + value2 = build_and_run_program(place, self.batch_size, + self.beam_size, True) + + self.assertTrue(np.array_equal(value1, value2)) + + def test_check_main(self): + places = [fluid.CPUPlace()] + if fluid.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for p in places: + self.run_main(p) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_embedding_id_stop_gradient.py b/python/paddle/fluid/tests/unittests/test_embedding_id_stop_gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..5a562dc14650a74ee6f76fa3d8c5f207da6475d6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_embedding_id_stop_gradient.py @@ -0,0 +1,87 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid +import six +import unittest + + +class TestEmbeddingIdStopGradientBase(unittest.TestCase): + def setUp(self): + self.reshape_times = 1 + self.iteration = 10 + + def get_places(self): + places = [fluid.CPUPlace()] + if fluid.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + return places + + def test_check_grad(self): + for p in self.get_places(): + grad_value1 = self.run_program(p, stop_gradient=False) + grad_value2 = self.run_program(p, stop_gradient=True) + self.assertTrue(np.array_equal(grad_value1, grad_value2)) + + def run_program(self, place, stop_gradient=False): + startup_program = fluid.Program() + main_program = fluid.Program() + + np.random.seed(1) + startup_program.random_seed = 1 + main_program.random_seed = 1 + + scope = fluid.Scope() + with fluid.program_guard(main_program, startup_program): + with fluid.scope_guard(scope): + x_1 = fluid.data(name='x1', shape=[4, 1], dtype='int64') + x_2 = fluid.data(name='x2', shape=[4, 1], dtype='int64') + x = fluid.layers.concat([x_1, x_2], axis=-1) + + for _ in six.moves.range(self.reshape_times): + x = fluid.layers.reshape(x, [-1, 1]) + + x.stop_gradient = stop_gradient + + emb = fluid.embedding(x, size=[10, 32], dtype='float32') + avg_cost = fluid.layers.mean(emb, name='mean_loss') + optim = fluid.optimizer.SGD(learning_rate=0.001) + optim.minimize(avg_cost) + + exe = fluid.Executor(place) + exe.run(startup_program) + + x1_data = np.random.randint(0, 9, x_1.shape).astype('int64') + x2_data = np.random.randint(0, 9, x_2.shape).astype('int64') + + fetch_val = None + for _ in six.moves.range(self.iteration): + fetch_val = exe.run( + feed={x_1.name: x1_data, + x_2.name: x2_data}, + fetch_list=[emb])[0] + + return fetch_val + + +class TestEmbeddingIdStopGradient2(TestEmbeddingIdStopGradientBase): + def setUp(self): + self.reshape_times = 100 + self.iteration = 10 + + +if __name__ == '__main__': + unittest.main()