未验证 提交 8259f141 编写于 作者: C chengduo 提交者: GitHub

Enhance backward process (#18700)

* prun backward ops
test=develop
上级 25c9b57b
......@@ -480,7 +480,6 @@ function assert_api_spec_approvals() {
API_FILES=("CMakeLists.txt"
"paddle/fluid/API.spec"
"paddle/fluid/op_use_default_grad_op_maker.spec"
"python/paddle/fluid/parallel_executor.py"
"paddle/fluid/framework/operator.h"
"paddle/fluid/framework/tensor.h"
"paddle/fluid/framework/details/op_registry.h"
......@@ -495,8 +494,11 @@ function assert_api_spec_approvals() {
"paddle/fluid/framework/ir/graph.h"
"paddle/fluid/framework/framework.proto"
"python/requirements.txt"
"python/paddle/fluid/compiler.py"
"python/paddle/fluid/__init__.py"
"python/paddle/fluid/compiler.py"
"python/paddle/fluid/parallel_executor.py"
"python/paddle/fluid/framework.py"
"python/paddle/fluid/backward.py"
"paddle/fluid/operators/distributed/send_recv.proto.in")
for API_FILE in ${API_FILES[*]}; do
API_CHANGE=`git diff --name-only upstream/$BRANCH | grep "${API_FILE}" | grep -v "/CMakeLists.txt" || true`
......
......@@ -247,6 +247,125 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
return op_descs
def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set):
"""
Pruning Program with Structural Analysis Method of Computational Graph.
The nodes of the computational graph composed of backward OPS should be
interconnected. If there are unconnected sub-graphs in the computational graph,
these sub-graphs should be cut off.
Args:
grad_op_descs(list[core.OpDesc]): The candidate backward OpDescs.
forward_ops(list[Operator]): The forward ops.
input_grad_names_set(set): this set is used to store the gradients' name
which is generated by backward ops, and input_grad_names_set can help
to prune the unnecessary backward ops.
Return:
(list[core.OpDesc]): A list of OpDescs which should be pruned.
"""
class Var(object):
def __init__(self, var_name):
self.var_name = var_name
self.gen_op = None
self.pendding_ops = []
def set_gen_op(self, gen_op):
assert isinstance(gen_op, Op)
assert self.gen_op is None
self.gen_op = gen_op
def add_pending_op(self, op):
assert isinstance(op, Op)
self.pendding_ops.append(op)
class Op(object):
def __init__(self, op_desc):
self.op_desc = op_desc
self.inputs = []
self.outputs = []
def insert_input(self, var):
assert isinstance(var, Var)
self.inputs.append(var)
def insert_output(self, var):
assert isinstance(var, Var)
self.outputs.append(var)
var_versions = dict()
def _create_node(name):
if name not in var_versions.keys():
var_versions[name] = [Var(name)]
else:
var_versions[name].append(Var(name))
return var_versions[name][-1]
def _create_or_get_last_version_node(name):
if name not in var_versions.keys():
var_versions[name] = [Var(name)]
return var_versions[name][-1]
def _create_op_node(op_desc):
op_node = Op(op_desc)
for input in op_desc.input_arg_names():
var = _create_or_get_last_version_node(name=input)
var.add_pending_op(op_node)
op_node.insert_input(var)
for output in op_desc.output_arg_names():
var = _create_node(name=output)
var.set_gen_op(op_node)
op_node.insert_output(var)
return op_node
# Record the forward vars
forward_vars_set = set() if input_grad_names_set is None else set(
input_grad_names_set)
for op in forward_ops:
forward_vars_set.update(op.desc.input_arg_names())
forward_vars_set.update(op.desc.output_arg_names())
# Record the vars which are created during backward and is not generated by op.
backward_vars_set = set()
# special_op_nodes is the candidate sub-graph head node.
special_op_nodes = set()
for op_desc in grad_op_descs:
input_set = set(op_desc.input_arg_names())
# The new_vars are created during backward and is not generated by op.
new_vars = input_set - forward_vars_set - backward_vars_set
backward_vars_set.update(op_desc.output_arg_names())
op_node = _create_op_node(op_desc)
if len(new_vars) == len(input_set):
special_op_nodes.add(op_node)
not_need_op_descs = []
# Start traversing all candidate sub-graph headers to check whether
# they are connected to backward computational graphs, and if they are
# not, list them in not_need_op_descs
for special_op_node in special_op_nodes:
op_list = [special_op_node]
ready_vars = set(special_op_node.inputs)
remove_ops = True
candidate_ops = [special_op_node]
while len(candidate_ops) > 0:
op_node = candidate_ops.pop(0)
if _all_in_set_(op_node.inputs, ready_vars):
for out_var in op_node.outputs:
candidate_ops.extend(out_var.pendding_ops)
op_list.extend(out_var.pendding_ops)
ready_vars.update(op_node.outputs)
else:
remove_ops = False
break
if remove_ops:
not_need_op_descs.extend([node.op_desc for node in op_list])
return set(not_need_op_descs)
from .proto import framework_pb2
......@@ -276,7 +395,10 @@ def _append_backward_ops_(block,
grad_to_var(dict)(output argument):
key(str): grad variable name
val(str): corresponding forward variable name
callback(callable object): a callable object used to decorate new generated grad ops
callbacks(callable object): a callable object used to decorate new generated grad ops
input_grad_names_set(set): this set is used to store the gradients' name which is
generated by backward ops, and input_grad_names_set can help to prune the unnecessary
backward ops.
"""
if callbacks is not None:
assert (isinstance(callbacks, list))
......@@ -342,6 +464,10 @@ def _append_backward_ops_(block,
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx])
not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set)
grad_op_descs = [
op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
]
# append op_desc in grad_op_descs to target_block
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
......
......@@ -30,6 +30,18 @@ def simple_net1():
return loss
def simple_net2():
x = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feature = fluid.layers.fc(input=x, size=10, act=None)
label = fluid.layers.cast(label, dtype="float32")
label = fluid.layers.cast(label, dtype='int64')
# Note that the label is not persistable in fluid.layers.cross_entropy.
loss = fluid.layers.cross_entropy(input=feature, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestBackward(unittest.TestCase):
def check_backward(self, model):
place = fluid.CPUPlace()
......@@ -51,6 +63,7 @@ class TestBackward(unittest.TestCase):
def test_backward(self):
self.check_backward(simple_net1)
self.check_backward(simple_net2)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册