提交 6720681c 编写于 作者: K kexinzhao 提交者: Yiqun Liu

Enable is_test attr of batch norm and drop out op for test program (#8642)

* fix is_test issue

* add paddle enforce

* fix bug

* add new func

* small fix

* address comments
上级 f45a82be
...@@ -27,8 +27,6 @@ namespace framework { ...@@ -27,8 +27,6 @@ namespace framework {
const std::string kFeedOpType = "feed"; const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch"; const std::string kFetchOpType = "fetch";
const std::string kDropOutOpType = "dropout";
const std::string kBatchNormOpType = "batch_norm";
bool HasDependentVar(const proto::OpDesc& op_desc, bool HasDependentVar(const proto::OpDesc& op_desc,
const std::set<std::string>& dependent_vars) { const std::set<std::string>& dependent_vars) {
...@@ -186,13 +184,9 @@ void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) { ...@@ -186,13 +184,9 @@ void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
prune_impl(input, output, 0, -1, dependent_vars); prune_impl(input, output, 0, -1, dependent_vars);
} }
void inference_optimize_impl(const proto::ProgramDesc& input, void inference_optimize_impl(proto::ProgramDesc* input, int block_id) {
proto::ProgramDesc* output, int block_id) { auto* op_field = input->mutable_blocks(block_id)->mutable_ops();
*output = input;
auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
for (auto& op_desc : *op_field) { for (auto& op_desc : *op_field) {
if (op_desc.type() == kDropOutOpType ||
op_desc.type() == kBatchNormOpType) {
for (auto& attr : *op_desc.mutable_attrs()) { for (auto& attr : *op_desc.mutable_attrs()) {
if (attr.name() == "is_test") { if (attr.name() == "is_test") {
attr.set_b(true); attr.set_b(true);
...@@ -200,12 +194,16 @@ void inference_optimize_impl(const proto::ProgramDesc& input, ...@@ -200,12 +194,16 @@ void inference_optimize_impl(const proto::ProgramDesc& input,
} }
} }
} }
}
} }
void InferenceOptimize(const proto::ProgramDesc& input, void InferenceOptimize(const proto::ProgramDesc& input,
proto::ProgramDesc* output) { proto::ProgramDesc* output) {
inference_optimize_impl(input, output, 0); *output = input;
int num_blocks = output->blocks_size();
PADDLE_ENFORCE_GT(num_blocks, 0, "ProgramDesc must have at least one block");
for (int i = 0; i < num_blocks; ++i) {
inference_optimize_impl(output, i);
}
} }
} // namespace framework } // namespace framework
......
...@@ -956,8 +956,25 @@ class Program(object): ...@@ -956,8 +956,25 @@ class Program(object):
def get_desc(self): def get_desc(self):
return self.desc return self.desc
def clone(self): def clone(self, for_test=False):
"""Clone the Program object
Set for_test to False when we want to clone the program for training.
Set for_test to True when we want to clone the program for testing.
Args:
for_test(bool): Some operators, such as batch_norm and drop_out ops,
behave differently in training and testing. If for_test is True,
the is_test attributes in these operators will be set to True for
testing purposes, otherwise, they remain unchanged.
Returns(Program):
The cloned Program object.
"""
p = Program() p = Program()
if for_test:
p.desc = core.inference_optimize(self.desc)
else:
p.desc = core.ProgramDesc(self.desc) p.desc = core.ProgramDesc(self.desc)
p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())] p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())]
p.sync_with_cpp() p.sync_with_cpp()
......
...@@ -115,7 +115,7 @@ def train(net_type, use_cuda, save_dirname, is_local): ...@@ -115,7 +115,7 @@ def train(net_type, use_cuda, save_dirname, is_local):
acc = fluid.layers.accuracy(input=predict, label=label) acc = fluid.layers.accuracy(input=predict, label=label)
# Test program # Test program
test_program = fluid.default_main_program().clone() test_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001) optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimize_ops, params_grads = optimizer.minimize(avg_cost) optimize_ops, params_grads = optimizer.minimize(avg_cost)
......
...@@ -92,7 +92,7 @@ def train(nn_type, ...@@ -92,7 +92,7 @@ def train(nn_type,
else: else:
prediction, avg_loss, acc = net_conf(img, label) prediction, avg_loss, acc = net_conf(img, label)
test_program = fluid.default_main_program().clone() test_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001) optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimize_ops, params_grads = optimizer.minimize(avg_loss) optimize_ops, params_grads = optimizer.minimize(avg_loss)
......
...@@ -157,7 +157,7 @@ def train(use_cuda, save_dirname, is_local=True): ...@@ -157,7 +157,7 @@ def train(use_cuda, save_dirname, is_local=True):
scale_infer, avg_cost = model() scale_infer, avg_cost = model()
# test program # test program
test_program = fluid.default_main_program().clone() test_program = fluid.default_main_program().clone(for_test=True)
sgd_optimizer = SGDOptimizer(learning_rate=0.2) sgd_optimizer = SGDOptimizer(learning_rate=0.2)
optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册