diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 6fe04cef2daf67f61e4af09eecd2ce125ea4a693..9ba0985dcb00d58ca57ec191b8a125c82cdb0f83 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -650,7 +650,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( "The number(%d) of samples of " "current batch is less than the count(%d) of " "devices(%s), currently, it is not allowed. ", - lod_tensors.size(), lod_tensors.size(), + lod_tensors.size(), member_->places_.size(), (is_cpu_place ? "CPU" : "GPU")); if (is_cpu_place) { error_info += diff --git a/paddle/fluid/op_use_default_grad_op_maker.spec b/paddle/fluid/op_use_default_grad_op_maker.spec index a2355d2deee5784f85a65ba32bf1440a55fb6bed..4ec0a35b2900a17f55428bb0e2cea3c9aa69c620 100644 --- a/paddle/fluid/op_use_default_grad_op_maker.spec +++ b/paddle/fluid/op_use_default_grad_op_maker.spec @@ -15,7 +15,6 @@ fusion_seqexpand_concat_fc fusion_seqpool_concat fusion_squared_mat_sub gru -hierarchical_sigmoid lrn lstm_unit max_pool2d_with_index diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 479b839e473591ba57945b496b83b0e76f620534..2b3e2e5c484a1f04c03f0c2482072f0452382aa1 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -86,6 +86,10 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { } }; +/* + * Inputs: X, W, Label, PathTable, PathCode, Bias + * Outputs: Out, PreOut, W_out + */ template class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { public: @@ -162,6 +166,37 @@ Hierarchical Probabilistic Neural Network Language Model." } }; +/* + * Inputs: X, W, Label, PathTable, PathCode, PreOut, Out@GRAD + * Outputs: X@GRAD, W@GRAD, Bias@GRAD + */ +class HierarchicalSigmoidGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType(this->ForwardOpType() + "_grad"); + // Inputs: X, W, Label, PathTable, PathCode, PreOut, Out@GRAD + op->SetInput("X", Input("X")); + op->SetInput("W", Input("W")); + op->SetInput("Bias", Input("Bias")); + op->SetInput("Label", Input("Label")); + op->SetInput("PathTable", Input("PathTable")); + op->SetInput("PathCode", Input("PathCode")); + op->SetInput("PreOut", Output("PreOut")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + + // Outputs: X@GRAD, W@GRAD, Bias@GRAD + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("W"), InputGrad("W")); + op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); + op->SetAttrMap(Attrs()); + + return std::unique_ptr(op); + } +}; + class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -209,17 +244,17 @@ class HierarchicalSigmoidGradOpGradVarTypeInference auto attr = ctx->GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { - VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") - << " is set to SelectedRows"; + VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") + << " is set to SelectedRows"; ctx->SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS); } else { - VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") - << " is set to LoDTensor"; + VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") + << " is set to LoDTensor"; ctx->SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR); } if (hasBias) { - VLOG(30) << "hierarchical_sigmoid_grad op " - << framework::GradVarName("Bias") << " is set to LoDTensor"; + VLOG(3) << "hierarchical_sigmoid_grad op " + << framework::GradVarName("Bias") << " is set to LoDTensor"; ctx->SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR); } ctx->SetDataType(w_grad_var_name, ctx->GetDataType(ctx->Input("W")[0])); @@ -232,7 +267,7 @@ class HierarchicalSigmoidGradOpGradVarTypeInference namespace ops = paddle::operators; REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, ops::HierarchicalSigmoidOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::HierarchicalSigmoidGradMaker); REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp, ops::HierarchicalSigmoidGradOpGradVarTypeInference); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index f5a1b32e5c240933d79a524937b5a8222118fdd9..4eb5b7ad9d1fe128ade904cf61e0178d59b374b8 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -58,10 +58,14 @@ class ScatterGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim(framework::GradVarName("Updates"), - ctx->GetInputDim("Updates")); - ctx->SetOutputDim(framework::GradVarName("X"), - ctx->GetInputDim(framework::GradVarName("Out"))); + if (ctx->HasOutput(framework::GradVarName("Updates"))) { + ctx->SetOutputDim(framework::GradVarName("Updates"), + ctx->GetInputDim("Updates")); + } + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), + ctx->GetInputDim(framework::GradVarName("Out"))); + } } protected: diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu index e9ad347538157342adb24813546e927040b4f9d2..e17617b40da356d74bdffcf53a6c9189d13c64f1 100644 --- a/paddle/fluid/operators/scatter_op.cu +++ b/paddle/fluid/operators/scatter_op.cu @@ -47,12 +47,15 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); auto *Ids = ctx.Input("Ids"); auto *dOut = ctx.Input(framework::GradVarName("Out")); - - // In place gradient: dX = dO - dX->ShareDataWith(*dOut); - dUpdates->mutable_data(ctx.GetPlace()); - // Gradient by Gather: dUpdates = dO[Ids] - GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + if (dX) { + // In place gradient: dX = dO + framework::TensorCopy(*dOut, ctx.GetPlace(), dX); + } + if (dUpdates) { + dUpdates->mutable_data(ctx.GetPlace()); + // Gradient by Gather: dUpdates = dO[Ids] + GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } } }; diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index 9c237dc0f1f115ce76a3b982a8c6ca1dfccb0b87..3b6184de77f4fc05aa2f2900ebc656ed06a8edfc 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -74,11 +74,15 @@ class ScatterGradientOpKernel : public framework::OpKernel { auto *Ids = ctx.Input("Ids"); auto *dOut = ctx.Input(framework::GradVarName("Out")); - // In place gradient: dX = dO - framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); - dUpdates->mutable_data(ctx.GetPlace()); - // Gradient by Gather: dUpdates = dO[Ids] - CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + if (dX) { + // In place gradient: dX = dO + framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); + } + if (dUpdates) { + dUpdates->mutable_data(ctx.GetPlace()); + // Gradient by Gather: dUpdates = dO[Ids] + CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } } }; diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 3771361cd2d87489186c4242b84f5e27d0106852..3e8669f0356a22d24ce8f15f630f449706f0abb3 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -247,6 +247,125 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): return op_descs +def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set): + """ + Pruning Program with Structural Analysis Method of Computational Graph. + The nodes of the computational graph composed of backward OPS should be + interconnected. If there are unconnected sub-graphs in the computational graph, + these sub-graphs should be cut off. + + Args: + grad_op_descs(list[core.OpDesc]): The candidate backward OpDescs. + forward_ops(list[Operator]): The forward ops. + input_grad_names_set(set): this set is used to store the gradients' name + which is generated by backward ops, and input_grad_names_set can help + to prune the unnecessary backward ops. + + Return: + (list[core.OpDesc]): A list of OpDescs which should be pruned. + """ + + class Var(object): + def __init__(self, var_name): + self.var_name = var_name + self.gen_op = None + self.pendding_ops = [] + + def set_gen_op(self, gen_op): + assert isinstance(gen_op, Op) + assert self.gen_op is None + self.gen_op = gen_op + + def add_pending_op(self, op): + assert isinstance(op, Op) + self.pendding_ops.append(op) + + class Op(object): + def __init__(self, op_desc): + self.op_desc = op_desc + self.inputs = [] + self.outputs = [] + + def insert_input(self, var): + assert isinstance(var, Var) + self.inputs.append(var) + + def insert_output(self, var): + assert isinstance(var, Var) + self.outputs.append(var) + + var_versions = dict() + + def _create_node(name): + if name not in var_versions.keys(): + var_versions[name] = [Var(name)] + else: + var_versions[name].append(Var(name)) + return var_versions[name][-1] + + def _create_or_get_last_version_node(name): + if name not in var_versions.keys(): + var_versions[name] = [Var(name)] + return var_versions[name][-1] + + def _create_op_node(op_desc): + op_node = Op(op_desc) + for input in op_desc.input_arg_names(): + var = _create_or_get_last_version_node(name=input) + var.add_pending_op(op_node) + op_node.insert_input(var) + for output in op_desc.output_arg_names(): + var = _create_node(name=output) + var.set_gen_op(op_node) + op_node.insert_output(var) + return op_node + + # Record the forward vars + forward_vars_set = set() if input_grad_names_set is None else set( + input_grad_names_set) + for op in forward_ops: + forward_vars_set.update(op.desc.input_arg_names()) + forward_vars_set.update(op.desc.output_arg_names()) + + # Record the vars which are created during backward and is not generated by op. + backward_vars_set = set() + # special_op_nodes is the candidate sub-graph head node. + special_op_nodes = set() + for op_desc in grad_op_descs: + input_set = set(op_desc.input_arg_names()) + # The new_vars are created during backward and is not generated by op. + new_vars = input_set - forward_vars_set - backward_vars_set + backward_vars_set.update(op_desc.output_arg_names()) + + op_node = _create_op_node(op_desc) + if len(new_vars) == len(input_set): + special_op_nodes.add(op_node) + + not_need_op_descs = [] + # Start traversing all candidate sub-graph headers to check whether + # they are connected to backward computational graphs, and if they are + # not, list them in not_need_op_descs + for special_op_node in special_op_nodes: + op_list = [special_op_node] + ready_vars = set(special_op_node.inputs) + remove_ops = True + candidate_ops = [special_op_node] + while len(candidate_ops) > 0: + op_node = candidate_ops.pop(0) + if _all_in_set_(op_node.inputs, ready_vars): + for out_var in op_node.outputs: + candidate_ops.extend(out_var.pendding_ops) + op_list.extend(out_var.pendding_ops) + ready_vars.update(op_node.outputs) + else: + remove_ops = False + break + if remove_ops: + not_need_op_descs.extend([node.op_desc for node in op_list]) + + return set(not_need_op_descs) + + from .proto import framework_pb2 @@ -276,7 +395,10 @@ def _append_backward_ops_(block, grad_to_var(dict)(output argument): key(str): grad variable name val(str): corresponding forward variable name - callback(callable object): a callable object used to decorate new generated grad ops + callbacks(callable object): a callable object used to decorate new generated grad ops + input_grad_names_set(set): this set is used to store the gradients' name which is + generated by backward ops, and input_grad_names_set can help to prune the unnecessary + backward ops. """ if callbacks is not None: assert (isinstance(callbacks, list)) @@ -342,6 +464,10 @@ def _append_backward_ops_(block, grad_op_descs = _remove_no_grad_branch_(grad_op_descs, no_grad_dict[block.idx]) + not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set) + grad_op_descs = [ + op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops + ] # append op_desc in grad_op_descs to target_block op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() backward = core.op_proto_and_checker_maker.OpRole.Backward @@ -552,7 +678,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0])) op_path = _find_op_path_(root_block, [loss], [], block_no_grad_set) - + no_grad_vars = _find_no_grad_vars(root_block, op_path, [loss], + block_no_grad_set) + block_no_grad_set.update(no_grad_vars) no_grad_dict[0].update(list(map(_append_grad_suffix_, block_no_grad_set))) input_grad_names_set = None @@ -630,6 +758,26 @@ def _as_list(x): return list(x) if isinstance(x, collections.Sequence) else [x] +def _find_no_grad_vars(block, op_path, targets, no_grad_set): + """ + Find the vars which is not used in the program, and + those var belong to no_grad_var. + """ + output_names = set([out.name for out in targets]) + no_grad_var = [] + for i, op in reversed(list(enumerate(op_path))): + # If the op has sub_block, it is too complicated to find the correct no_grad_var. + if not op.has_attr("sub_block"): + for out_var in op.desc.output_arg_names(): + if out_var not in output_names and out_var not in op.desc.input_arg_names( + ) and not block.vars[out_var].stop_gradient: + no_grad_var.append(out_var) + for name in op.desc.input_arg_names(): + if name not in no_grad_set: + output_names.add(name) + return set(no_grad_var) + + def _find_op_path_(block, outputs, inputs, no_grad_set): """ no_grad_set will also be changed diff --git a/python/paddle/fluid/tests/unittests/test_backward.py b/python/paddle/fluid/tests/unittests/test_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..e5f4b47f7d4dca4f079917d505a0ce249f3241e7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_backward.py @@ -0,0 +1,70 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid +from simple_nets import init_data + + +def simple_net1(): + x = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + feature = fluid.layers.fc(input=x, size=20, act=None) + part1, part2 = fluid.layers.split(feature, num_or_sections=[10, 10], dim=1) + # Note that: part2 is not used. + loss = fluid.layers.cross_entropy(input=part1, label=label) + loss = fluid.layers.mean(loss) + return loss + + +def simple_net2(): + x = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + feature = fluid.layers.fc(input=x, size=10, act=None) + label = fluid.layers.cast(label, dtype="float32") + label = fluid.layers.cast(label, dtype='int64') + # Note that the label is not persistable in fluid.layers.cross_entropy. + loss = fluid.layers.cross_entropy(input=feature, label=label) + loss = fluid.layers.mean(loss) + return loss + + +class TestBackward(unittest.TestCase): + def check_backward(self, model): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + main = fluid.Program() + startup = fluid.Program() + batch_size = 2 + + with fluid.program_guard(main, startup): + loss = model() + + optimizer = fluid.optimizer.SGD(learning_rate=0.1) + optimizer.minimize(loss) + + exe.run(fluid.default_startup_program()) + img, label = init_data(batch_size, img_shape=[784], label_range=9) + exe.run(feed={'image': img, 'label': label}) + + def test_backward(self): + self.check_backward(simple_net1) + self.check_backward(simple_net2) + + +if __name__ == '__main__': + unittest.main()