提交 e1e5f5d1 编写于 作者: W wanghaoshuang

Support for dynamic graph.

上级 bb0f8fbb
......@@ -12,7 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .graph_wrapper import GraphWrapper, VarWrapper, OpWrapper
from .graph_wrapper import GraphWrapper
from .registry import Registry
__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper', 'Registry']
__all__ = ['GraphWrapper', 'Registry']
try:
from .dy_graph import DyGraph
__all__ += ['DyGraph']
except Exception as e:
pass
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import pickle
import numpy as np
from collections import OrderedDict
from collections import Iterable
import torch
__all__ = ['DyGraph', 'VarWrapper', 'OpWrapper']
class VarWrapper(object):
def __init__(self, id, is_parameter=False, tensor=None):
self._id = id
self._inputs = []
self._outputs = []
self._is_parameter = is_parameter
self._tensor = tensor
def __eq__(self, v):
"""
Overwrite this function for ...in... syntax in python.
"""
return self._id == v._id
def name(self):
"""
Get the name of the variable.
"""
return self._id
def __repr__(self):
return "id: {};".format(self._id)
def shape(self):
"""
Get the shape of the varibale.
"""
return self._tensor.shape
def set_shape(self, shape):
"""
Set the shape of the variable.
"""
assert ("Unimplement")
def inputs(self):
"""
Get all the operators that use this variable as output.
Returns:
list<OpWrapper>: A list of operators.
"""
return self._inputs
def outputs(self):
"""
Get all the operators that use this variable as input.
Returns:
list<OpWrapper>: A list of operators.
"""
return self._outputs
def is_parameter(self):
return self._is_parameter
class OpWrapper(object):
def __init__(self, id, name):
self._id = id
self.name = name
self.module = None
self._inputs = []
self._outputs = []
def __eq__(self, op):
"""
Overwrite this function for ...in... syntax in python.
"""
return self.id() == op.id()
def all_inputs(self):
"""
Get all the input variables of this operator.
"""
return self._inputs
def all_outputs(self):
"""
Get all the output variables of this operator.
"""
return self._outputs
def id(self):
"""
Get the id of this operator.
"""
return self._id
def type(self):
"""
Get the type of this operator.
"""
if self.module is not None:
return self.module.__class__.__name__
else:
if self.name.startswith("aten::"):
return self.name.split(":")[-1]
def __repr__(self):
return "op[id: {}, type: {}; inputs: {}]".format(self.id(),
self.type(),
self.all_inputs())
def is_bwd_op(self):
"""
Whether this operator is backward op.
"""
return False
def is_opt_op(self):
"""
Whether this operator is optimizer op.
"""
return False
def inputs(self, name):
"""
Get all the varibales by the input name.
"""
return [self._graph.var(var_name) for var_name in self._op.input(name)]
def outputs(self, name):
"""
Get all the varibales by the output name.
"""
return [
self._graph.var(var_name) for var_name in self._op.output(name)
]
def set_attr(self, key, value):
"""
Set the value of attribute by attribute's name.
Args:
key(str): the attribute name.
value(bool|int|str|float|list): the value of the attribute.
"""
self._op._set_attr(key, value)
def attr(self, name):
"""
Get the attribute by name.
Args:
name(str): the attribute name.
Returns:
bool|int|str|float|list: The attribute value. The return value
can be any valid attribute type.
"""
print dir(self.module)
return self._op.attr(name)
class DyGraph(object):
"""
It is a wrapper of paddle.fluid.framework.IrGraph with some special functions
for paddle slim framework.
Args:
program(framework.Program): A program with
in_nodes(dict): A dict to indicate the input nodes of the graph.
The key is user-defined and human-readable name.
The value is the name of Variable.
out_nodes(dict): A dict to indicate the input nodes of the graph.
The key is user-defined and human-readable name.
The value is the name of Variable.
"""
def __init__(self, module, input_shape):
"""
"""
super(DyGraph, self).__init__()
self.module = module
self._graph = torch.jit.trace(self.module,
torch.rand(input_shape)).graph
print self._graph
self.children = {}
for name, child in self.module.named_children():
self.children[name] = child
self.id2child = {}
for node in self._graph.nodes():
if "prim::GetAttr" == node.kind() and "self.1" == node.inputsAt(
0).debugName():
# print dir(node)
self.id2child[node.output().debugName()] = node["name"]
print self.id2child
self.vars = {}
self.nodes = {}
for node in self._graph.nodes():
if "prim::CallMethod" == node.kind() and "forward" == node["name"]:
module_id = node.inputsAt(0).debugName()
node_id = node.output().debugName() + "-" + module_id
in_var_id = node.inputsAt(1).debugName()
out_var_id = node.output().debugName()
if node_id not in self.nodes:
self.nodes[node_id] = OpWrapper(node_id,
self.id2child[module_id])
self.nodes[node_id].module = self.children[self.id2child[
module_id]]
for param_id, param in self.nodes[
node_id].module.named_parameters():
param_id = ".".join([self.id2child[module_id], param_id])
if param_id not in self.vars:
self.vars[param_id] = VarWrapper(
param_id, is_parameter=True, tensor=param)
self.nodes[node_id].all_inputs().append(self.vars[
param_id])
self.vars[param_id].outputs().append(self.nodes[
node_id])
if in_var_id not in self.vars:
self.vars[in_var_id] = VarWrapper(in_var_id)
if out_var_id not in self.vars:
self.vars[out_var_id] = VarWrapper(out_var_id)
self.nodes[node_id].all_inputs().append(self.vars[in_var_id])
self.nodes[node_id].all_outputs().append(self.vars[out_var_id])
self.vars[in_var_id].outputs().append(self.nodes[node_id])
self.vars[out_var_id].inputs().append(self.nodes[node_id])
elif node.kind().startswith("aten::"):
# print dir(node)
node_id = node.output().debugName() + "-" + node.kind()
# node_id = node.debugName()
if node_id not in self.nodes:
self.nodes[node_id] = OpWrapper(node_id, node.kind())
# self.nodes[node_id].type = node.kind()
for input in node.inputs():
in_var_id = input.debugName()
if in_var_id not in self.vars:
self.vars[in_var_id] = VarWrapper(in_var_id)
self.vars[in_var_id].outputs().append(self.nodes[node_id])
self.nodes[node_id].all_inputs().append(self.vars[
in_var_id])
for output in node.outputs():
out_var_id = output.debugName()
if out_var_id not in self.vars:
self.vars[out_var_id] = VarWrapper(out_var_id)
self.vars[out_var_id].inputs().append(self.nodes[node_id])
self.nodes[node_id].all_outputs().append(self.vars[
out_var_id])
def all_parameters(self):
"""
Get all the parameters in this graph.
Returns:
list<VarWrapper>: A list of VarWrapper instances.
"""
params = []
for var in self.vars.values():
if var.is_parameter():
params.append(var)
return params
def is_parameter(self, var):
"""
Whether the given variable is parameter.
Args:
var(VarWrapper): The given varibale.
"""
return var.is_parameter()
def ops(self):
"""
Return all operator nodes included in the graph as a set.
"""
return self.nodes.values()
def vars(self):
"""
Get all the variables.
"""
return self.vars.values()
def var(self, name):
"""
Get the variable by variable name.
"""
return self.vars[name]
def clone(self, for_test=False):
"""
Clone a new graph from current graph.
Returns:
(DyGraph): The wrapper of a new graph.
"""
return DyGraph(
self.program.clone(for_test),
copy.deepcopy(self.in_nodes), copy.deepcopy(self.out_nodes))
def program(self):
"""
Get the program in current wrapper.
"""
return self.program
def pre_ops(self, op):
"""
Get all the previous operators of target operator.
Args:
op(OpWrapper): Target operator.
Returns:
list<OpWrapper>: A list of operators.
"""
ops = []
for p in self.ops():
for in_var in op.all_inputs():
if in_var in p.all_outputs():
ops.append(p)
return ops
def next_ops(self, op):
"""
Get all the next operators of target operator.
Args:
op(OpWrapper): Target operator.
Returns:
list<OpWrapper>: A list of operators.
"""
ops = []
for p in self.ops():
for out_var in op.all_outputs():
if out_var in p.all_inputs():
ops.append(p)
return ops
def get_param_by_op(self, op):
"""
Get the parameters used by target operator.
"""
assert isinstance(op, OpWrapper)
params = []
for var in op.all_inputs():
if isinstance(var._var, Parameter):
params.append(var)
assert len(params) > 0
return params
def numel_params(self):
"""
Get the number of elements in all parameters.
"""
ret = 0
for param in self.all_parameters():
ret += np.product(param.shape())
return ret
def update_param_shape(self, scope):
"""
Update the shape of parameters in the graph according to tensors in scope.
It is used after loading pruned parameters from file.
"""
for param in self.all_parameters():
tensor_shape = np.array(
scope.find_var(param.name()).get_tensor()).shape
param.set_shape(tensor_shape)
def infer_shape(self):
"""
Update the groups of convolution layer according to current filters.
It is used after loading pruned parameters from file.
"""
for op in self.ops():
if op.type() != 'conditional_block':
op._op.desc.infer_shape(op._op.block.desc)
def update_groups_of_conv(self):
for op in self.ops():
if op.type() == 'depthwise_conv2d' or op.type(
) == 'depthwise_conv2d_grad':
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from paddleslim.core import Registry
from paddleslim.common import get_logger
__all__ = ["PRUNE_WORKER", "Conv2d"]
_logger = get_logger(__name__, level=logging.INFO)
PRUNE_WORKER = Registry('prune_worker')
class PruneWorker(object):
def __init__(self, op, pruned_params=[], visited={}):
"""
A wrapper of operator used to infer the information of all the related variables.
Args:
op(Operator): The operator to be pruned.
pruned_params(list): The list to store the information of pruning that infered by walker.
visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name.
Return: A instance of PruneWalker.
"""
self.op = op
self.pruned_params = pruned_params
self.visited = visited
def prune(self, var, pruned_axis, pruned_idx):
"""
Infer the shape of variables related with current operator, predecessor and successor.
It will search the graph to find all varibles related with `var` and record the information of pruning.
Args:
var(Variable): The root variable of searching. It can be the input or output of current operator.
pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable.
"""
if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, pruned_idx)
def _visit(self, var, pruned_axis):
key = "_".join([str(self.op.id()), var.name()])
if pruned_axis not in self.visited:
self.visited[pruned_axis] = {}
if key in self.visited[pruned_axis]:
return False
else:
self.visited[pruned_axis][key] = True
return True
def _prune(self, var, pruned_axis, pruned_idx):
raise NotImplementedError('Abstract method.')
def _prune_op(self, op, var, pruned_axis, pruned_idx, visited=None):
if op.type().endswith("_grad"):
return
if visited is not None:
self.visited = visited
cls = PRUNE_WORKER.get(op.type())
assert cls is not None, "The walker of {} is not registered.".format(
op.type())
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name()))
walker = cls(op,
pruned_params=self.pruned_params,
visited=self.visited)
walker.prune(var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class Conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(Conv2d, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
channel_axis = 1
print self.op.all_inputs()
if var == self.op.all_inputs()[-1]: # input
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.all_inputs()[0]
self._visit(filter_var, 1)
self.pruned_params.append((filter_var, 1, pruned_idx))
for op in filter_var.outputs():
self._prune_op(op, filter_var, 1, pruned_idx)
elif var == self.op.all_inputs()[0]: # filter
assert pruned_axis in [0, 1]
self.pruned_params.append((var, pruned_axis, pruned_idx))
for op in var.outputs():
self._prune_op(op, var, pruned_axis, pruned_idx)
if pruned_axis == 0:
if len(self.op.all_inputs()) > 2: # has bias
self.pruned_params.append(
(self.op.all_inputs()[1], channel_axis, pruned_idx))
output_var = self.op.all_outputs()[0]
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
elif pruned_axis == 1:
input_var = self.op.all_inputs()[-1]
self._visit(input_var, channel_axis)
pre_ops = input_var.inputs()
for op in pre_ops:
self._prune_op(op, input_var, channel_axis, pruned_idx)
elif var in self.op.all_outputs():
assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.all_inputs()[0]
self._visit(filter_var, 0)
self.pruned_params.append((filter_var, 0, pruned_idx))
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
if len(self.op.all_inputs()) > 2:
self.pruned_params.append(
(self.op.all_inputs()[1], channel_axis, pruned_idx))
output_var = self.op.all_outputs()[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
@PRUNE_WORKER.register
class batch_norm(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(batch_norm, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if (var not in self.op.outputs("Y")) and (
var not in self.op.inputs("X")):
return
if var in self.op.outputs("Y"):
in_var = self.op.inputs("X")[0]
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
for param in ["Scale", "Bias", "Mean", "Variance"]:
param_var = self.op.inputs(param)[0]
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
self.pruned_params.append((param_var, 0, pruned_idx))
out_var = self.op.outputs("Y")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
class elementwise_op(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(elementwise_op, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
axis = self.op.attr("axis")
if axis == -1: # TODO
axis = 0
if var in self.op.outputs("Out"):
for name in ["X", "Y"]:
actual_axis = pruned_axis
if name == "Y":
actual_axis = pruned_axis - axis
in_var = self.op.inputs(name)[0]
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, actual_axis, pruned_idx)
else:
if var in self.op.inputs("X"):
in_var = self.op.inputs("Y")[0]
if in_var.is_parameter():
self.pruned_params.append(
(in_var, pruned_axis - axis, pruned_idx))
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis - axis, pruned_idx)
elif var in self.op.inputs("Y"):
in_var = self.op.inputs("X")[0]
pre_ops = in_var.inputs()
pruned_axis = pruned_axis + axis
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class elementwise_add(elementwise_op):
def __init__(self, op, pruned_params, visited):
super(elementwise_add, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class elementwise_sub(elementwise_op):
def __init__(self, op, pruned_params, visited):
super(elementwise_sub, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class elementwise_mul(elementwise_op):
def __init__(self, op, pruned_params, visited):
super(elementwise_mul, self).__init__(op, pruned_params, visited)
class activation(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(activation, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.all_outputs():
in_var = self.op.all_inputs()[0]
for op in in_var.inputs():
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.all_outputs()[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class uniform_random_batch_size_like(activation):
def __init__(self, op, pruned_params, visited):
super(uniform_random_batch_size_like, self).__init__(op, pruned_params,
visited)
self.input_name = "Input"
self.output_name = "Out"
@PRUNE_WORKER.register
class bilinear_interp(activation):
def __init__(self, op, pruned_params, visited):
super(bilinear_interp, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class nearest_interp(activation):
def __init__(self, op, pruned_params, visited):
super(nearest_interp, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class relu(activation):
def __init__(self, op, pruned_params, visited):
super(relu, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class leaky_relu(activation):
def __init__(self, op, pruned_params, visited):
super(leaky_relu, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class floor(activation):
def __init__(self, op, pruned_params, visited):
super(floor, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class relu6(activation):
def __init__(self, op, pruned_params, visited):
super(relu6, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class MaxPool2d(activation):
def __init__(self, op, pruned_params, visited):
super(MaxPool2d, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class sum(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(sum, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs("Out"):
for in_var in self.op.inputs("X"):
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
elif var in self.op.inputs("X"):
for in_var in self.op.inputs("X"):
if in_var != var:
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class concat(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(concat, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
idx = []
axis = self.op.attr("axis")
if var in self.op.outputs("Out"):
start = 0
if axis == pruned_axis:
for _, in_var in enumerate(self.op.inputs("X")):
idx = []
for i in pruned_idx:
r_idx = i - start
if r_idx < in_var.shape()[pruned_axis] and r_idx >= 0:
idx.append(r_idx)
start += in_var.shape()[pruned_axis]
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, idx)
idx = pruned_idx[:]
else:
for _, in_var in enumerate(self.op.inputs("X")):
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
elif var in self.op.inputs("X"):
if axis == pruned_axis:
idx = []
start = 0
for v in self.op.inputs("X"):
if v.name() == var.name():
idx = [i + start for i in pruned_idx]
else:
start += v.shape()[pruned_axis]
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, idx, visited={})
else:
for v in self.op.inputs("X"):
for op in v.inputs():
self._prune_op(op, v, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class depthwise_conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(depthwise_conv2d, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
data_format = self.op.attr("data_format")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
if var in self.op.inputs("Input"):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis)
filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
self._visit(filter_var, 0)
new_groups = filter_var.shape()[0] - len(pruned_idx)
self.op.set_attr("groups", new_groups)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
elif var in self.op.inputs("Filter"):
assert pruned_axis in [0]
if pruned_axis == 0:
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx))
self.pruned_params.append((var, 0, pruned_idx))
new_groups = var.shape()[0] - len(pruned_idx)
self.op.set_attr("groups", new_groups)
for op in var.outputs():
self._prune_op(op, var, 0, pruned_idx)
output_var = self.op.outputs("Output")[0]
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
for op in var.outputs():
self._prune_op(op, var, pruned_axis, pruned_idx)
elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis
filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
self._visit(filter_var, 0)
new_groups = filter_var.shape()[0] - len(pruned_idx)
op.set_attr("groups", new_groups)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
in_var = self.op.inputs("Input")[0]
self._visit(in_var, channel_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, channel_axis, pruned_idx)
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
@PRUNE_WORKER.register
class mul(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(mul, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X"):
assert pruned_axis == 1, "The Input of conv2d can only be pruned at axis 1, but got {}".format(
pruned_axis)
idx = []
feature_map_size = var.shape()[2] * var.shape()[3]
range_idx = np.array(range(feature_map_size))
for i in pruned_idx:
idx += list(range_idx + i * feature_map_size)
param_var = self.op.inputs("Y")[0]
self.pruned_params.append((param_var, 0, idx))
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
@PRUNE_WORKER.register
class scale(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(scale, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X"):
out_var = self.op.outputs("Out")[0]
for op in out_var.outputs():
self._prune_op(op, out_var, pruned_axis, pruned_idx)
elif var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0]
for op in in_var.inputs():
self._prune_op(op, in_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class momentum(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(momentum, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("Param"):
_logger.debug("pruning momentum, var:{}".format(var.name()))
velocity_var = self.op.inputs("Velocity")[0]
self.pruned_params.append((velocity_var, pruned_axis, pruned_idx))
@PRUNE_WORKER.register
class adam(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(adam, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("Param"):
_logger.debug("pruning momentum, var:{}".format(var.name()))
moment1_var = self.op.inputs("Moment1")[0]
self.pruned_params.append((moment1_var, pruned_axis, pruned_idx))
moment2_var = self.op.inputs("Moment2")[0]
self.pruned_params.append((moment2_var, pruned_axis, pruned_idx))
......@@ -17,8 +17,13 @@ import sys
import numpy as np
import paddle.fluid as fluid
import copy
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..core import GraphWrapper
try:
from ..core import DyGraph
except Exception as e:
pass
from .prune_walker import conv2d as conv2d_walker
from .dy_prune_walker import Conv2d as dy_conv2d_walker
from ..common import get_logger
__all__ = ["Pruner"]
......@@ -38,7 +43,7 @@ class Pruner():
self.criterion = criterion
def prune(self,
program,
graph,
scope,
params,
ratios,
......@@ -68,7 +73,13 @@ class Pruner():
"""
self.pruned_list = []
graph = GraphWrapper(program.clone())
if isinstance(graph, fluid.Program):
graph = GraphWrapper(program.clone())
elif isinstance(graph, torch.nn.Module):
graph = DyGraph(graph)
conv2d_walker = dy_conv2d_walker
else:
raise NotImplementedError('The type of graph is not supported.')
param_backup = {} if param_backup else None
param_shape_backup = {} if param_shape_backup else None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册