未验证 提交 f2a3405f 编写于 作者: J JingZhuangzhuang 提交者: GitHub

fix save inference model conditional op (#37579) (#38739)

上级 5925b82e
...@@ -145,6 +145,23 @@ int FindMapByValue(const std::map<int, int>& m, int val) { ...@@ -145,6 +145,23 @@ int FindMapByValue(const std::map<int, int>& m, int val) {
return -1; return -1;
} }
// In other two cases,the op that has feed vars as output vars is dependent:
// 1. op has subblock, like while/for/ifelse/recurrent
// 2. op is in subblock
bool IsSubBlockDependent(const proto::OpDesc& op_desc,
const std::set<std::string>& feed_vars,
int parent_block_id) {
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if ((HasSubBlock(op_desc) || parent_block_id != -1) &&
feed_vars.count(argu) != 0) {
return true;
}
}
}
return false;
}
// block_id is the idx of the current block in the input desc // block_id is the idx of the current block in the input desc
// parent_block_id is the idx of the parent of the current block // parent_block_id is the idx of the parent of the current block
// in the output desc, -1 means the current block is global block // in the output desc, -1 means the current block is global block
...@@ -210,7 +227,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -210,7 +227,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// } // }
if (IsTarget(op_desc) || if (IsTarget(op_desc) ||
(HasDependentOutputVar(op_desc, *dependent_vars) && ((HasDependentOutputVar(op_desc, *dependent_vars) ||
(IsSubBlockDependent(op_desc, feed_var_names, parent_block_id))) &&
(GetOpRole(op_desc) & static_cast<int>(OpRole::kOptimize)) == 0)) { (GetOpRole(op_desc) & static_cast<int>(OpRole::kOptimize)) == 0)) {
// NOTE(zhiqiu): since optimize op takes the trainable parameters as // NOTE(zhiqiu): since optimize op takes the trainable parameters as
// inputs and output, it may introduce wrong dependency graph. // inputs and output, it may introduce wrong dependency graph.
...@@ -227,30 +245,6 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -227,30 +245,6 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
should_run.push_back(true); should_run.push_back(true);
} else { } else {
should_run.push_back(false); should_run.push_back(false);
// If the output of an op modifies feed vars, the op should not clip.
// For example, in the transformer structure, the third parameter returned
// by beam_search op is generally assigned to a feed var. Cutting the
// assign op will cause an error.
if (parent_block_id != -1) {
bool flag = false;
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if (feed_var_names.count(argu)) {
flag = true;
}
}
}
if (flag) {
should_run.back() = true;
// If any op should run, then there inputs are dependent_vars
for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) {
dependent_vars->insert(argu);
}
}
}
}
} }
} }
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.nn.functional as F
def getModelOp(model_path):
model_bytes = paddle.static.load_from_file(model_path)
pg = paddle.static.deserialize_program(model_bytes)
main_block = pg.desc.block(0)
size = main_block.op_size()
result = set()
for i in range(0, size):
#print(main_block.op(i).type())
result.add(main_block.op(i).type())
return result
class WhileNet(paddle.nn.Layer):
def __init__(self):
super(WhileNet, self).__init__()
def forward(self, x):
y = paddle.rand(shape=[1, 3, 4, 4])
w1 = paddle.shape(y)[0]
w2 = paddle.shape(x)[0]
while w2 != w1:
x = F.avg_pool2d(x, kernel_size=3, padding=1, stride=2)
w2 = paddle.shape(x)[0]
return x + y
class ForNet(paddle.nn.Layer):
def __init__(self):
super(ForNet, self).__init__()
def forward(self, x):
y = paddle.randint(low=0, high=5, shape=[1], dtype='int32')
z = paddle.randint(low=0, high=5, shape=[1], dtype='int32')
for i in range(0, z):
x = x + i
return x + y
class IfElseNet(paddle.nn.Layer):
def __init__(self):
super(IfElseNet, self).__init__()
def forward(self, x):
y = paddle.to_tensor([5])
if x > y:
x = x + 1
else:
x = x - 1
return x
class TestConditionalOp(unittest.TestCase):
def test_while_op(self):
paddle.disable_static()
net = WhileNet()
net = paddle.jit.to_static(
net,
input_spec=[
paddle.static.InputSpec(
shape=[1, 3, 8, 8], dtype='float32')
])
paddle.jit.save(net, './while_net')
right_pdmodel = set([
"uniform_random", "shape", "slice", "not_equal", "while",
"elementwise_add"
])
paddle.enable_static()
pdmodel = getModelOp("while_net.pdmodel")
#print(len(right_pdmodel.difference(pdmodel)))
self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0,
"The while op is pruned by mistake.")
def test_for_op(self):
paddle.disable_static()
net = ForNet()
net = paddle.jit.to_static(
net,
input_spec=[paddle.static.InputSpec(
shape=[1], dtype='int32')])
paddle.jit.save(net, './for_net')
right_pdmodel = set([
"randint", "fill_constant", "cast", "less_than", "while",
"elementwise_add"
])
paddle.enable_static()
pdmodel = getModelOp("for_net.pdmodel")
#print(len(right_pdmodel.difference(pdmodel)))
self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0,
"The for op is pruned by mistake.")
def test_if_op(self):
paddle.disable_static()
net = IfElseNet()
net = paddle.jit.to_static(
net,
input_spec=[paddle.static.InputSpec(
shape=[1], dtype='int32')])
paddle.jit.save(net, './if_net')
right_pdmodel = set([
"assign_value", "greater_than", "cast", "conditional_block",
"logical_not", "select_input"
])
paddle.enable_static()
pdmodel = getModelOp("if_net.pdmodel")
#print(len(right_pdmodel.difference(pdmodel)))
self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0,
"The if op is pruned by mistake.")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册