Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • PaddleSlim
  • Issue
  • #107

P
PaddleSlim
  • 项目概览

PaddlePaddle / PaddleSlim
大约 2 年 前同步成功

通知 51
Star 1434
Fork 344
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 53
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 16
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
PaddleSlim
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 53
    • Issue 53
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 16
    • 合并请求 16
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 2月 12, 2020 by saxon_zh@saxon_zhGuest

如何理解敏感度裁切中的greedy_prune和普通的prune区别?

Created by: zhizunbao-y

1.如提问 2.下列代码中 ,#标出的部分有何意义?为什么要乘2?

def flops_sensitivity(program,
                      place,
                      param_names,
                      eval_func,
                      sensitivities_file=None,
                      pruned_flops_rate=0.1):

    assert (1.0 / len(param_names) > pruned_flops_rate)

    scope = fluid.global_scope()
    graph = GraphWrapper(program)
    sensitivities = load_sensitivities(sensitivities_file)

    for name in param_names:
        if name not in sensitivities:
            sensitivities[name] = {}
    base_flops = flops(program)
    target_pruned_flops = base_flops * pruned_flops_rate

    pruner = Pruner()
    baseline = None
    for name in sensitivities:

        pruned_program, _, _ = pruner.prune(
            program=graph.program,
            scope=None,
            params=[name],
            ratios=[0.5],
            place=None,
            lazy=False,
            only_graph=True)
       ################################################
        param_flops = (base_flops - flops(pruned_program)) * 2
        channel_size = graph.var(name).shape()[0]
        pruned_ratio = target_pruned_flops / float(param_flops)
       ################################################
        pruned_ratio = round(pruned_ratio, 3)
        pruned_size = round(pruned_ratio * channel_size)
        pruned_ratio = 1 if pruned_size >= channel_size else pruned_ratio

        if len(sensitivities[name].keys()) > 0:
            _logger.debug(
                '{} exist; pruned ratio: {}; excepted ratio: {}'.format(
                    name, sensitivities[name].keys(), pruned_ratio))
            continue
        if baseline is None:
            baseline = eval_func(graph.program)
        param_backup = {}
        pruner = Pruner()
        _logger.info("sensitive - param: {}; ratios: {}".format(name,
                                                                pruned_ratio))
        loss = 1
        if pruned_ratio < 1:
            pruned_program = pruner.prune(
                program=graph.program,
                scope=scope,
                params=[name],
                ratios=[pruned_ratio],
                place=place,
                lazy=True,
                only_graph=False,
                param_backup=param_backup)
            pruned_metric = eval_func(pruned_program)
            loss = (baseline - pruned_metric) / baseline
        _logger.info("pruned param: {}; {}; loss={}".format(name, pruned_ratio,
                                                            loss))
        sensitivities[name][pruned_ratio] = loss
        _save_sensitivities(sensitivities, sensitivities_file)

        # restore pruned parameters
        for param_name in param_backup.keys():
            param_t = scope.find_var(param_name).get_tensor()
            param_t.set(param_backup[param_name], place)
    return sensitivities

3.下列代码中min_loss和max_loss初始值均为0,会导致while循环不执行

def get_ratios_by_sensitive(self, sensitivities, pruned_flops,
                                eval_program):
        """
        Search a group of ratios for pruning target flops.
        Args:
          sensitivities(dict): The sensitivities used to generate a group of pruning ratios. The key of dict
                               is name of parameters to be pruned. The value of dict is a list of tuple with
                               format `(pruned_ratio, accuracy_loss)`.
          pruned_flops(float): The percent of FLOPS to be pruned.
          eval_program(Program): The program whose FLOPS is considered.
        Returns:
          dict: A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned.
        """

        min_loss = 0.
        max_loss = 0.
        # step 2: Find a group of ratios by binary searching.
        base_flops = flops(eval_program)
        ratios = None
        max_times = 20
        while min_loss < max_loss and max_times > 0:
            loss = (max_loss + min_loss) / 2
            _logger.info(
                '-----------Try pruned ratios while acc loss={}-----------'.
                format(loss))
            ratios = self.get_ratios_by_loss(sensitivities, loss)
            _logger.info('Pruned ratios={}'.format(
                [round(ratio, 3) for ratio in ratios.values()]))
            pruned_program = self._pruner.prune(
                eval_program,
                None,  # scope
                ratios.keys(),
                ratios.values(),
                None,  # place
                only_graph=True)
            pruned_ratio = 1 - (float(flops(pruned_program)) / base_flops)
            _logger.info('Pruned flops: {:.4f}'.format(pruned_ratio))

            # Check whether current ratios is enough
            if abs(pruned_ratio - pruned_flops) < 0.015:
                break
            if pruned_ratio > pruned_flops:
                max_loss = loss
            else:
                min_loss = loss
            max_times -= 1
        return ratios
指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/PaddleSlim#107
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7