未验证 提交 b2c1be85 编写于 作者: L Leo Chen 提交者: GitHub

support cond in clone, test=develop (#22657)

* support cond in clone, test=develop

* refine code, test=develop

* refine code, test=develop

* follow comments, test=develop

* refine code, test=develop
上级 2143bd57
...@@ -18,8 +18,10 @@ limitations under the License. */ ...@@ -18,8 +18,10 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <queue>
#include <set> #include <set>
#include <string> #include <string>
#include <tuple>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -81,19 +83,50 @@ bool HasFalseTarget(const proto::OpDesc& op_desc) { ...@@ -81,19 +83,50 @@ bool HasFalseTarget(const proto::OpDesc& op_desc) {
} }
int GetSubBlockIndex(const proto::OpDesc& op_desc) { int GetSubBlockIndex(const proto::OpDesc& op_desc) {
// The block index >= 0, so -1 is used to indicate "NotFound".
for (auto& attr : op_desc.attrs()) { for (auto& attr : op_desc.attrs()) {
if (attr.type() == proto::AttrType::BLOCK) { if (attr.type() == proto::AttrType::BLOCK) {
PADDLE_ENFORCE(attr.has_block_idx()); PADDLE_ENFORCE_EQ(attr.has_block_idx(), true,
platform::errors::NotFound(
"Attribute sub_block is not found in operator %s",
op_desc.type()));
return attr.block_idx(); return attr.block_idx();
} }
} }
return -1; return -1;
} }
void SetSubBlockIndex(proto::OpDesc* op_desc, int sub_idx) {
for (auto& attr : *op_desc->mutable_attrs()) {
if (attr.type() == proto::AttrType::BLOCK) {
PADDLE_ENFORCE_EQ(attr.has_block_idx(), true,
platform::errors::NotFound(
"Attribute sub_block is not found in operator %s",
op_desc->type()));
attr.set_block_idx(sub_idx);
}
}
}
bool HasSubBlock(const proto::OpDesc& op_desc) { bool HasSubBlock(const proto::OpDesc& op_desc) {
return GetSubBlockIndex(op_desc) > 0; return GetSubBlockIndex(op_desc) > 0;
} }
int GetOpRole(const proto::OpDesc& op_desc) {
// The op role >= 0, so -1 is used to indicate "NotFound".
for (auto& attr : op_desc.attrs()) {
if (attr.name() == OpProtoAndCheckerMaker::OpRoleAttrName()) {
PADDLE_ENFORCE_EQ(
attr.has_i(), true,
platform::errors::NotFound("Attribute %s is empty in operator %s",
OpProtoAndCheckerMaker::OpRoleAttrName(),
op_desc.type()));
return attr.i();
}
}
return -1;
}
void AppendOpInputVarNames(const proto::OpDesc& op_desc, void AppendOpInputVarNames(const proto::OpDesc& op_desc,
std::unordered_set<std::string>* vars_set) { std::unordered_set<std::string>* vars_set) {
for (auto& var : op_desc.inputs()) { for (auto& var : op_desc.inputs()) {
...@@ -259,134 +292,159 @@ void Prune(const proto::ProgramDesc& input, ...@@ -259,134 +292,159 @@ void Prune(const proto::ProgramDesc& input,
prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names); prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names);
} }
void CloneWholeBlock(proto::ProgramDesc* input, proto::ProgramDesc* output, int FindMapByValue(const std::map<int, int>& m, int val) {
int block_id, int parent_block_id) { // The content in map should be >= 0, so -1 is used to indicate "NotFound".
auto* block_field = output->mutable_blocks(); for (auto& pair : m) {
*block_field->Add() = input->blocks(block_id); if (pair.second == val) {
int output_block_id = output->blocks_size() - 1; return pair.first;
auto* output_block = output->mutable_blocks(output_block_id); }
output_block->set_idx(output_block_id); }
output_block->set_parent_idx(parent_block_id); return -1;
} }
void PruneBackwardImpl(proto::ProgramDesc* input, proto::ProgramDesc* output, void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) {
int block_id, int parent_block_id) {
// Step 1. Copy the current input block to output
CloneWholeBlock(input, output, block_id, parent_block_id);
int output_block_id = output->blocks_size() - 1;
auto* output_block = output->mutable_blocks(output_block_id);
// Step 2. Mark forward ops on main branch
auto* ops = input->mutable_blocks(block_id)->mutable_ops();
std::unordered_set<std::string> op_input_vars; std::unordered_set<std::string> op_input_vars;
std::unordered_set<std::string> op_output_vars; std::unordered_set<std::string> op_output_vars;
for (auto op_iter = ops->rbegin(); op_iter != ops->rend(); ++op_iter) {
auto& op_desc = *op_iter;
if (HasTrueTarget(op_desc) ||
HasDependentOutputVar(op_desc, op_input_vars)) {
op_desc.set_is_target(true);
AppendOpInputVarNames(op_desc, &op_input_vars);
AppendOpOutputVarNames(op_desc, &op_output_vars);
}
}
// Step 3. Mark backward & optimize ops on main branch // Step 1. Mark backward, optimize and lrsched ops in the block
std::unordered_set<std::string> gradop_input_vars; auto* ops = origin->mutable_ops();
std::unordered_set<std::string> gradop_output_vars;
for (auto op_iter = ops->begin(); op_iter != ops->end(); ++op_iter) { for (auto op_iter = ops->begin(); op_iter != ops->end(); ++op_iter) {
auto& op_desc = *op_iter; auto& op_desc = *op_iter;
if (HasFalseTarget(op_desc) || auto op_role = GetOpRole(op_desc);
HasDependentInputVar(op_desc, gradop_output_vars)) { if (op_role & static_cast<int>(OpRole::kOptimize) ||
op_desc.set_is_target(false); op_role & static_cast<int>(OpRole::kBackward) ||
AppendOpInputVarNames(op_desc, &gradop_input_vars); op_role & static_cast<int>(OpRole::kLRSched)) {
AppendOpOutputVarNames(op_desc, &gradop_output_vars);
}
}
// Step 4. Mark ops need to be reserved on sub-branch
for (auto op_iter = ops->rbegin(); op_iter != ops->rend(); ++op_iter) {
auto& op_desc = *op_iter;
if (!op_desc.has_is_target()) {
if (HasDependentOutputVar(op_desc, gradop_input_vars)) {
op_desc.set_is_target(false); op_desc.set_is_target(false);
AppendOpInputVarNames(op_desc, &gradop_input_vars);
} else {
op_desc.set_is_target(true);
AppendOpInputVarNames(op_desc, &op_input_vars);
AppendOpOutputVarNames(op_desc, &op_output_vars);
}
} }
} }
// Step 5. Copy the forward ops to new ProgramDesc // Step 2. Copy the forward ops which have not been set false target to new
// ProgramDesc
// Note: The proto::ProgramDesc doesn't have interface // Note: The proto::ProgramDesc doesn't have interface
// to remove op and var // to remove op and var
auto* op_field = output_block->mutable_ops(); auto* op_field = pruned->mutable_ops();
op_field->Clear(); op_field->Clear();
for (auto op_iter = ops->begin(); op_iter != ops->end(); ++op_iter) { for (auto op_iter = ops->begin(); op_iter != ops->end(); ++op_iter) {
if (IsTarget(*op_iter)) { if (!HasFalseTarget(*op_iter)) {
auto* op = op_field->Add(); auto* op = op_field->Add();
AppendOpInputVarNames(*op_iter, &op_input_vars);
AppendOpOutputVarNames(*op_iter, &op_output_vars);
*op = *op_iter; *op = *op_iter;
if (HasSubBlock(*op)) {
CloneWholeBlock(input, output, GetSubBlockIndex(*op), output_block_id);
}
} }
} }
// Step 6. Copy the forward vars to new ProgramDesc // Step 3. Copy the forward vars to new ProgramDesc,
// construct all var's map before clear // construct all var's map before clear
auto* var_field = output_block->mutable_vars(); auto* origin_vars = origin->mutable_vars();
auto* pruned_vars = pruned->mutable_vars();
std::unordered_map<std::string, proto::VarDesc> var_map; std::unordered_map<std::string, proto::VarDesc> var_map;
for (const auto& var : *var_field) { for (const auto& var : *origin_vars) {
var_map[var.name()] = var; var_map[var.name()] = var;
} }
pruned_vars->Clear();
std::unordered_set<std::string> var_names; std::unordered_set<std::string> var_names;
var_names.insert(op_input_vars.begin(), op_input_vars.end()); var_names.insert(op_input_vars.begin(), op_input_vars.end());
var_names.insert(op_output_vars.begin(), op_output_vars.end()); var_names.insert(op_output_vars.begin(), op_output_vars.end());
var_field->Clear();
for (const auto& name : var_names) { for (const auto& name : var_names) {
*var_field->Add() = var_map[name]; if (var_map.count(name)) {
// NOTE(zhiqiu): For operator in a conditional block, the related vars may
// not exist in current block, but in its futher block.
*pruned_vars->Add() = var_map[name];
} }
} }
} // namespace framework
std::unique_ptr<framework::ProgramDesc> PruneBackward( std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward(
const framework::ProgramDesc& origin) { const framework::ProgramDesc& origin) {
// Copy original ProgramDesc, origin can't be change // Copy original ProgramDesc, origin can't be change
framework::ProgramDesc origin_clone(origin); framework::ProgramDesc origin_clone(origin);
// Step 1. Update loss op's role & set loss op to be target // Step 1. check if the program contains grad loss operator.
// The loss op's op_role is (kForward | kLoss) // If not, the program need no pruning.
// The input ProgramDesc should have loss operator. bool has_loss_grad_op = false;
auto ops = origin_clone.Block(0).AllOps(); std::queue<int> block_contains_loss;
bool has_loss_op = false; std::queue<int> block_contains_loss_grad;
for (auto op : ops) { for (size_t i = 0; i < origin_clone.Size(); i++) {
int op_role = auto block_ops = origin_clone.Block(i).AllOps();
boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())); for (auto op : block_ops) {
if (op_role == (static_cast<int>(OpRole::kForward) | int op_role = boost::get<int>(
static_cast<int>(OpRole::kLoss))) { op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), if (op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kForward));
op->SetIsTarget(true);
has_loss_op = true;
} else if (op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss))) { static_cast<int>(OpRole::kLoss))) {
op->SetIsTarget(false); op->SetIsTarget(false);
has_loss_grad_op = true;
break; break;
} }
} }
PADDLE_ENFORCE_EQ(has_loss_op, true, }
"The Program need to be pruned its backward part"
"should have loss operator."); std::map<int, int> pruned_progin_block_id_map;
if (!has_loss_grad_op) {
// No pruning, fast return a copy of the origin ProgramDesc with an empty
// map, means default mapped, i.e.{0:0, 1:1, ..., n:n}.
return std::make_tuple(framework::ProgramDesc(origin_clone),
pruned_progin_block_id_map);
}
// Step 2. Prune backward
proto::ProgramDesc pruned_desc; proto::ProgramDesc pruned_desc;
pruned_desc.clear_blocks(); pruned_desc.clear_blocks();
PruneBackwardImpl(origin_clone.Proto(), &pruned_desc, 0, -1); // Step 2. Prune backward for each block.
for (size_t i = 0; i < origin_clone.Size(); i++) {
auto pruned = proto::BlockDesc();
auto origin = origin_clone.Proto()->mutable_blocks(i);
PruneBackwardImpl(origin, &pruned);
// If pruned block contains no operator, it means the block is a
// backward block and should be pruned.
// Else, add the block to pruned_desc and update its id & parent_id.
if (pruned.ops_size() > 0) {
auto* block_field = pruned_desc.mutable_blocks();
*block_field->Add() = pruned;
auto pruned_block_id = pruned_desc.blocks_size() - 1;
pruned_progin_block_id_map[pruned_block_id] = origin->idx();
auto* pruned_block = pruned_desc.mutable_blocks(pruned_block_id);
pruned_block->set_idx(pruned_block_id);
if (origin->parent_idx() == -1) {
pruned_block->set_parent_idx(-1);
} else {
auto parent_idx =
FindMapByValue(pruned_progin_block_id_map, origin->parent_idx());
PADDLE_ENFORCE_NE(parent_idx, -1,
platform::errors::NotFound(
"The origin parent block id is not found in "
"pruned_progin_block_id_map"));
pruned_block->set_parent_idx(parent_idx);
}
}
}
// Step 3. Contruct new framework::ProgramDesc // Step 3. Update subblock attribute for conditional operator.
return std::unique_ptr<framework::ProgramDesc>( // This should be performed after all blocks pruned.
new framework::ProgramDesc(pruned_desc)); for (int i = 0; i < pruned_desc.blocks_size(); i++) {
} auto* pruned = pruned_desc.mutable_blocks(i);
auto* ops = pruned->mutable_ops();
for (auto op_iter = ops->begin(); op_iter != ops->end(); ++op_iter) {
auto& op_desc = *op_iter;
if (HasSubBlock(op_desc)) {
int origin_sub_idx = GetSubBlockIndex(op_desc);
auto sub_idx =
FindMapByValue(pruned_progin_block_id_map, origin_sub_idx);
PADDLE_ENFORCE_NE(sub_idx, -1,
platform::errors::NotFound(
"The origin sub block id is not found in "
"pruned_progin_block_id_map"));
SetSubBlockIndex(&op_desc, sub_idx);
}
}
}
// Step 4. Return a tuple
return std::make_tuple(framework::ProgramDesc(pruned_desc),
pruned_progin_block_id_map);
} // namespace framework
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,9 +14,11 @@ limitations under the License. */ ...@@ -14,9 +14,11 @@ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <tuple>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -28,7 +30,7 @@ void Prune(const proto::ProgramDesc& input, ...@@ -28,7 +30,7 @@ void Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names, const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output); proto::ProgramDesc* output);
std::unique_ptr<framework::ProgramDesc> PruneBackward( std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward(
const framework::ProgramDesc& origin); const framework::ProgramDesc& origin);
} // namespace framework } // namespace framework
......
...@@ -1136,9 +1136,23 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1136,9 +1136,23 @@ All parameter, weight, gradient are variables in Paddle.
Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc); Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc);
return new ProgramDesc(pruned_desc); return new ProgramDesc(pruned_desc);
}); });
m.def("prune_backward", [](const framework::ProgramDesc &program) { m.def("prune_backward",
[](const framework::ProgramDesc &program) {
return PruneBackward(program); return PruneBackward(program);
}); },
R"DOC(
Prune the backward part of a program, mostly called in
program.clone(for_test=True).
Args:
program (ProgramDesc): The original program.
Returns:
tuple(ProgramDesc, map<int, int>): The first part is
the pruned program desc, and the second part is a map
which contains the id pair of pruned block and corresponding
origin block.
)DOC");
m.def("empty_var_name", m.def("empty_var_name",
[]() { return std::string(framework::kEmptyVarName); }); []() { return std::string(framework::kEmptyVarName); });
m.def("grad_var_suffix", m.def("grad_var_suffix",
......
...@@ -3991,18 +3991,17 @@ class Program(object): ...@@ -3991,18 +3991,17 @@ class Program(object):
The two code snippets above will generate and print same programs. The two code snippets above will generate and print same programs.
""" """
pruned_origin_block_id_map = None
if for_test: if for_test:
if self._appending_grad_times > 0:
forward_prog = Program() forward_prog = Program()
forward_prog.desc = core.prune_backward(self.desc) forward_prog.desc, pruned_origin_block_id_map = core.prune_backward(
self.desc)
forward_prog.blocks = [ forward_prog.blocks = [
Block(forward_prog, i) Block(forward_prog, i)
for i in six.moves.range(forward_prog.desc.num_blocks()) for i in six.moves.range(forward_prog.desc.num_blocks())
] ]
forward_prog._sync_with_cpp() forward_prog._sync_with_cpp()
p = forward_prog._inference_optimize(prune_read_op=False) p = forward_prog._inference_optimize(prune_read_op=False)
else:
p = self._inference_optimize(prune_read_op=False)
else: else:
p = Program() p = Program()
p.current_block_idx = self.current_block_idx p.current_block_idx = self.current_block_idx
...@@ -4019,7 +4018,7 @@ class Program(object): ...@@ -4019,7 +4018,7 @@ class Program(object):
p._sync_with_cpp() p._sync_with_cpp()
p._copy_param_info_from(self) p._copy_param_info_from(self)
p._copy_data_info_from(self) p._copy_data_info_from(self, pruned_origin_block_id_map)
p._copy_dist_param_info_from(self) p._copy_dist_param_info_from(self)
return p return p
...@@ -4445,9 +4444,6 @@ class Program(object): ...@@ -4445,9 +4444,6 @@ class Program(object):
raise TypeError("_copy_param_info_from should be invoked with " raise TypeError("_copy_param_info_from should be invoked with "
"Program") "Program")
if len(self.blocks) != len(other.blocks):
raise ValueError("_copy_param_info_from should be invoked with two "
"program, with represent the same topology")
self.global_block()._copy_param_info_from(other.global_block()) self.global_block()._copy_param_info_from(other.global_block())
def _copy_dist_param_info_from(self, other): def _copy_dist_param_info_from(self, other):
...@@ -4470,7 +4466,7 @@ class Program(object): ...@@ -4470,7 +4466,7 @@ class Program(object):
self._ps_endpoint = other._ps_endpoint self._ps_endpoint = other._ps_endpoint
self._distributed_lookup_table = other._distributed_lookup_table self._distributed_lookup_table = other._distributed_lookup_table
def _copy_data_info_from(self, other): def _copy_data_info_from(self, other, pruned_origin_block_id_map=None):
""" """
Copy the information of data variables from other program. Copy the information of data variables from other program.
...@@ -4479,6 +4475,10 @@ class Program(object): ...@@ -4479,6 +4475,10 @@ class Program(object):
Args: Args:
other(Program): Other program other(Program): Other program
pruned_origin_block_id_map(dict{int:int}): A dict which maps the block id in program
self to the block id in program other. For example, {0:0, 1:1, 2:3} means block 0 in self is
cloned from block 0 in other, etc. Default is None, which means default mapped,
{0:0, 1:1,..., n:n}.
Returns: Returns:
None None
...@@ -4487,22 +4487,24 @@ class Program(object): ...@@ -4487,22 +4487,24 @@ class Program(object):
raise TypeError("_copy_data_info_from should be invoked with " raise TypeError("_copy_data_info_from should be invoked with "
"Program") "Program")
if len(self.blocks) != len(other.blocks): if not pruned_origin_block_id_map:
raise ValueError("_copy_data_info_from should be invoked with two " pruned_origin_block_id_map = {
"program, with represent the same topology") i: i
for i in six.moves.range(self.desc.num_blocks())
}
# NOTE(zhiqiu): All vars in cloned program exist in original program. # NOTE(zhiqiu): All vars in cloned program exist in original program.
# The reverse is not true, due to backward pruning. # The reverse is not true, due to backward pruning.
for i, block in enumerate(other.blocks): for i, block in enumerate(self.blocks):
other_block = other.blocks[pruned_origin_block_id_map[i]]
for var in list(block.vars.values()): for var in list(block.vars.values()):
if not self.blocks[i].has_var(var.name): other_var = other_block.var(var.name)
continue if other_var.is_data:
if var.is_data: var.is_data = True
self.blocks[i].var(var.name).is_data = True if other_var.desc.need_check_feed():
if var.desc.need_check_feed(): var.desc.set_need_check_feed(True)
self.blocks[i].var(var.name).desc.set_need_check_feed(True) if other_var.stop_gradient:
if var.stop_gradient: var.stop_gradient = True
self.blocks[i].var(var.name).stop_gradient = True
@dygraph_not_support @dygraph_not_support
def list_vars(self): def list_vars(self):
......
...@@ -128,9 +128,9 @@ def check_if_mkldnn_batchnorm_primitives_exist_in_bwd( ...@@ -128,9 +128,9 @@ def check_if_mkldnn_batchnorm_primitives_exist_in_bwd(
for arg in grad_op_desc.output_arg_names(): for arg in grad_op_desc.output_arg_names():
grad_var = block.desc.find_var(arg.encode("ascii")) grad_var = block.desc.find_var(arg.encode("ascii"))
grad_var.set_dtype(core.VarDesc.VarType.FP32) grad_var.set_dtype(core.VarDesc.VarType.FP32)
program._sync_with_cpp()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# Do at least 2 iterations # Do at least 2 iterations
for i in range(2): for i in range(2):
out = exe.run( out = exe.run(
......
...@@ -18,7 +18,7 @@ fluid.core._set_eager_deletion_mode(-1, -1, False) ...@@ -18,7 +18,7 @@ fluid.core._set_eager_deletion_mode(-1, -1, False)
import paddle.fluid.layers.ops as ops import paddle.fluid.layers.ops as ops
from paddle.fluid.initializer import init_on_cpu from paddle.fluid.initializer import init_on_cpu
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter from paddle.fluid.layers.learning_rate_scheduler import cosine_decay
from simple_nets import init_data from simple_nets import init_data
import math import math
import os import os
...@@ -161,20 +161,6 @@ def SE_ResNeXt50Small(use_feed): ...@@ -161,20 +161,6 @@ def SE_ResNeXt50Small(use_feed):
return loss return loss
def cosine_decay(learning_rate, step_each_epoch, epochs=120):
"""
Applies cosine decay to the learning rate.
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
"""
global_step = _decay_step_counter()
with init_on_cpu():
epoch = ops.floor(global_step / step_each_epoch)
decayed_lr = learning_rate * \
(ops.cos(epoch * (math.pi / epochs)) + 1)/2
return decayed_lr
def optimizer(learning_rate=0.01): def optimizer(learning_rate=0.01):
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=cosine_decay( learning_rate=cosine_decay(
......
...@@ -71,6 +71,58 @@ def simple_fc_net_with_accuracy(use_feed): ...@@ -71,6 +71,58 @@ def simple_fc_net_with_accuracy(use_feed):
return loss return loss
def cond_net(use_feed=None):
x = fluid.layers.data(name="x", shape=[4], dtype='float32')
label = fluid.layers.data('label', shape=[1], dtype='int64')
prediction = fluid.layers.fc(input=x, size=1, act=None)
def loss1(pred, label):
x = fluid.layers.data(name="x", shape=[4], dtype='float32')
loss = fluid.layers.cross_entropy(input=pred, label=label)
avg_loss = fluid.layers.mean(loss, name='mean_cross_entropy_loss')
return avg_loss
def loss2(pred, label):
loss = fluid.layers.softmax_with_cross_entropy(logits=pred, label=label)
avg_loss = fluid.layers.mean(loss, name='mean_softmax_loss')
return avg_loss
two = fluid.layers.fill_constant([1], 'int32', 2)
pred = (two == 0)
avg_loss = fluid.layers.case([(pred, lambda: loss1(prediction, label))],
lambda: loss2(prediction, label))
return avg_loss
def optimization_in_cond_net(with_optimize=False):
x = fluid.layers.data(name="x", shape=[4], dtype='float32')
label = fluid.layers.data('label', shape=[1], dtype='int64')
prediction = fluid.layers.fc(input=x, size=1, act=None)
def loss1(opt, pred, label, with_optimize):
x = fluid.layers.data(name="x", shape=[4], dtype='float32')
loss = fluid.layers.cross_entropy(input=pred, label=label)
avg_loss = fluid.layers.mean(loss, name='mean_cross_entropy_loss')
if with_optimize:
opt.minimize(avg_loss)
return avg_loss
def loss2(opt, pred, label, with_optimize):
loss = fluid.layers.softmax_with_cross_entropy(logits=pred, label=label)
avg_loss = fluid.layers.mean(loss, name='mean_softmax_loss')
if with_optimize:
opt.minimize(avg_loss)
return avg_loss
sgd = fluid.optimizer.SGD(learning_rate=0.1)
two = fluid.layers.fill_constant([1], 'int32', 2)
pred = (two == 0)
avg_loss = fluid.layers.case(
[(pred, lambda: loss1(sgd, prediction, label, with_optimize))],
lambda: loss2(sgd, prediction, label, with_optimize))
return avg_loss
class TestProgramPruneBackward(unittest.TestCase): class TestProgramPruneBackward(unittest.TestCase):
def program_compare(self, program_a, program_b): def program_compare(self, program_a, program_b):
assert isinstance( assert isinstance(
...@@ -99,9 +151,14 @@ class TestProgramPruneBackward(unittest.TestCase): ...@@ -99,9 +151,14 @@ class TestProgramPruneBackward(unittest.TestCase):
test_prog_orig = main_program.clone(for_test=True) test_prog_orig = main_program.clone(for_test=True)
optimizer().minimize(loss) optimizer().minimize(loss)
test_prog_prune = main_program.clone(for_test=True) test_prog_prune = main_program.clone(for_test=True)
self.program_compare(test_prog_orig, test_prog_prune) self.program_compare(test_prog_orig, test_prog_prune)
place = core.CPUPlace() places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -198,6 +255,48 @@ class TestProgramPruneBackward(unittest.TestCase): ...@@ -198,6 +255,48 @@ class TestProgramPruneBackward(unittest.TestCase):
self.check_prune_correctness( self.check_prune_correctness(
method=lstm_net, feed_dict=feed_data, optimizer=optimizer) method=lstm_net, feed_dict=feed_data, optimizer=optimizer)
def test_cond(self):
def optimizer():
optimizer = fluid.optimizer.SGD(learning_rate=0.01)
return optimizer
with self.program_scope_guard():
x_in = np.random.random(size=(10, 4)).astype('float32')
label_in = np.random.randint(1, size=(10, 1)).astype('int64')
feed_dict = {'x': x_in, 'label': label_in}
self.check_prune_correctness(
method=cond_net, feed_dict=feed_dict, optimizer=optimizer)
def test_optimization_in_cond(self):
x_in = np.random.random(size=(10, 4)).astype('float32')
label_in = np.random.randint(1, size=(10, 1)).astype('int64')
feed_dict = {'x': x_in, 'label': label_in}
with self.program_scope_guard():
loss = optimization_in_cond_net(False)
main_program = fluid.default_main_program()
test_prog_orig = main_program.clone(for_test=True)
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
loss_data_orig, = exe.run(test_prog_orig,
feed=feed_dict,
fetch_list=[loss.name])
with self.program_scope_guard():
loss = optimization_in_cond_net(True)
main_program = fluid.default_main_program()
test_prog_prune = main_program.clone(for_test=True)
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
loss_data_prune, = exe.run(test_prog_prune,
feed=feed_dict,
fetch_list=[loss.name])
self.program_compare(test_prog_orig, test_prog_prune)
self.assertEqual(loss_data_orig, loss_data_prune)
@contextlib.contextmanager @contextlib.contextmanager
def program_scope_guard(self): def program_scope_guard(self):
prog = fluid.Program() prog = fluid.Program()
...@@ -205,6 +304,7 @@ class TestProgramPruneBackward(unittest.TestCase): ...@@ -205,6 +304,7 @@ class TestProgramPruneBackward(unittest.TestCase):
scope = fluid.core.Scope() scope = fluid.core.Scope()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog): with fluid.program_guard(prog, startup_prog):
with fluid.unique_name.guard():
yield yield
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册