From b2c1be851a3b9e6d01dab1d741fcb05e5fc4b016 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 26 Feb 2020 14:12:09 +0800 Subject: [PATCH] support cond in clone, test=develop (#22657) * support cond in clone, test=develop * refine code, test=develop * refine code, test=develop * follow comments, test=develop * refine code, test=develop --- paddle/fluid/framework/prune.cc | 240 +++++++++++------- paddle/fluid/framework/prune.h | 4 +- paddle/fluid/pybind/pybind.cc | 20 +- python/paddle/fluid/framework.py | 58 +++-- .../tests/unittests/mkldnn/mkldnn_op_test.py | 2 +- .../fluid/tests/unittests/seresnext_net.py | 16 +- .../unittests/test_program_prune_backward.py | 122 ++++++++- 7 files changed, 312 insertions(+), 150 deletions(-) diff --git a/paddle/fluid/framework/prune.cc b/paddle/fluid/framework/prune.cc index c58cb8ad2ac..db9d794b3c6 100644 --- a/paddle/fluid/framework/prune.cc +++ b/paddle/fluid/framework/prune.cc @@ -18,8 +18,10 @@ limitations under the License. */ #include #include +#include #include #include +#include #include #include #include @@ -81,19 +83,50 @@ bool HasFalseTarget(const proto::OpDesc& op_desc) { } int GetSubBlockIndex(const proto::OpDesc& op_desc) { + // The block index >= 0, so -1 is used to indicate "NotFound". for (auto& attr : op_desc.attrs()) { if (attr.type() == proto::AttrType::BLOCK) { - PADDLE_ENFORCE(attr.has_block_idx()); + PADDLE_ENFORCE_EQ(attr.has_block_idx(), true, + platform::errors::NotFound( + "Attribute sub_block is not found in operator %s", + op_desc.type())); return attr.block_idx(); } } return -1; } +void SetSubBlockIndex(proto::OpDesc* op_desc, int sub_idx) { + for (auto& attr : *op_desc->mutable_attrs()) { + if (attr.type() == proto::AttrType::BLOCK) { + PADDLE_ENFORCE_EQ(attr.has_block_idx(), true, + platform::errors::NotFound( + "Attribute sub_block is not found in operator %s", + op_desc->type())); + attr.set_block_idx(sub_idx); + } + } +} + bool HasSubBlock(const proto::OpDesc& op_desc) { return GetSubBlockIndex(op_desc) > 0; } +int GetOpRole(const proto::OpDesc& op_desc) { + // The op role >= 0, so -1 is used to indicate "NotFound". + for (auto& attr : op_desc.attrs()) { + if (attr.name() == OpProtoAndCheckerMaker::OpRoleAttrName()) { + PADDLE_ENFORCE_EQ( + attr.has_i(), true, + platform::errors::NotFound("Attribute %s is empty in operator %s", + OpProtoAndCheckerMaker::OpRoleAttrName(), + op_desc.type())); + return attr.i(); + } + } + return -1; +} + void AppendOpInputVarNames(const proto::OpDesc& op_desc, std::unordered_set* vars_set) { for (auto& var : op_desc.inputs()) { @@ -259,134 +292,159 @@ void Prune(const proto::ProgramDesc& input, prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names); } -void CloneWholeBlock(proto::ProgramDesc* input, proto::ProgramDesc* output, - int block_id, int parent_block_id) { - auto* block_field = output->mutable_blocks(); - *block_field->Add() = input->blocks(block_id); - int output_block_id = output->blocks_size() - 1; - auto* output_block = output->mutable_blocks(output_block_id); - output_block->set_idx(output_block_id); - output_block->set_parent_idx(parent_block_id); +int FindMapByValue(const std::map& m, int val) { + // The content in map should be >= 0, so -1 is used to indicate "NotFound". + for (auto& pair : m) { + if (pair.second == val) { + return pair.first; + } + } + return -1; } -void PruneBackwardImpl(proto::ProgramDesc* input, proto::ProgramDesc* output, - int block_id, int parent_block_id) { - // Step 1. Copy the current input block to output - CloneWholeBlock(input, output, block_id, parent_block_id); - int output_block_id = output->blocks_size() - 1; - auto* output_block = output->mutable_blocks(output_block_id); - - // Step 2. Mark forward ops on main branch - auto* ops = input->mutable_blocks(block_id)->mutable_ops(); +void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) { std::unordered_set op_input_vars; std::unordered_set op_output_vars; - for (auto op_iter = ops->rbegin(); op_iter != ops->rend(); ++op_iter) { - auto& op_desc = *op_iter; - if (HasTrueTarget(op_desc) || - HasDependentOutputVar(op_desc, op_input_vars)) { - op_desc.set_is_target(true); - AppendOpInputVarNames(op_desc, &op_input_vars); - AppendOpOutputVarNames(op_desc, &op_output_vars); - } - } - // Step 3. Mark backward & optimize ops on main branch - std::unordered_set gradop_input_vars; - std::unordered_set gradop_output_vars; + // Step 1. Mark backward, optimize and lrsched ops in the block + auto* ops = origin->mutable_ops(); for (auto op_iter = ops->begin(); op_iter != ops->end(); ++op_iter) { auto& op_desc = *op_iter; - if (HasFalseTarget(op_desc) || - HasDependentInputVar(op_desc, gradop_output_vars)) { + auto op_role = GetOpRole(op_desc); + if (op_role & static_cast(OpRole::kOptimize) || + op_role & static_cast(OpRole::kBackward) || + op_role & static_cast(OpRole::kLRSched)) { op_desc.set_is_target(false); - AppendOpInputVarNames(op_desc, &gradop_input_vars); - AppendOpOutputVarNames(op_desc, &gradop_output_vars); - } - } - - // Step 4. Mark ops need to be reserved on sub-branch - for (auto op_iter = ops->rbegin(); op_iter != ops->rend(); ++op_iter) { - auto& op_desc = *op_iter; - if (!op_desc.has_is_target()) { - if (HasDependentOutputVar(op_desc, gradop_input_vars)) { - op_desc.set_is_target(false); - AppendOpInputVarNames(op_desc, &gradop_input_vars); - } else { - op_desc.set_is_target(true); - AppendOpInputVarNames(op_desc, &op_input_vars); - AppendOpOutputVarNames(op_desc, &op_output_vars); - } } } - // Step 5. Copy the forward ops to new ProgramDesc - // Note: The proto::ProgramDesc doesn't have interface - // to remove op and var - auto* op_field = output_block->mutable_ops(); + // Step 2. Copy the forward ops which have not been set false target to new + // ProgramDesc + // Note: The proto::ProgramDesc doesn't have interface + // to remove op and var + auto* op_field = pruned->mutable_ops(); op_field->Clear(); for (auto op_iter = ops->begin(); op_iter != ops->end(); ++op_iter) { - if (IsTarget(*op_iter)) { + if (!HasFalseTarget(*op_iter)) { auto* op = op_field->Add(); + AppendOpInputVarNames(*op_iter, &op_input_vars); + AppendOpOutputVarNames(*op_iter, &op_output_vars); *op = *op_iter; - if (HasSubBlock(*op)) { - CloneWholeBlock(input, output, GetSubBlockIndex(*op), output_block_id); - } } } - // Step 6. Copy the forward vars to new ProgramDesc + // Step 3. Copy the forward vars to new ProgramDesc, // construct all var's map before clear - auto* var_field = output_block->mutable_vars(); + auto* origin_vars = origin->mutable_vars(); + auto* pruned_vars = pruned->mutable_vars(); std::unordered_map var_map; - for (const auto& var : *var_field) { + for (const auto& var : *origin_vars) { var_map[var.name()] = var; } + pruned_vars->Clear(); + std::unordered_set var_names; var_names.insert(op_input_vars.begin(), op_input_vars.end()); var_names.insert(op_output_vars.begin(), op_output_vars.end()); - var_field->Clear(); for (const auto& name : var_names) { - *var_field->Add() = var_map[name]; + if (var_map.count(name)) { + // NOTE(zhiqiu): For operator in a conditional block, the related vars may + // not exist in current block, but in its futher block. + *pruned_vars->Add() = var_map[name]; + } } -} +} // namespace framework -std::unique_ptr PruneBackward( +std::tuple> PruneBackward( const framework::ProgramDesc& origin) { // Copy original ProgramDesc, origin can't be change framework::ProgramDesc origin_clone(origin); - // Step 1. Update loss op's role & set loss op to be target - // The loss op's op_role is (kForward | kLoss) - // The input ProgramDesc should have loss operator. - auto ops = origin_clone.Block(0).AllOps(); - bool has_loss_op = false; - for (auto op : ops) { - int op_role = - boost::get(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())); - if (op_role == (static_cast(OpRole::kForward) | - static_cast(OpRole::kLoss))) { - op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), - static_cast(OpRole::kForward)); - op->SetIsTarget(true); - has_loss_op = true; - } else if (op_role == (static_cast(OpRole::kBackward) | - static_cast(OpRole::kLoss))) { - op->SetIsTarget(false); - break; + // Step 1. check if the program contains grad loss operator. + // If not, the program need no pruning. + bool has_loss_grad_op = false; + std::queue block_contains_loss; + std::queue block_contains_loss_grad; + for (size_t i = 0; i < origin_clone.Size(); i++) { + auto block_ops = origin_clone.Block(i).AllOps(); + for (auto op : block_ops) { + int op_role = boost::get( + op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())); + if (op_role == (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss))) { + op->SetIsTarget(false); + has_loss_grad_op = true; + break; + } } } - PADDLE_ENFORCE_EQ(has_loss_op, true, - "The Program need to be pruned its backward part" - "should have loss operator."); - // Step 2. Prune backward + std::map pruned_progin_block_id_map; + if (!has_loss_grad_op) { + // No pruning, fast return a copy of the origin ProgramDesc with an empty + // map, means default mapped, i.e.{0:0, 1:1, ..., n:n}. + return std::make_tuple(framework::ProgramDesc(origin_clone), + pruned_progin_block_id_map); + } + proto::ProgramDesc pruned_desc; pruned_desc.clear_blocks(); - PruneBackwardImpl(origin_clone.Proto(), &pruned_desc, 0, -1); + // Step 2. Prune backward for each block. + for (size_t i = 0; i < origin_clone.Size(); i++) { + auto pruned = proto::BlockDesc(); + auto origin = origin_clone.Proto()->mutable_blocks(i); + + PruneBackwardImpl(origin, &pruned); + // If pruned block contains no operator, it means the block is a + // backward block and should be pruned. + // Else, add the block to pruned_desc and update its id & parent_id. + if (pruned.ops_size() > 0) { + auto* block_field = pruned_desc.mutable_blocks(); + *block_field->Add() = pruned; + + auto pruned_block_id = pruned_desc.blocks_size() - 1; + pruned_progin_block_id_map[pruned_block_id] = origin->idx(); + auto* pruned_block = pruned_desc.mutable_blocks(pruned_block_id); + pruned_block->set_idx(pruned_block_id); + + if (origin->parent_idx() == -1) { + pruned_block->set_parent_idx(-1); + } else { + auto parent_idx = + FindMapByValue(pruned_progin_block_id_map, origin->parent_idx()); + PADDLE_ENFORCE_NE(parent_idx, -1, + platform::errors::NotFound( + "The origin parent block id is not found in " + "pruned_progin_block_id_map")); + pruned_block->set_parent_idx(parent_idx); + } + } + } - // Step 3. Contruct new framework::ProgramDesc - return std::unique_ptr( - new framework::ProgramDesc(pruned_desc)); -} + // Step 3. Update subblock attribute for conditional operator. + // This should be performed after all blocks pruned. + for (int i = 0; i < pruned_desc.blocks_size(); i++) { + auto* pruned = pruned_desc.mutable_blocks(i); + auto* ops = pruned->mutable_ops(); + for (auto op_iter = ops->begin(); op_iter != ops->end(); ++op_iter) { + auto& op_desc = *op_iter; + if (HasSubBlock(op_desc)) { + int origin_sub_idx = GetSubBlockIndex(op_desc); + auto sub_idx = + FindMapByValue(pruned_progin_block_id_map, origin_sub_idx); + PADDLE_ENFORCE_NE(sub_idx, -1, + platform::errors::NotFound( + "The origin sub block id is not found in " + "pruned_progin_block_id_map")); + SetSubBlockIndex(&op_desc, sub_idx); + } + } + } + + // Step 4. Return a tuple + return std::make_tuple(framework::ProgramDesc(pruned_desc), + pruned_progin_block_id_map); +} // namespace framework } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/prune.h b/paddle/fluid/framework/prune.h index f710106a263..857006d69c1 100644 --- a/paddle/fluid/framework/prune.h +++ b/paddle/fluid/framework/prune.h @@ -14,9 +14,11 @@ limitations under the License. */ #pragma once +#include #include #include #include +#include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/enforce.h" @@ -28,7 +30,7 @@ void Prune(const proto::ProgramDesc& input, const std::set& feed_var_names, proto::ProgramDesc* output); -std::unique_ptr PruneBackward( +std::tuple> PruneBackward( const framework::ProgramDesc& origin); } // namespace framework diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 8120ac6a004..03078a767b9 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1136,9 +1136,23 @@ All parameter, weight, gradient are variables in Paddle. Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc); return new ProgramDesc(pruned_desc); }); - m.def("prune_backward", [](const framework::ProgramDesc &program) { - return PruneBackward(program); - }); + m.def("prune_backward", + [](const framework::ProgramDesc &program) { + return PruneBackward(program); + }, + R"DOC( + Prune the backward part of a program, mostly called in + program.clone(for_test=True). + + Args: + program (ProgramDesc): The original program. + + Returns: + tuple(ProgramDesc, map): The first part is + the pruned program desc, and the second part is a map + which contains the id pair of pruned block and corresponding + origin block. + )DOC"); m.def("empty_var_name", []() { return std::string(framework::kEmptyVarName); }); m.def("grad_var_suffix", diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 6a8bb211049..a13f6edf8de 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3991,18 +3991,17 @@ class Program(object): The two code snippets above will generate and print same programs. """ + pruned_origin_block_id_map = None if for_test: - if self._appending_grad_times > 0: - forward_prog = Program() - forward_prog.desc = core.prune_backward(self.desc) - forward_prog.blocks = [ - Block(forward_prog, i) - for i in six.moves.range(forward_prog.desc.num_blocks()) - ] - forward_prog._sync_with_cpp() - p = forward_prog._inference_optimize(prune_read_op=False) - else: - p = self._inference_optimize(prune_read_op=False) + forward_prog = Program() + forward_prog.desc, pruned_origin_block_id_map = core.prune_backward( + self.desc) + forward_prog.blocks = [ + Block(forward_prog, i) + for i in six.moves.range(forward_prog.desc.num_blocks()) + ] + forward_prog._sync_with_cpp() + p = forward_prog._inference_optimize(prune_read_op=False) else: p = Program() p.current_block_idx = self.current_block_idx @@ -4019,7 +4018,7 @@ class Program(object): p._sync_with_cpp() p._copy_param_info_from(self) - p._copy_data_info_from(self) + p._copy_data_info_from(self, pruned_origin_block_id_map) p._copy_dist_param_info_from(self) return p @@ -4445,9 +4444,6 @@ class Program(object): raise TypeError("_copy_param_info_from should be invoked with " "Program") - if len(self.blocks) != len(other.blocks): - raise ValueError("_copy_param_info_from should be invoked with two " - "program, with represent the same topology") self.global_block()._copy_param_info_from(other.global_block()) def _copy_dist_param_info_from(self, other): @@ -4470,7 +4466,7 @@ class Program(object): self._ps_endpoint = other._ps_endpoint self._distributed_lookup_table = other._distributed_lookup_table - def _copy_data_info_from(self, other): + def _copy_data_info_from(self, other, pruned_origin_block_id_map=None): """ Copy the information of data variables from other program. @@ -4479,6 +4475,10 @@ class Program(object): Args: other(Program): Other program + pruned_origin_block_id_map(dict{int:int}): A dict which maps the block id in program + self to the block id in program other. For example, {0:0, 1:1, 2:3} means block 0 in self is + cloned from block 0 in other, etc. Default is None, which means default mapped, + {0:0, 1:1,..., n:n}. Returns: None @@ -4487,22 +4487,24 @@ class Program(object): raise TypeError("_copy_data_info_from should be invoked with " "Program") - if len(self.blocks) != len(other.blocks): - raise ValueError("_copy_data_info_from should be invoked with two " - "program, with represent the same topology") + if not pruned_origin_block_id_map: + pruned_origin_block_id_map = { + i: i + for i in six.moves.range(self.desc.num_blocks()) + } # NOTE(zhiqiu): All vars in cloned program exist in original program. # The reverse is not true, due to backward pruning. - for i, block in enumerate(other.blocks): + for i, block in enumerate(self.blocks): + other_block = other.blocks[pruned_origin_block_id_map[i]] for var in list(block.vars.values()): - if not self.blocks[i].has_var(var.name): - continue - if var.is_data: - self.blocks[i].var(var.name).is_data = True - if var.desc.need_check_feed(): - self.blocks[i].var(var.name).desc.set_need_check_feed(True) - if var.stop_gradient: - self.blocks[i].var(var.name).stop_gradient = True + other_var = other_block.var(var.name) + if other_var.is_data: + var.is_data = True + if other_var.desc.need_check_feed(): + var.desc.set_need_check_feed(True) + if other_var.stop_gradient: + var.stop_gradient = True @dygraph_not_support def list_vars(self): diff --git a/python/paddle/fluid/tests/unittests/mkldnn/mkldnn_op_test.py b/python/paddle/fluid/tests/unittests/mkldnn/mkldnn_op_test.py index c47115c466f..ab9dc2455af 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/mkldnn_op_test.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/mkldnn_op_test.py @@ -128,9 +128,9 @@ def check_if_mkldnn_batchnorm_primitives_exist_in_bwd( for arg in grad_op_desc.output_arg_names(): grad_var = block.desc.find_var(arg.encode("ascii")) grad_var.set_dtype(core.VarDesc.VarType.FP32) + program._sync_with_cpp() exe = fluid.Executor(place) - # Do at least 2 iterations for i in range(2): out = exe.run( diff --git a/python/paddle/fluid/tests/unittests/seresnext_net.py b/python/paddle/fluid/tests/unittests/seresnext_net.py index 17ffe39c0a3..ece1e8ce74b 100644 --- a/python/paddle/fluid/tests/unittests/seresnext_net.py +++ b/python/paddle/fluid/tests/unittests/seresnext_net.py @@ -18,7 +18,7 @@ fluid.core._set_eager_deletion_mode(-1, -1, False) import paddle.fluid.layers.ops as ops from paddle.fluid.initializer import init_on_cpu -from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter +from paddle.fluid.layers.learning_rate_scheduler import cosine_decay from simple_nets import init_data import math import os @@ -161,20 +161,6 @@ def SE_ResNeXt50Small(use_feed): return loss -def cosine_decay(learning_rate, step_each_epoch, epochs=120): - """ - Applies cosine decay to the learning rate. - lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1) - """ - global_step = _decay_step_counter() - - with init_on_cpu(): - epoch = ops.floor(global_step / step_each_epoch) - decayed_lr = learning_rate * \ - (ops.cos(epoch * (math.pi / epochs)) + 1)/2 - return decayed_lr - - def optimizer(learning_rate=0.01): optimizer = fluid.optimizer.Momentum( learning_rate=cosine_decay( diff --git a/python/paddle/fluid/tests/unittests/test_program_prune_backward.py b/python/paddle/fluid/tests/unittests/test_program_prune_backward.py index ed259b12f03..bf3aa33886c 100755 --- a/python/paddle/fluid/tests/unittests/test_program_prune_backward.py +++ b/python/paddle/fluid/tests/unittests/test_program_prune_backward.py @@ -71,6 +71,58 @@ def simple_fc_net_with_accuracy(use_feed): return loss +def cond_net(use_feed=None): + x = fluid.layers.data(name="x", shape=[4], dtype='float32') + label = fluid.layers.data('label', shape=[1], dtype='int64') + prediction = fluid.layers.fc(input=x, size=1, act=None) + + def loss1(pred, label): + x = fluid.layers.data(name="x", shape=[4], dtype='float32') + loss = fluid.layers.cross_entropy(input=pred, label=label) + avg_loss = fluid.layers.mean(loss, name='mean_cross_entropy_loss') + return avg_loss + + def loss2(pred, label): + loss = fluid.layers.softmax_with_cross_entropy(logits=pred, label=label) + avg_loss = fluid.layers.mean(loss, name='mean_softmax_loss') + return avg_loss + + two = fluid.layers.fill_constant([1], 'int32', 2) + pred = (two == 0) + avg_loss = fluid.layers.case([(pred, lambda: loss1(prediction, label))], + lambda: loss2(prediction, label)) + return avg_loss + + +def optimization_in_cond_net(with_optimize=False): + x = fluid.layers.data(name="x", shape=[4], dtype='float32') + label = fluid.layers.data('label', shape=[1], dtype='int64') + prediction = fluid.layers.fc(input=x, size=1, act=None) + + def loss1(opt, pred, label, with_optimize): + x = fluid.layers.data(name="x", shape=[4], dtype='float32') + loss = fluid.layers.cross_entropy(input=pred, label=label) + avg_loss = fluid.layers.mean(loss, name='mean_cross_entropy_loss') + if with_optimize: + opt.minimize(avg_loss) + return avg_loss + + def loss2(opt, pred, label, with_optimize): + loss = fluid.layers.softmax_with_cross_entropy(logits=pred, label=label) + avg_loss = fluid.layers.mean(loss, name='mean_softmax_loss') + if with_optimize: + opt.minimize(avg_loss) + return avg_loss + + sgd = fluid.optimizer.SGD(learning_rate=0.1) + two = fluid.layers.fill_constant([1], 'int32', 2) + pred = (two == 0) + avg_loss = fluid.layers.case( + [(pred, lambda: loss1(sgd, prediction, label, with_optimize))], + lambda: loss2(sgd, prediction, label, with_optimize)) + return avg_loss + + class TestProgramPruneBackward(unittest.TestCase): def program_compare(self, program_a, program_b): assert isinstance( @@ -99,19 +151,24 @@ class TestProgramPruneBackward(unittest.TestCase): test_prog_orig = main_program.clone(for_test=True) optimizer().minimize(loss) test_prog_prune = main_program.clone(for_test=True) + self.program_compare(test_prog_orig, test_prog_prune) - place = core.CPUPlace() - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) - loss_data_prune, = exe.run(test_prog_prune, - feed=feed_dict, - fetch_list=[loss.name]) - loss_data_orig, = exe.run(test_prog_orig, - feed=feed_dict, - fetch_list=[loss.name]) - self.assertEqual(loss_data_orig, loss_data_prune) + for place in places: + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + loss_data_prune, = exe.run(test_prog_prune, + feed=feed_dict, + fetch_list=[loss.name]) + loss_data_orig, = exe.run(test_prog_orig, + feed=feed_dict, + fetch_list=[loss.name]) + self.assertEqual(loss_data_orig, loss_data_prune) def test_simple_fc_net(self): def optimizer(): @@ -198,6 +255,48 @@ class TestProgramPruneBackward(unittest.TestCase): self.check_prune_correctness( method=lstm_net, feed_dict=feed_data, optimizer=optimizer) + def test_cond(self): + def optimizer(): + optimizer = fluid.optimizer.SGD(learning_rate=0.01) + return optimizer + + with self.program_scope_guard(): + x_in = np.random.random(size=(10, 4)).astype('float32') + label_in = np.random.randint(1, size=(10, 1)).astype('int64') + feed_dict = {'x': x_in, 'label': label_in} + self.check_prune_correctness( + method=cond_net, feed_dict=feed_dict, optimizer=optimizer) + + def test_optimization_in_cond(self): + x_in = np.random.random(size=(10, 4)).astype('float32') + label_in = np.random.randint(1, size=(10, 1)).astype('int64') + feed_dict = {'x': x_in, 'label': label_in} + with self.program_scope_guard(): + loss = optimization_in_cond_net(False) + main_program = fluid.default_main_program() + test_prog_orig = main_program.clone(for_test=True) + place = core.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + loss_data_orig, = exe.run(test_prog_orig, + feed=feed_dict, + fetch_list=[loss.name]) + + with self.program_scope_guard(): + loss = optimization_in_cond_net(True) + main_program = fluid.default_main_program() + test_prog_prune = main_program.clone(for_test=True) + + place = core.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + loss_data_prune, = exe.run(test_prog_prune, + feed=feed_dict, + fetch_list=[loss.name]) + + self.program_compare(test_prog_orig, test_prog_prune) + self.assertEqual(loss_data_orig, loss_data_prune) + @contextlib.contextmanager def program_scope_guard(self): prog = fluid.Program() @@ -205,7 +304,8 @@ class TestProgramPruneBackward(unittest.TestCase): scope = fluid.core.Scope() with fluid.scope_guard(scope): with fluid.program_guard(prog, startup_prog): - yield + with fluid.unique_name.guard(): + yield if __name__ == '__main__': -- GitLab