未验证 提交 a3f9bcb2 编写于 作者: G Guanghua Yu 提交者: GitHub

Modify quantization use tempfile to place the temporary files (#43267)

上级 d5afc1ba
...@@ -20,6 +20,7 @@ import random ...@@ -20,6 +20,7 @@ import random
import unittest import unittest
import logging import logging
import warnings import warnings
import tempfile
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -111,6 +112,16 @@ class ImperativeLenet(fluid.dygraph.Layer): ...@@ -111,6 +112,16 @@ class ImperativeLenet(fluid.dygraph.Layer):
class TestImperativeOutSclae(unittest.TestCase): class TestImperativeOutSclae(unittest.TestCase):
def setUp(self):
self.root_path = tempfile.TemporaryDirectory()
self.param_save_path = os.path.join(self.root_path.name,
"lenet.pdparams")
self.save_path = os.path.join(self.root_path.name,
"lenet_dynamic_outscale_infer_model")
def tearDown(self):
self.root_path.cleanup()
def func_out_scale_acc(self): def func_out_scale_acc(self):
seed = 1000 seed = 1000
lr = 0.001 lr = 0.001
...@@ -138,46 +149,16 @@ class TestImperativeOutSclae(unittest.TestCase): ...@@ -138,46 +149,16 @@ class TestImperativeOutSclae(unittest.TestCase):
loss_list = train_lenet(lenet, reader, adam) loss_list = train_lenet(lenet, reader, adam)
lenet.eval() lenet.eval()
param_save_path = "test_save_quantized_model/lenet.pdparams"
save_dict = lenet.state_dict() save_dict = lenet.state_dict()
paddle.save(save_dict, param_save_path) paddle.save(save_dict, self.param_save_path)
save_path = "./dynamic_outscale_infer_model/lenet"
imperative_out_scale.save_quantized_model(
layer=lenet,
path=save_path,
input_spec=[
paddle.static.InputSpec(shape=[None, 1, 28, 28],
dtype='float32')
])
for i in range(len(loss_list) - 1): for i in range(len(loss_list) - 1):
self.assertTrue(loss_list[i] > loss_list[i + 1], self.assertTrue(loss_list[i] > loss_list[i + 1],
msg='Failed to do the imperative qat.') msg='Failed to do the imperative qat.')
def test_out_scale_acc(self):
with _test_eager_guard():
self.func_out_scale_acc()
self.func_out_scale_acc()
class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
def func_save_quantized_model(self):
lr = 0.001
load_param_path = "test_save_quantized_model/lenet.pdparams"
save_path = "./dynamic_outscale_infer_model_from_checkpoint/lenet"
weight_quantize_type = 'abs_max'
activation_quantize_type = 'moving_average_abs_max'
imperative_out_scale = ImperativeQuantAware(
weight_quantize_type=weight_quantize_type,
activation_quantize_type=activation_quantize_type)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
lenet = ImperativeLenet() lenet = ImperativeLenet()
load_dict = paddle.load(load_param_path) load_dict = paddle.load(self.param_save_path)
imperative_out_scale.quantize(lenet) imperative_out_scale.quantize(lenet)
lenet.set_dict(load_dict) lenet.set_dict(load_dict)
...@@ -191,7 +172,7 @@ class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase): ...@@ -191,7 +172,7 @@ class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
imperative_out_scale.save_quantized_model( imperative_out_scale.save_quantized_model(
layer=lenet, layer=lenet,
path=save_path, path=self.save_path,
input_spec=[ input_spec=[
paddle.static.InputSpec(shape=[None, 1, 28, 28], paddle.static.InputSpec(shape=[None, 1, 28, 28],
dtype='float32') dtype='float32')
...@@ -201,10 +182,10 @@ class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase): ...@@ -201,10 +182,10 @@ class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
self.assertTrue(loss_list[i] > loss_list[i + 1], self.assertTrue(loss_list[i] > loss_list[i + 1],
msg='Failed to do the imperative qat.') msg='Failed to do the imperative qat.')
def test_save_quantized_model(self): def test_out_scale_acc(self):
with _test_eager_guard(): with _test_eager_guard():
self.func_save_quantized_model() self.func_out_scale_acc()
self.func_save_quantized_model() self.func_out_scale_acc()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -22,6 +22,7 @@ import time ...@@ -22,6 +22,7 @@ import time
import unittest import unittest
import copy import copy
import logging import logging
import tempfile
import paddle.nn as nn import paddle.nn as nn
import paddle import paddle
...@@ -73,10 +74,6 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -73,10 +74,6 @@ class TestImperativePTQ(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
cls.root_path = os.path.join(os.getcwd(), "imperative_ptq_" + timestamp)
cls.save_path = os.path.join(cls.root_path, "model")
cls.download_path = 'dygraph_int8/download' cls.download_path = 'dygraph_int8/download'
cls.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + cls.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
cls.download_path) cls.download_path)
...@@ -89,14 +86,6 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -89,14 +86,6 @@ class TestImperativePTQ(unittest.TestCase):
paddle.static.default_main_program().random_seed = seed paddle.static.default_main_program().random_seed = seed
paddle.static.default_startup_program().random_seed = seed paddle.static.default_startup_program().random_seed = seed
@classmethod
def tearDownClass(cls):
try:
pass
# shutil.rmtree(cls.root_path)
except Exception as e:
print("Failed to delete {} due to {}".format(cls.root_path, str(e)))
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format( cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(
...@@ -217,16 +206,18 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -217,16 +206,18 @@ class TestImperativePTQ(unittest.TestCase):
input_spec = [ input_spec = [
paddle.static.InputSpec(shape=[None, 1, 28, 28], dtype='float32') paddle.static.InputSpec(shape=[None, 1, 28, 28], dtype='float32')
] ]
with tempfile.TemporaryDirectory(prefix="imperative_ptq_") as tmpdir:
save_path = os.path.join(tmpdir, "model")
self.ptq.save_quantized_model(model=quant_model, self.ptq.save_quantized_model(model=quant_model,
path=self.save_path, path=save_path,
input_spec=input_spec) input_spec=input_spec)
print('Quantized model saved in {%s}' % self.save_path) print('Quantized model saved in {%s}' % save_path)
after_acc_top1 = self.model_test(quant_model, self.batch_num, after_acc_top1 = self.model_test(quant_model, self.batch_num,
self.batch_size) self.batch_size)
paddle.enable_static() paddle.enable_static()
infer_acc_top1 = self.program_test(self.save_path, self.batch_num, infer_acc_top1 = self.program_test(save_path, self.batch_num,
self.batch_size) self.batch_size)
paddle.disable_static() paddle.disable_static()
...@@ -278,16 +269,18 @@ class TestImperativePTQfuse(TestImperativePTQ): ...@@ -278,16 +269,18 @@ class TestImperativePTQfuse(TestImperativePTQ):
input_spec = [ input_spec = [
paddle.static.InputSpec(shape=[None, 1, 28, 28], dtype='float32') paddle.static.InputSpec(shape=[None, 1, 28, 28], dtype='float32')
] ]
with tempfile.TemporaryDirectory(prefix="imperative_ptq_") as tmpdir:
save_path = os.path.join(tmpdir, "model")
self.ptq.save_quantized_model(model=quant_model, self.ptq.save_quantized_model(model=quant_model,
path=self.save_path, path=save_path,
input_spec=input_spec) input_spec=input_spec)
print('Quantized model saved in {%s}' % self.save_path) print('Quantized model saved in {%s}' % save_path)
after_acc_top1 = self.model_test(quant_model, self.batch_num, after_acc_top1 = self.model_test(quant_model, self.batch_num,
self.batch_size) self.batch_size)
paddle.enable_static() paddle.enable_static()
infer_acc_top1 = self.program_test(self.save_path, self.batch_num, infer_acc_top1 = self.program_test(save_path, self.batch_num,
self.batch_size) self.batch_size)
paddle.disable_static() paddle.disable_static()
......
...@@ -21,6 +21,7 @@ import shutil ...@@ -21,6 +21,7 @@ import shutil
import time import time
import unittest import unittest
import logging import logging
import tempfile
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -46,10 +47,9 @@ class TestImperativeQatAmp(unittest.TestCase): ...@@ -46,10 +47,9 @@ class TestImperativeQatAmp(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) cls.root_path = tempfile.TemporaryDirectory(
cls.root_path = os.path.join(os.getcwd(), prefix="imperative_qat_amp_")
"imperative_qat_amp_" + timestamp) cls.save_path = os.path.join(cls.root_path.name, "model")
cls.save_path = os.path.join(cls.root_path, "model")
cls.download_path = 'dygraph_int8/download' cls.download_path = 'dygraph_int8/download'
cls.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + cls.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
...@@ -65,10 +65,7 @@ class TestImperativeQatAmp(unittest.TestCase): ...@@ -65,10 +65,7 @@ class TestImperativeQatAmp(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try: cls.root_path.cleanup()
shutil.rmtree(cls.root_path)
except Exception as e:
print("Failed to delete {} due to {}".format(cls.root_path, str(e)))
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
......
...@@ -20,6 +20,7 @@ import math ...@@ -20,6 +20,7 @@ import math
import functools import functools
import contextlib import contextlib
import struct import struct
import tempfile
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -38,9 +39,9 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -38,9 +39,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.download_path = 'int8/download' self.download_path = 'int8/download'
self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
self.download_path) self.download_path)
self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) self.root_path = tempfile.TemporaryDirectory()
self.int8_model_path = os.path.join(os.getcwd(), self.int8_model_path = os.path.join(self.root_path.name,
"post_training_" + self.timestamp) "post_training_quantization")
try: try:
os.system("mkdir -p " + self.int8_model_path) os.system("mkdir -p " + self.int8_model_path)
except Exception as e: except Exception as e:
...@@ -49,11 +50,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -49,11 +50,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
sys.exit(-1) sys.exit(-1)
def tearDown(self): def tearDown(self):
try: self.root_path.cleanup()
os.system("rm -rf {}".format(self.int8_model_path))
except Exception as e:
print("Failed to delete {} due to {}".format(
self.int8_model_path, str(e)))
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
......
...@@ -18,6 +18,7 @@ import sys ...@@ -18,6 +18,7 @@ import sys
import random import random
import math import math
import functools import functools
import tempfile
import contextlib import contextlib
import numpy as np import numpy as np
import paddle import paddle
...@@ -34,12 +35,12 @@ np.random.seed(0) ...@@ -34,12 +35,12 @@ np.random.seed(0)
class TestPostTrainingQuantization(unittest.TestCase): class TestPostTrainingQuantization(unittest.TestCase):
def setUp(self): def setUp(self):
self.root_path = tempfile.TemporaryDirectory()
self.int8_model_path = os.path.join(self.root_path.name,
"post_training_quantization")
self.download_path = 'int8/download' self.download_path = 'int8/download'
self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
self.download_path) self.download_path)
self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
self.int8_model_path = os.path.join(os.getcwd(),
"post_training_" + self.timestamp)
try: try:
os.system("mkdir -p " + self.int8_model_path) os.system("mkdir -p " + self.int8_model_path)
except Exception as e: except Exception as e:
...@@ -48,11 +49,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -48,11 +49,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
sys.exit(-1) sys.exit(-1)
def tearDown(self): def tearDown(self):
try: self.root_path.cleanup()
os.system("rm -rf {}".format(self.int8_model_path))
except Exception as e:
print("Failed to delete {} due to {}".format(
self.int8_model_path, str(e)))
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
...@@ -123,7 +120,6 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -123,7 +120,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
scope = fluid.global_scope()
val_reader = paddle.dataset.mnist.train() val_reader = paddle.dataset.mnist.train()
ptq = PostTrainingQuantization(executor=exe, ptq = PostTrainingQuantization(executor=exe,
......
...@@ -19,6 +19,7 @@ import random ...@@ -19,6 +19,7 @@ import random
import math import math
import functools import functools
import contextlib import contextlib
import tempfile
import numpy as np import numpy as np
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
import paddle import paddle
...@@ -150,16 +151,12 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -150,16 +151,12 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.infer_iterations = 50000 if os.environ.get( self.infer_iterations = 50000 if os.environ.get(
'DATASET') == 'full' else 2 'DATASET') == 'full' else 2
self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) self.root_path = tempfile.TemporaryDirectory()
self.int8_model = os.path.join(os.getcwd(), self.int8_model = os.path.join(self.root_path.name,
"post_training_" + self.timestamp) "post_training_quantization")
def tearDown(self): def tearDown(self):
try: self.root_path.cleanup()
os.system("rm -rf {}".format(self.int8_model))
except Exception as e:
print("Failed to delete {} due to {}".format(
self.int8_model, str(e)))
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
......
...@@ -17,6 +17,7 @@ import unittest ...@@ -17,6 +17,7 @@ import unittest
import random import random
import numpy as np import numpy as np
import six import six
import tempfile
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
...@@ -166,15 +167,19 @@ class TestQuantizationScalePass(unittest.TestCase): ...@@ -166,15 +167,19 @@ class TestQuantizationScalePass(unittest.TestCase):
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'quant_scale' + dev_name, marked_nodes) test_graph.draw('.', 'quant_scale' + dev_name, marked_nodes)
with open('quant_scale_model' + dev_name + '.txt', 'w') as f: tempdir = tempfile.TemporaryDirectory()
mapping_table_path = os.path.join(
tempdir.name, 'quant_scale_model' + dev_name + '.txt')
save_path = os.path.join(tempdir.name, 'quant_scale_model' + dev_name)
with open(mapping_table_path, 'w') as f:
f.write(str(server_program)) f.write(str(server_program))
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
fluid.io.save_inference_model('quant_scale_model' + dev_name, fluid.io.save_inference_model(save_path, ['image', 'label'], [loss],
['image', 'label'], [loss],
exe, exe,
server_program, server_program,
clip_extra=True) clip_extra=True)
tempdir.cleanup()
def test_quant_scale_cuda(self): def test_quant_scale_cuda(self):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
......
...@@ -18,6 +18,7 @@ import json ...@@ -18,6 +18,7 @@ import json
import random import random
import numpy as np import numpy as np
import six import six
import tempfile
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
...@@ -110,18 +111,20 @@ class TestUserDefinedQuantization(unittest.TestCase): ...@@ -110,18 +111,20 @@ class TestUserDefinedQuantization(unittest.TestCase):
def get_optimizer(): def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.0001, 0.9) return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
def load_dict(): def load_dict(mapping_table_path):
with open('mapping_table_for_saving_inference_model', 'r') as file: with open(mapping_table_path, 'r') as file:
data = file.read() data = file.read()
data = json.loads(data) data = json.loads(data)
return data return data
def save_dict(Dict): def save_dict(Dict, mapping_table_path):
with open('mapping_table_for_saving_inference_model', 'w') as file: with open(mapping_table_path, 'w') as file:
file.write(json.dumps(Dict)) file.write(json.dumps(Dict))
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
tempdir = tempfile.TemporaryDirectory()
mapping_table_path = os.path.join(tempdir.name, 'inference')
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
...@@ -162,7 +165,7 @@ class TestUserDefinedQuantization(unittest.TestCase): ...@@ -162,7 +165,7 @@ class TestUserDefinedQuantization(unittest.TestCase):
executor=exe) executor=exe)
test_transform_pass.apply(test_graph) test_transform_pass.apply(test_graph)
save_dict(test_graph.out_node_mapping_table) save_dict(test_graph.out_node_mapping_table, mapping_table_path)
add_quant_dequant_pass = AddQuantDequantPass(scope=scope, place=place) add_quant_dequant_pass = AddQuantDequantPass(scope=scope, place=place)
add_quant_dequant_pass.apply(main_graph) add_quant_dequant_pass.apply(main_graph)
...@@ -203,10 +206,11 @@ class TestUserDefinedQuantization(unittest.TestCase): ...@@ -203,10 +206,11 @@ class TestUserDefinedQuantization(unittest.TestCase):
activation_bits=8, activation_bits=8,
weight_quantize_type=weight_quant_type) weight_quantize_type=weight_quant_type)
mapping_table = load_dict() mapping_table = load_dict(mapping_table_path)
test_graph.out_node_mapping_table = mapping_table test_graph.out_node_mapping_table = mapping_table
if act_quantize_func == None and weight_quantize_func == None: if act_quantize_func == None and weight_quantize_func == None:
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
tempdir.cleanup()
def test_act_preprocess_cuda(self): def test_act_preprocess_cuda(self):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import os import os
import six import six
import tempfile
import numpy as np import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -46,8 +47,9 @@ class InferModel(object): ...@@ -46,8 +47,9 @@ class InferModel(object):
class TestBook(unittest.TestCase): class TestBook(unittest.TestCase):
def test_fit_line_inference_model(self): def test_fit_line_inference_model(self):
MODEL_DIR = "./tmp/inference_model" root_path = tempfile.TemporaryDirectory()
UNI_MODEL_DIR = "./tmp/inference_model1" MODEL_DIR = os.path.join(root_path.name, "inference_model")
UNI_MODEL_DIR = os.path.join(root_path.name, "inference_model1")
init_program = Program() init_program = Program()
program = Program() program = Program()
...@@ -118,6 +120,8 @@ class TestBook(unittest.TestCase): ...@@ -118,6 +120,8 @@ class TestBook(unittest.TestCase):
print("fetch %s" % str(model.fetch_vars[0])) print("fetch %s" % str(model.fetch_vars[0]))
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
root_path.cleanup()
self.assertRaises(ValueError, fluid.io.load_inference_model, None, exe, self.assertRaises(ValueError, fluid.io.load_inference_model, None, exe,
model_str, None) model_str, None)
...@@ -125,7 +129,8 @@ class TestBook(unittest.TestCase): ...@@ -125,7 +129,8 @@ class TestBook(unittest.TestCase):
class TestSaveInferenceModel(unittest.TestCase): class TestSaveInferenceModel(unittest.TestCase):
def test_save_inference_model(self): def test_save_inference_model(self):
MODEL_DIR = "./tmp/inference_model2" root_path = tempfile.TemporaryDirectory()
MODEL_DIR = os.path.join(root_path.name, "inference_model2")
init_program = Program() init_program = Program()
program = Program() program = Program()
...@@ -144,9 +149,11 @@ class TestSaveInferenceModel(unittest.TestCase): ...@@ -144,9 +149,11 @@ class TestSaveInferenceModel(unittest.TestCase):
exe.run(init_program, feed={}, fetch_list=[]) exe.run(init_program, feed={}, fetch_list=[])
save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program) save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program)
root_path.cleanup()
def test_save_inference_model_with_auc(self): def test_save_inference_model_with_auc(self):
MODEL_DIR = "./tmp/inference_model4" root_path = tempfile.TemporaryDirectory()
MODEL_DIR = os.path.join(root_path.name, "inference_model4")
init_program = Program() init_program = Program()
program = Program() program = Program()
...@@ -168,6 +175,7 @@ class TestSaveInferenceModel(unittest.TestCase): ...@@ -168,6 +175,7 @@ class TestSaveInferenceModel(unittest.TestCase):
warnings.simplefilter("always") warnings.simplefilter("always")
save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe,
program) program)
root_path.cleanup()
expected_warn = "please ensure that you have set the auc states to zeros before saving inference model" expected_warn = "please ensure that you have set the auc states to zeros before saving inference model"
self.assertTrue(len(w) > 0) self.assertTrue(len(w) > 0)
self.assertTrue(expected_warn == str(w[0].message)) self.assertTrue(expected_warn == str(w[0].message))
...@@ -176,7 +184,8 @@ class TestSaveInferenceModel(unittest.TestCase): ...@@ -176,7 +184,8 @@ class TestSaveInferenceModel(unittest.TestCase):
class TestInstance(unittest.TestCase): class TestInstance(unittest.TestCase):
def test_save_inference_model(self): def test_save_inference_model(self):
MODEL_DIR = "./tmp/inference_model3" root_path = tempfile.TemporaryDirectory()
MODEL_DIR = os.path.join(root_path.name, "inference_model3")
init_program = Program() init_program = Program()
program = Program() program = Program()
...@@ -202,12 +211,14 @@ class TestInstance(unittest.TestCase): ...@@ -202,12 +211,14 @@ class TestInstance(unittest.TestCase):
save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, cp_prog) save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, cp_prog)
self.assertRaises(TypeError, save_inference_model, self.assertRaises(TypeError, save_inference_model,
[MODEL_DIR, ["x", "y"], [avg_cost], [], cp_prog]) [MODEL_DIR, ["x", "y"], [avg_cost], [], cp_prog])
root_path.cleanup()
class TestSaveInferenceModelNew(unittest.TestCase): class TestSaveInferenceModelNew(unittest.TestCase):
def test_save_and_load_inference_model(self): def test_save_and_load_inference_model(self):
MODEL_DIR = "./tmp/inference_model5" root_path = tempfile.TemporaryDirectory()
MODEL_DIR = os.path.join(root_path.name, "inference_model5")
init_program = fluid.default_startup_program() init_program = fluid.default_startup_program()
program = fluid.default_main_program() program = fluid.default_main_program()
...@@ -303,6 +314,7 @@ class TestSaveInferenceModelNew(unittest.TestCase): ...@@ -303,6 +314,7 @@ class TestSaveInferenceModelNew(unittest.TestCase):
model = InferModel(paddle.static.io.load_inference_model( model = InferModel(paddle.static.io.load_inference_model(
MODEL_DIR, exe)) MODEL_DIR, exe))
root_path.cleanup()
outs = exe.run(model.program, outs = exe.run(model.program,
feed={ feed={
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import os import os
import unittest import unittest
import numpy as np import numpy as np
import tempfile
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -93,54 +94,60 @@ class TestConditionalOp(unittest.TestCase): ...@@ -93,54 +94,60 @@ class TestConditionalOp(unittest.TestCase):
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[1, 3, 8, 8], dtype='float32') shape=[1, 3, 8, 8], dtype='float32')
]) ])
paddle.jit.save(net, './while_net') root_path = tempfile.TemporaryDirectory()
model_file = os.path.join(root_path.name, "while_net")
paddle.jit.save(net, model_file)
right_pdmodel = set([ right_pdmodel = set([
"uniform_random", "shape", "slice", "not_equal", "while", "uniform_random", "shape", "slice", "not_equal", "while",
"elementwise_add" "elementwise_add"
]) ])
paddle.enable_static() paddle.enable_static()
pdmodel = getModelOp("while_net.pdmodel") pdmodel = getModelOp(model_file + ".pdmodel")
#print(len(right_pdmodel.difference(pdmodel)))
self.assertTrue( self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0, len(right_pdmodel.difference(pdmodel)) == 0,
"The while op is pruned by mistake.") "The while op is pruned by mistake.")
root_path.cleanup()
def test_for_op(self): def test_for_op(self):
paddle.disable_static() paddle.disable_static()
net = ForNet() net = ForNet()
net = paddle.jit.to_static( net = paddle.jit.to_static(
net, input_spec=[paddle.static.InputSpec(shape=[1], dtype='int32')]) net, input_spec=[paddle.static.InputSpec(shape=[1], dtype='int32')])
paddle.jit.save(net, './for_net') root_path = tempfile.TemporaryDirectory()
model_file = os.path.join(root_path.name, "for_net")
paddle.jit.save(net, model_file)
right_pdmodel = set([ right_pdmodel = set([
"randint", "fill_constant", "cast", "less_than", "while", "randint", "fill_constant", "cast", "less_than", "while",
"elementwise_add" "elementwise_add"
]) ])
paddle.enable_static() paddle.enable_static()
pdmodel = getModelOp("for_net.pdmodel") pdmodel = getModelOp(model_file + ".pdmodel")
#print(len(right_pdmodel.difference(pdmodel)))
self.assertTrue( self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0, len(right_pdmodel.difference(pdmodel)) == 0,
"The for op is pruned by mistake.") "The for op is pruned by mistake.")
root_path.cleanup()
def test_if_op(self): def test_if_op(self):
paddle.disable_static() paddle.disable_static()
net = IfElseNet() net = IfElseNet()
net = paddle.jit.to_static( net = paddle.jit.to_static(
net, input_spec=[paddle.static.InputSpec(shape=[1], dtype='int32')]) net, input_spec=[paddle.static.InputSpec(shape=[1], dtype='int32')])
paddle.jit.save(net, './if_net') root_path = tempfile.TemporaryDirectory()
model_file = os.path.join(root_path.name, "if_net")
paddle.jit.save(net, model_file)
right_pdmodel = set([ right_pdmodel = set([
"assign_value", "greater_than", "cast", "conditional_block", "assign_value", "greater_than", "cast", "conditional_block",
"logical_not", "select_input" "logical_not", "select_input"
]) ])
paddle.enable_static() paddle.enable_static()
pdmodel = getModelOp("if_net.pdmodel") pdmodel = getModelOp(model_file + ".pdmodel")
#print(len(right_pdmodel.difference(pdmodel)))
self.assertTrue( self.assertTrue(
len(right_pdmodel.difference(pdmodel)) == 0, len(right_pdmodel.difference(pdmodel)) == 0,
"The if op is pruned by mistake.") "The if op is pruned by mistake.")
root_path.cleanup()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册