提交 4836ee68 编写于 作者: G Ghost Under Moon 提交者: hong

warning when user save a inference model which contains auc op test=develop (#19838)

上级 5452b6a1
......@@ -1031,6 +1031,15 @@ def save_inference_model(dirname,
main_program = _get_valid_program(main_program)
# remind user to set auc_states to zeros if the program contains auc op
all_ops = main_program.global_block().ops
for op in all_ops:
if op.type == 'auc':
warnings.warn(
"please ensure that you have set the auc states to zeros before saving inference model"
)
break
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
......
......@@ -19,6 +19,8 @@ import unittest
import six
import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
import warnings
import paddle.fluid.executor as executor
import paddle.fluid.layers as layers
......@@ -111,6 +113,33 @@ class TestSaveInferenceModel(unittest.TestCase):
save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program)
def test_save_inference_model_with_auc(self):
MODEL_DIR = "./tmp/inference_model4"
init_program = Program()
program = Program()
# fake program without feed/fetch
with program_guard(program, init_program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[1], dtype='float32')
predict = fluid.layers.fc(input=x, size=2, act='softmax')
acc = fluid.layers.accuracy(input=predict, label=y)
auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict,
label=y)
cost = fluid.layers.cross_entropy(input=predict, label=y)
avg_cost = fluid.layers.mean(x=cost)
place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe,
program)
expected_warn = "please ensure that you have set the auc states to zeros before saving inference model"
self.assertTrue(len(w) > 0)
self.assertTrue(expected_warn == str(w[0].message))
class TestInstance(unittest.TestCase):
def test_save_inference_model(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册