提交 09954701 编写于 作者: W wanghaoshuang

Add save checkpoint function to sensitive pruner.

上级 0e2b3039
......@@ -35,6 +35,7 @@ add_arg('config_file', str, None, "The config file for comp
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('checkpoints', str, "./checkpoints", "Checkpoints path.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
......@@ -188,17 +189,25 @@ def compress(args):
def eval_func(program):
return test(0, program)
pruner = SensitivePruner(place, eval_func)
if args.data == "mnist":
train(0, fluid.default_main_program())
pruner = SensitivePruner(place, eval_func, checkpoints=args.checkpoints)
pruned_program, pruned_val_program, iter = pruner.restore()
if pruned_program is None:
pruned_program = fluid.default_main_program()
if pruned_val_program is None:
pruned_val_program = val_program
for iter in range(6):
start = iter
end = 6
for iter in range(start, end):
pruned_program, pruned_val_program = pruner.prune(
pruned_program, pruned_val_program, params, 0.1)
train(iter, pruned_program)
test(iter, pruned_val_program)
pruner.save_checkpoint(pruned_program, pruned_val_program)
print("before flops: {}".format(flops(fluid.default_main_program())))
print("after flops: {}".format(flops(pruned_val_program)))
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import copy
from scipy.optimize import leastsq
......@@ -28,7 +29,7 @@ _logger = get_logger(__name__, level=logging.INFO)
class SensitivePruner(object):
def __init__(self, place, eval_func, scope=None):
def __init__(self, place, eval_func, scope=None, checkpoints=None):
"""
Pruner used to prune parameters iteratively according to sensitivities of parameters in each step.
Args:
......@@ -41,6 +42,50 @@ class SensitivePruner(object):
self._place = place
self._scope = fluid.global_scope() if scope is None else scope
self._pruner = Pruner()
self._checkpoints = checkpoints
def save_checkpoint(self, train_program, eval_program):
checkpoint = os.path.join(self._checkpoints, str(self._iter - 1))
exe = fluid.Executor(self._place)
fluid.io.save_persistables(
exe, checkpoint, main_program=train_program, filename="__params__")
with open(checkpoint + "/main_program", "wb") as f:
f.write(train_program.desc.serialize_to_string())
with open(checkpoint + "/eval_program", "wb") as f:
f.write(eval_program.desc.serialize_to_string())
def restore(self, checkpoints=None):
exe = fluid.Executor(self._place)
checkpoints = self._checkpoints if checkpoints is None else checkpoints
print("check points: {}".format(checkpoints))
main_program = None
eval_program = None
if checkpoints is not None:
cks = [dir for dir in os.listdir(checkpoints)]
if len(cks) > 0:
latest = max([int(ck) for ck in cks])
latest_ck_path = os.path.join(checkpoints, str(latest))
self._iter += 1
with open(latest_ck_path + "/main_program", "rb") as f:
program_desc_str = f.read()
main_program = fluid.Program.parse_from_string(
program_desc_str)
print main_program
with open(latest_ck_path + "/eval_program", "rb") as f:
program_desc_str = f.read()
eval_program = fluid.Program.parse_from_string(
program_desc_str)
with fluid.scope_guard(self._scope):
fluid.io.load_persistables(exe, latest_ck_path,
main_program, "__params__")
print("load checkpoint from: {}".format(latest_ck_path))
print("flops of eval program: {}".format(flops(eval_program)))
return main_program, eval_program, self._iter
def prune(self, train_program, eval_program, params, pruned_flops):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册