You need to sign in or sign up before continuing.
未验证 提交 039bb505 编写于 作者: Z Zeng Jinle 提交者: GitHub

Polish backward.py to prune more ops (#22246)

* polish backward prune, test=develop

* fix control flow op bug, test=develop

* add some unittests, test=develop

* fix unittest args, test=develop

* follow huihuang's comments, test=develop
上级 7b2c0993
...@@ -933,6 +933,31 @@ def _append_backward_ops_(block, ...@@ -933,6 +933,31 @@ def _append_backward_ops_(block,
cb(block=target_block, context=grad_to_var) cb(block=target_block, context=grad_to_var)
def _is_grad_var_(var_name):
return core.grad_var_suffix() in var_name
# Find the op who holds the sub_block as its "sub_block" attr
def _find_parent_op_(sub_block):
sub_block_id = sub_block.idx
if sub_block_id == 0:
return None
program = sub_block.program
for block_id in six.moves.range(program.num_blocks):
block_desc = program.block(block_id).desc
for op_idx in six.moves.range(block_desc.op_size()):
op = block_desc.op(op_idx)
if op.has_attr("sub_block") and op._block_attr_id(
"sub_block") == sub_block_id:
return op
# NOTE(paddle-dev): When optimizer is added in conditional block,
# sub_block may not be found.
return None
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
""" """
Create new variables required by backward pass. Create new variables required by backward pass.
...@@ -948,11 +973,73 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): ...@@ -948,11 +973,73 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
key(str): forward variable name key(str): forward variable name
val(tuple): a tuple of (str, Block), str is the corresponding grad name, Block is the block containing grad variable val(tuple): a tuple of (str, Block), str is the corresponding grad name, Block is the block containing grad variable
""" """
ops_to_remove = []
'''
NOTE(paddle-dev): while_grad op may hold some inputs which are not found
in the parent/forward block, and they are also the outputs of while_grad
op. These kinds of inputs are the recursive outputs inside while_grad op.
They should be considered as "already created" when scanning the inner
ops of while_grad ops.
'''
parent_op = _find_parent_op_(block)
parent_op_vars = []
if parent_op is not None:
input_args = parent_op.input_arg_names()
output_args = parent_op.output_arg_names()
for in_arg in input_args:
if in_arg in output_args:
parent_op_vars.append(in_arg)
for op_idx in range(start_op_idx, block.desc.op_size()): for op_idx in range(start_op_idx, block.desc.op_size()):
op_desc = block.desc.op(op_idx) op_desc = block.desc.op(op_idx)
if op_desc.has_attr("sub_block"): if op_desc.has_attr("sub_block"):
sub_block = block.program.block(op_desc._block_attr_id("sub_block")) sub_block = block.program.block(op_desc._block_attr_id("sub_block"))
_append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map) _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map)
grad_var_ins = [
var for var in op_desc.input_arg_names() if _is_grad_var_(var)
]
grad_var_outs = [
var for var in op_desc.output_arg_names() if _is_grad_var_(var)
]
inputs = [
var for var in op_desc.input_arg_names()
if var != core.empty_var_name()
]
outputs = [
var for var in op_desc.output_arg_names()
if var != core.empty_var_name()
]
# If the outputs of grad op is empty, just remove it
if not outputs:
ops_to_remove.append(op_idx)
continue
else:
'''
If the output is not empty and there is any grad input, find
whether there is any existing input. If not, just remove it.
'''
if grad_var_ins:
existing_grad_var_ins = [
var for var in grad_var_ins
if block.desc.has_var_recursive(cpt.to_bytes(var)) or var in
parent_op_vars
]
if not existing_grad_var_ins:
'''
FIXME(paddle-dev, zengjinle): rnn_memory_helper_grad is used
in recurrent op. The input of this op does not even exist in
the program! Therefore, any dependency analysis would not
work to this op! If I do not add the following code, this op
would be pruned, and the calculation result would be wrong.
Maybe we should re-design this op later...
'''
if op_desc.type() not in ['rnn_memory_helper_grad']:
ops_to_remove.append(op_idx)
continue
new_vars = set() new_vars = set()
# create new gradient variables # create new gradient variables
for grad_var_name in op_desc.output_arg_names(): for grad_var_name in op_desc.output_arg_names():
...@@ -972,6 +1059,9 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): ...@@ -972,6 +1059,9 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
if arg in new_vars: if arg in new_vars:
_infer_var_data_type_shape_(arg, block) _infer_var_data_type_shape_(arg, block)
for op_idx in reversed(ops_to_remove):
block.desc._remove_op(op_idx, op_idx + 1)
def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
var_map = copy.copy(target_grad_map) var_map = copy.copy(target_grad_map)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import unittest
def build_and_run_program(place, batch_size, beam_size, stop_gradient=False):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
np.random.seed(2)
x = layers.assign(
np.random.rand(batch_size, beam_size, 32).astype("float32"))
indices = fluid.data(shape=[None, beam_size], dtype="int64", name="indices")
step_idx = layers.fill_constant(
shape=[1], dtype="int64", value=0, force_cpu=True)
max_len = layers.fill_constant(
shape=[1], dtype="int64", value=10, force_cpu=True)
cond = layers.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
scores = layers.array_write(x, step_idx)
with while_op.block():
bs = layers.cast(layers.shape(x)[0], "int64")
for _ in range(20):
bs = layers.cast(bs, 'int64')
bs.stop_gradient = stop_gradient
batch_pos = layers.expand(
layers.unsqueeze(
layers.range(
0, bs, 1, dtype=bs.dtype), [1]), [1, beam_size])
topk_coordinates = layers.stack([batch_pos, indices], axis=2)
topk_coordinates.stop_gradient = stop_gradient
score = layers.gather_nd(x, topk_coordinates)
layers.increment(x=step_idx, value=1.0, in_place=True)
layers.array_write(score, i=step_idx, array=scores)
length_cond = layers.less_than(x=step_idx, y=max_len)
layers.assign(length_cond, cond)
out = layers.tensor_array_to_tensor(scores, axis=0, use_stack=True)[0]
loss = layers.reduce_mean(out)
opt = fluid.optimizer.Adam(0.01)
opt.minimize(loss)
exe = fluid.Executor(place)
data = np.random.random_integers(
low=0, high=beam_size - 1, size=(batch_size, beam_size)).astype("int64")
loss_val, = exe.run(feed={"indices": data}, fetch_list=[loss])
return loss_val
class TestDynRNNStopGradient(unittest.TestCase):
def setUp(self):
self.batch_size = 20
self.beam_size = 64
def run_main(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
value1 = build_and_run_program(place, self.batch_size,
self.beam_size, False)
value2 = build_and_run_program(place, self.batch_size,
self.beam_size, True)
self.assertTrue(np.array_equal(value1, value2))
def test_check_main(self):
places = [fluid.CPUPlace()]
if fluid.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.run_main(p)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
import six
import unittest
class TestEmbeddingIdStopGradientBase(unittest.TestCase):
def setUp(self):
self.reshape_times = 1
self.iteration = 10
def get_places(self):
places = [fluid.CPUPlace()]
if fluid.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
return places
def test_check_grad(self):
for p in self.get_places():
grad_value1 = self.run_program(p, stop_gradient=False)
grad_value2 = self.run_program(p, stop_gradient=True)
self.assertTrue(np.array_equal(grad_value1, grad_value2))
def run_program(self, place, stop_gradient=False):
startup_program = fluid.Program()
main_program = fluid.Program()
np.random.seed(1)
startup_program.random_seed = 1
main_program.random_seed = 1
scope = fluid.Scope()
with fluid.program_guard(main_program, startup_program):
with fluid.scope_guard(scope):
x_1 = fluid.data(name='x1', shape=[4, 1], dtype='int64')
x_2 = fluid.data(name='x2', shape=[4, 1], dtype='int64')
x = fluid.layers.concat([x_1, x_2], axis=-1)
for _ in six.moves.range(self.reshape_times):
x = fluid.layers.reshape(x, [-1, 1])
x.stop_gradient = stop_gradient
emb = fluid.embedding(x, size=[10, 32], dtype='float32')
avg_cost = fluid.layers.mean(emb, name='mean_loss')
optim = fluid.optimizer.SGD(learning_rate=0.001)
optim.minimize(avg_cost)
exe = fluid.Executor(place)
exe.run(startup_program)
x1_data = np.random.randint(0, 9, x_1.shape).astype('int64')
x2_data = np.random.randint(0, 9, x_2.shape).astype('int64')
fetch_val = None
for _ in six.moves.range(self.iteration):
fetch_val = exe.run(
feed={x_1.name: x1_data,
x_2.name: x2_data},
fetch_list=[emb])[0]
return fetch_val
class TestEmbeddingIdStopGradient2(TestEmbeddingIdStopGradientBase):
def setUp(self):
self.reshape_times = 100
self.iteration = 10
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册