提交 b8fe0843 编写于 作者: S SunAhong1993

add custom layer v1

上级 11f5f1c2
......@@ -230,3 +230,226 @@ def shape_batchnorm(layer, input_shape):
def shape_scale(layer, input_shape):
return input_shape
def shape_reshape(layer, input_shape):
def count(num_list):
return reduce(lambda a, b: a * b, num_list)
inshape = input_shape[0]
params = layer.reshape_param
axis = params.axis if hasattr(params, axis) else 0
num_axes = params.num_axes if hasattr(params, num_axes) else -1
if inshape[0] == -1:
inshape[0] = 1
input_count = count(inshape)
input_num_axes = len(inshape)
input_start_axis = axis
start_axis = input_start_axis if input_start_axis >= 0 \
else input_num_axes + input_start_axis + 1
assert start_axis >= 0, "[Reshape]axis %d out of range" % (input_start_axis)
assert start_axis <= input_num_axes, "[Reshape]axis %d out of range for %d-D input data"\
% (input_start_axis, input_num_axes)
assert num_axes >= -1, "[Reshape]num_axes must be >= 0, or -1 for all"
end_axis = input_num_axes if num_axes == -1 else start_axis + num_axes
assert end_axis <= input_num_axes, "end_axis[%d] = axis[%d] + num_axes[%d] is out of range"\
% (end_axis, start_axis, num_axes)
num_axes_replaced = end_axis - start_axis
num_axes_retained = input_num_axes - num_axes_replaced
num_new_axes = len(shape['dim'])
outshape = []
for i in range(start_axis):
outshape.append(inshape[i])
for i in range(num_new_axes):
outshape.append(shape['dim'][i])
for i in range(end_axis, input_num_axes):
outshape.append(inshape[i])
assert len(outshape) == num_axes_retained + num_new_axes,\
"[Reshape]invalid dims of output shape[%s]" % (str(outshape))
inferred_axis = -1
copy_axes = []
constant_count = 1
for i in range(num_new_axes):
top_dim = shape['dim'][i]
if top_dim == 0:
copy_axes.append(i)
copy_axis_index = start_axis + i
outshape[copy_axis_index] = inshape[copy_axis_index]
elif top_dim == -1:
assert inferred_axis == -1, "[Reshape]new shape contains multiple -1 dims"
inferred_axis = i
else:
constant_count *= top_dim
if inferred_axis >= 0:
explicit_count = constant_count
l = inshape[0:start_axis]
if len(l) > 0:
explicit_count *= count(l)
l = inshape[end_axis:]
if len(l) > 0:
explicit_count *= count(l)
for i in range(len(copy_axes)):
explicit_count *= outshape[start_axis + copy_axes[i]]
assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\
"must be divisible by product of the specified dimensions[%d] "\
% (input_count, explicit_count)
outshape[start_axis + inferred_axis] = input_count / explicit_count
output_count = count(outshape)
assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % (
output_count, input_count)
if inshape[0] == -1:
outshape[0] = -1
return [outshape]
def shape_argmax(layer, input_shape):
inshape = input_shape[0]
params = layer.argmax_param
out_max_val = params.out_max_val if hasattr(params, out_max_val) else False
top_k = params.top_k if hasattr(params, top_k) else 1
axis = parmas.axis if hasattr(params, axis) else -1
if axis < 0:
axis += len(inshape)
assert (axis + 1 == len(inshape)
), 'only can be applied on the last dimension[axis:%d, %s] now,'\
'make sure you have set axis param in xxx.prototxt file' \
% (axis, str(inshape))
outshape = inshape
outshape[-1] = top_k
if out_max_val is True:
outshape[-1] *= 2
return [outshape]
def shape_axpy(layer, input_shape):
assert len(input_shapes) == 3, "not valid input shape for axpy layer"
assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims'
output_shape = input_shapes[1]
assert (input_shapes[2] == output_shape),\
"shape not consistent for axpy[%s <--> %s]" \
% (str(output_shape), str(input_shapes[2]))
return [output_shape]
def shape_crop(layer, input_shape):
assert len(input_shape) == 2, "the number of crop's inputs must be 2"
return [input_shape[1]]
def shape_detectionoutput(layer, input_shape):
return [[-1, 6]]
def shape_flatten(layer, input_shape):
assert len(input_shape) == 1, "the number of flatten's inputs must be 1"
params = layer.flatten_param
start_axis = params.axis
end_axis = params.end_axis
if start_axis < 0:
start_axis += len(input_shape[0])
if end_axis < 0:
end_axis += len(input_shape[0]) + 1
assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\
% (start_axis, end_axis)
output_shape = [0] * (start_axis - 0) + [
-1
] + [0] * (len(input_shape[0]) - end_axis)
return [output_shape]
def shape_normalize(layer, input_shape):
return input_shape
def shape_permute(layer, input_shape):
params = layer.permute_param
order = list(params.order)
inshape = input_shape[0]
output_shape = []
for ii in order:
assert ii < len(inshape), "invalid order for permute[%s]" % (name)
output_shape.append(inshape[ii])
return [output_shape]
def shape_power(layer, input_shape):
return input_shape
def shape_priorbox(layer, input_shape):
params = layer.prior_box_param
min_size = list(params.min_size)
max_size = list(params.max_size)
aspect_ratio = list(params.aspect_ratio)
assert len(input_shapes[0]) == 2, "invalid inputs for Priorbox[%s]" % (name)
fc_shape = input_shapes[0][0]
N = 1
if not max_size == None:
N += 1
if not aspect_ratio == None:
N += 2 * len(aspect_ratio)
N_bbx = fc_shape[2] * fc_shape[3] * N
output_shape = [[1, 2, 4 * N_bbx]]
return output_shape
def shape_reduction(layer, input_shape):
params = layer.reduction_param
axis = params.axis
if axis < 0:
axis += len(input_shape[0]) + 1
assert axis <= len(input_shape[0]), 'invalid axis[%d] error' % (axis)
return [input_shape[0:axis]]
def shape_roipooling(layer, input_shape):
params = layer.roi_pooling_param
pooled_w = params.pooled_w
pooled_h = params.pooled_h
spatial_scale = params.spatial_scale
assert len(
input_shapes[0]) == 2, "not valid input shape for roipooling layer"
base_fea_shape = input_shapes[0][0]
rois_shape = input_shapes[0][1]
output_shape = base_fea_shape
output_shape[0] = rois_shape[0]
output_shape[2] = pooled_h
output_shape[3] = pooled_w
return [output_shape]
def shape_select(layer, input_shape):
input_shape = list(input_shape[0])
params = layer.select_param
axis = params.axis
slice_point = list(params.slice_point)
start = slice_point[0]
if len(slice_point) == 2:
end = slice_point[1]
else:
end = input_shape[axis]
assert end > start, "invalid slice_point with [start:%d, end:%d]"\
% (start, end)
output_shape = input_shape
output_shape[axis] = end - start
return [output_shape]
......@@ -267,6 +267,7 @@ class CaffeOpMapper(OpMapper):
def Pooling(self, node):
params = node.layer.pooling_param
ceil_mode = getattr(params, 'ceil_mode', True)
global_pool = getattr(params, 'global_pooling', False)
kernel_default = [1, 1]
channel, kernel, stride, pad, dilation, group = self.get_kernel_parameters(
......@@ -286,7 +287,7 @@ class CaffeOpMapper(OpMapper):
'pool_size': kernel,
'pool_stride': stride,
'pool_padding': pad,
'ceil_mode': True,
'ceil_mode': ceil_mode,
'pool_type': string(pool_type),
'exclusive': True,
'global_pooling': global_pool,
......@@ -737,7 +738,7 @@ class CaffeOpMapper(OpMapper):
else:
self.weights[node.layer_name + '_scale'] = np.squeeze(nose.data[0])
self.weights[node.layer_name + '_offset'] = np.squeeze(node.data[1])
params = node.layer.scale_params
params = node.layer.scale_param
axis = params.axis
num_axes = params.num_axes
assert num_axes == 1, "layer scale not support this num_axes[%d] now" % (
......@@ -811,3 +812,518 @@ class CaffeOpMapper(OpMapper):
node.layer_name, node.layer_name),
output=node,
param_attr=attr)
def Reshape(self, node):
assert len(node.inputs) == 1 and len(
node.outputs
) == 1, 'The count of Reshape node\'s input and output is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
top_count = len(input.layer.top)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
is_inplace, = False if top_count == 1 else True
output_shape = node.output_shape[0]
attr = {
'shape': output_shape,
'inplace': is_inplace,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("reshape",
inputs=input,
output=node,
param_attr=attr)
def ArgMax(self, node):
assert len(node.inputs) == 1 and len(
node.outputs
) == 1, 'The count of ArgMax node\'s input and output is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
input_shape = node.input_shape[0]
params = node.layer.argmax_param
out_max_val = params.out_max_val if hasattr(params,
out_max_val) else False
top_k = params.top_k if hasattr(params, top_k) else 1
axis = parmas.axis if hasattr(params, axis) else -1
if axis < 0:
axis += len(input_shape)
if out_max_val is True:
attr = {'k': top_k, 'name': string(node.layer_name + '_topk')}
node.fluid_code.add_layer("topk",
inputs=input,
output='{}_topk_var, {}_index_var'.format(
node.layer_name, node.layer_name),
param_attr=attr)
attr = {'dtype': '{}_topk_var.dtype'.format(node.layer_name)}
node.fluid_code.add_layer(
"cast",
inputs='{}_index_var'.format(node.layer_name),
output='{}_index_var'.format(node.layer_name),
param_attr=attr)
attr = {'axis': axis, 'name': string(node.layer_name)}
node.fluid_code.add_layer("concat",
inputs='{}_topk_var, {}_index_var'.format(
node.layer_name, node.layer_name),
output=node,
param_attr=attr)
else:
attr = {'k': top_k, 'name': string(node.layer_name)}
node.fluid_code.add_layer("topk",
inputs=input,
output='_, {}'.format(node.layer_name),
param_attr=attr)
def Axpy(self, node):
assert len(
node.inputs) == 3, 'The count of Axpy node\'s input is not 3.'
alpha = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(alpha):
tmp = self.graph.get_bottom_node(alpha, idx=0, copy=True)
if self.is_BN(tmp):
alpha = tmp
x = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(x):
tmp = self.graph.get_bottom_node(x, idx=0, copy=True)
if self.is_BN(tmp):
x = tmp
y = self.graph.get_bottom_node(node, idx=2, copy=True)
if self.is_Scale(y):
tmp = self.graph.get_bottom_node(y, idx=0, copy=True)
if self.is_BN(tmp):
y = tmp
attr = {'axis': 0, 'name': string(node.layer_name + '_mul')}
node.fluid_code.add_layer("elementwise_mul",
inputs={
'x': alpha,
'y': x
},
output=node,
param_attr=attr)
attr = {'name': string(node.layer_name + '_add')}
node.fluid_code.add_layer("elementwise_add",
inputs={
'x': node,
'y': y
},
output=node,
param_attr=attr)
def Crop(self, node):
assert len(
node.inputs) == 2, 'The count of Crop node\'s input is not 2.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
example = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(example):
tmp = self.graph.get_bottom_node(example, idx=0, copy=True)
if self.is_BN(tmp):
example = tmp
params = node.layer.crop_param
axis = parmas.axis
input_shape = node.input_shape[0]
if axis < 0:
axis += len(input_shape)
offset_real = [0] * len(input_shape)
if hasattr(params, offset):
offset = list(params.offset)
assert (len(input_shape) - axis) == len(
offset), "invalid offset[%s] in crop layer" % (str(offset))
offset_real = [0] * axis + offset
attr = {'offsets': offset_real, 'name': string(node.layer_name)}
node.fluid_code.add_layer("crop",
inputs={
'x': input,
'y': example
},
output=node,
param_attr=attr)
def DetectionOutput(self, node):
assert len(
node.inputs
) == 3, 'The count of DetectionOutput node\'s input is not 3.'
mbox_loc = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(mbox_loc):
tmp = self.graph.get_bottom_node(mbox_loc, idx=0, copy=True)
if self.is_BN(tmp):
mbox_loc = tmp
mbox_conf_flatten = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(mbox_conf_flatten):
tmp = self.graph.get_bottom_node(mbox_conf_flatten,
idx=0,
copy=True)
if self.is_BN(tmp):
mbox_conf_flatten = tmp
mbox_priorbox = self.graph.get_bottom_node(node, idx=2, copy=True)
if self.is_Scale(mbox_priorbox):
tmp = self.graph.get_bottom_node(mbox_priorbox, idx=0, copy=True)
if self.is_BN(tmp):
mbox_priorbox = tmp
params = node.layer.detection_output_param
nms_threshold = 0.3
top_k = 10
eta = 1.0
if hasattr(params, 'nms_param'):
nms_threshold = getattr(params.nms_param, 'nms_threshold', 0.3)
top_k = getattr(params.nms_param, 'top_k', 10)
eta = getattr(params.nms_param, 'eta', 1.0)
background_label = getattr(params, 'background_label_id', 0)
share_location = getattr(params, 'share_location', True)
keep_top_k = getattr(params, 'keep_top_k', 100)
confidence_threshold = getattr(params, 'confidence_threshold', 0.1)
attr = {
'num_or_sections': 2,
'dim': 1,
'name': string(node.layer_name + '_split')
}
node.fluid_code.add_layer("split",
inputs=mbox_priorbox,
output='mbox_priorbox_list',
param_attr=attr)
node.fluid_code.add_note('pb = mbox_priorbox_list[0]')
node.fluid_code.add_note('pbv = mbox_priorbox_list[1]')
attr = {'shape': [-1, 4], 'name': string(node.layer_name + '_reshape1')}
node.fluid_code.add_layer("reshape",
inputs='pb',
output='pb',
param_attr=attr)
attr = {'shape': [-1, 4], 'name': string(node.layer_name + '_reshape2')}
node.fluid_code.add_layer("reshape",
inputs='pbv',
output='pbv',
param_attr=attr)
# TODO(syf): need chaeck
attr = {
'shape': [-1, node.input_shape[1][1], 4],
'name': string(node.layer_name + '_reshape3')
}
node.fluid_code.add_layer("reshape",
inputs=mbox_loc,
output='mbox_loc',
param_attr=attr)
attr = {
'background_label': background_label,
'nms_threshold': nms_threshold,
'nms_top_k': top_k,
'keep_top_k': keep_top_k,
'score_threshold': confidence_threshold,
'nms_eta': eta
}
inputs_str = get_input_name(mbox_conf_flatten) + ', mbox_loc, pb, pbv'
node.fluid_code.add_layer("detection_output",
inputs=inputs_str,
output=node,
param_attr=attr)
def Flatten(self, noed):
assert len(
node.inputs
) == 1, 'The count of DetectionOutput node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
shape = node.output_shape[0]
attr = {'shape': shape, 'name': string(node.layer_name)}
node.fluid_code.add_layer("reshape",
inputs=input,
output=node,
param_attr=attr)
def Normalize(self, node):
assert len(
node.inputs) == 1, 'The count of Normalize node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.norm_param
across_spatial = params.across_spatial
channel_shared = params.channel_shared
assert across_spatial == False, "Only support across_spatial == False for Normalize"
attr = {'axis': 1, 'name': string(node.layer_name + '_l2')}
node.fluid_code.add_layer("l2_normalize",
inputs=input,
output=node.layer_name + '_l2',
param_attr=attr)
input_name = self.get_input_name(input)
data = node.data
data = self.adjust_parameters(node, data)
self.weights[node.layer_name + '_scale'] = data[0]
node.fluid_code.add_note(
'{}_scale_attr = ParamAttr(name=\'{}\')'.format(
node.layer_name, node.layer_name + '_scale'))
attr = {
'shape': [1] if channel_shared else [node.input_shape[0][1]],
'dtype': '{}.dtype'.format(input_name),
'attr': '{}_scale_attr'.format(node.layer_name),
'name': string(node.layer_name + '_param')
}
node.fluid_code.add_layer("create_parameter",
inputs=None,
output=node.layer_name + '_scale_param',
param_attr=attr)
attr = {
'axis': -1 if channel_shared else 1,
'name': string(node.layer_name + '_mul')
}
node.fluid_code.add_layer("elementwise_mul",
inputs=node.layer_name + '_l2, ' +
node.layer_name + '_scale_param',
output=node,
param_attr=attr)
def Permute(self, node):
assert len(
node.inputs) == 1, 'The count of Permute node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.permute_param
order = list(params.order)
attr = {'order': order, 'name': string(node.layer_name)}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=attr)
def Power(self, node):
assert len(
node.inputs) == 1, 'The count of Permute node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.power_param
power = params.power
scale = params.scale
shift = params.shift
attr = {
'scale': scale,
'bias': shift,
'bias_after_scale': True,
'name': string(node.layer_name + '_scale')
}
node.fluid_code.add_layer("scale",
inputs=input,
output=node,
param_attr=attr)
attr = {'factor': power, 'name': string(node.layer_name)}
node.fluid_code.add_layer("pow",
inputs=node,
output=node,
param_attr=attr)
def PriorBox(self, node):
assert len(
node.inputs) == 2, 'The count of PriorBox node\'s input is not 2.'
input1 = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input1):
tmp = self.graph.get_bottom_node(input1, idx=0, copy=True)
if self.is_BN(tmp):
input1 = tmp
input2 = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(input2):
tmp = self.graph.get_bottom_node(input2, idx=0, copy=True)
if self.is_BN(tmp):
input2 = tmp
input_dict = {'input': input1, 'image': input2}
params = node.layer.prior_box_param
step = getattr(params, 'step', 0.0)
offset = getattr(params, 'offset', 0.5)
min_size = list(params.min_size)
max_size = list(params.max_size)
aspect_ratio = list(params.aspect_ratio)
flip = getattr(params, 'flip', False)
clip = getattr(params, 'clip', False)
variance = list(getattr(params, 'variance', [0.1, 0.1, 0.2, 0.2]))
steps = tuple(step) if type(step) is list or type(step) is tuple else (
step, step)
attr = {
'min_sizes': min_size,
'max_sizes': max_size,
'aspect_ratios': aspect_ratio,
'variance': variance,
'flip': flip,
'clip': clip,
'step': steps,
'offset': offset,
'min_max_aspect_ratios_order': True,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("prior_box",
inputs=input_dict,
output='{}_box, {}_var'.format(
node.layer_name, node.layer_name),
param_attr=attr)
attr = {
'shape': [1, 1, -1],
}
node.fluid_code.add_layer("reshape",
inputs='{}_box'.format(node.layer_name),
output='{}_box'.format(node.layer_name),
param_attr=attr)
attr = {
'shape': [1, 1, -1],
}
node.fluid_code.add_layer("reshape",
inputs='{}_var'.format(node.layer_name),
output='{}_var'.format(node.layer_name),
param_attr=attr)
attr = {'axis': 1, 'name': string(node.layer_name + '_concat')}
node.fluid_code.add_layer("concat",
inputs='[{}_box, {}_var]'.format(
node.layer_name, node.layer_name),
output=node,
param_attr=attr)
def Reduction(self, node):
assert len(
node.inputs) == 1, 'The count of Reduction node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.reduction_param
operation = params.operation
axis = params.axis
coeff = params.coeff
assert operation >= 1 and operation <= 4, "reduction reduction [%s] error" % (
operation)
input_len = len(node.input_shape[0])
if axis < 0:
axis += input_len + 1
dim = list(range(input_len))
if operation == 1: ## operation = SUM
attr = {
'dim': dim[axis:],
'keep_dim': False,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("reduce_sum",
inputs=input,
output=node,
param_attr=attr)
elif operation == 2: ## operation = ASUM
attr = {'name': string(node.layer_name + '_abs')}
node.fluid_code.add_layer("abs",
inputs=input,
output=node,
param_attr=attr)
attr = {
'dim': dim[axis:],
'keep_dim': False,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("reduce_sum",
inputs=node,
output=node,
param_attr=attr)
elif operation == 3: ## operation = SUMSQ
attr = {'factor': 2.0, 'name': string(node.layer_name + '_pow')}
node.fluid_code.add_layer("pow",
inputs=input,
output=node,
param_attr=attr)
attr = {
'dim': dim[axis:],
'keep_dim': False,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("reduce_sum",
inputs=node,
output=node,
param_attr=attr)
else: ## operation = MEAN
attr = {
'dim': dim[axis:],
'keep_dim': False,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("reduce_mean",
inputs=node,
output=node,
param_attr=attr)
attr = {'scale': coeff}
node.fluid_code.add_layer("scale",
inputs=node,
output=node,
param_attr=attr)
def ROIPooling(self, node):
assert len(
node.inputs) == 2, 'The count of ROIPooling node\'s input is not 2.'
input1 = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input1):
tmp = self.graph.get_bottom_node(input1, idx=0, copy=True)
if self.is_BN(tmp):
input1 = tmp
input2 = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(input2):
tmp = self.graph.get_bottom_node(input2, idx=0, copy=True)
if self.is_BN(tmp):
input2 = tmp
attr = {'axes': [1], 'starts': [1], 'ends': [5]}
node.fluid_code.add_layer("slice",
inputs=input2,
output=input2,
param_attr=attr)
input_dict = {'input': input1, 'rois': input2}
params = node.layer.roi_pooling_param
attr = {
'pooled_w': params.pooled_w,
'pooled_h': params.pooled_h,
'spatial_scale': params.spatial_scale,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("roi_pool",
inputs=input_dict,
output=node,
param_attr=attr)
def Select(self, node):
assert len(
node.inputs) == 1, 'The count of Select node\'s input is not 2.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.select_param
slice_point = list(params.slice_point)
axis = params.axis
maxint32 = 2147483647
slice_point = [0] + slice_point
slice_point.append(maxint32)
i = 0
node.fluid_code.add_note('{} = []'.format(node.layer_name))
for i in range(len(slice_point)):
attr = {
'axes': [axis],
'starts': [slice_point[i]],
'ends': [slice_point[i + 1]],
'name': string(node.layer_name + '_' + str(i))
}
node.fluid_code.add_layer("slice",
inputs=input,
output=string(node.layer_name + '_' +
str(i)),
param_attr=attr)
node.fluid_code.add_note('{}.append({})'.format(
node.layer_name, node.layer_name + '_' + str(i)))
if i == len(slice_point) - 2:
break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册