提交 0db126d8 编写于 作者: W walloollaw 提交者: qingqing01

caffe2fluid: support prelu convertion (#1161)

上级 a42bfcc3
......@@ -45,7 +45,7 @@ def calc_diff(f1, f2):
sq_df = np.mean(df * df)
return max_df, sq_df
except Exception as e:
return -1.0, -1.0
return 1.0, 1.0
def compare(path1, path2, no_exception):
......
......@@ -245,10 +245,18 @@ class Network(object):
@layer
def prelu(self, input, channel_shared, name):
#fluid = import_fluid()
#output = fluid.layers.relu(input)
#return output
raise NotImplementedError('prelu not implemented')
fluid = import_fluid()
if channel_shared:
mode = 'all'
else:
mode = 'channel'
prefix = name + '_'
output = fluid.layers.prelu(
input,
mode=mode,
param_attr=fluid.ParamAttr(name=prefix + 'negslope'))
return output
def pool(self, pool_type, input, k_h, k_w, s_h, s_w, ceil_mode, padding,
name):
......
......@@ -176,6 +176,7 @@ class DataReshaper(object):
del node.reshaped_data
return graph
class CropFuser(object):
'''
Crop is to return a scalar output Blob for an input Blob of arbitrary size.
......@@ -197,7 +198,8 @@ class CropFuser(object):
cls._traced_names[fname] = []
cls._traced_names[fname].append(tname)
def __init__(self, allowed_parent_types=[NodeKind.Input, NodeKind.DummyData]):
def __init__(self,
allowed_parent_types=[NodeKind.Input, NodeKind.DummyData]):
self.allowed_parent_types = allowed_parent_types
def __call__(self, graph):
......@@ -232,7 +234,11 @@ class CropFuser(object):
def merge(self, parent, child):
'''Merge the parent node into the child.'''
child.metadata['shape'] = [parent.output_shape.batch_size, parent.output_shape.channels, parent.output_shape.height, parent.output_shape.width]
child.metadata['shape'] = [
parent.output_shape.batch_size, parent.output_shape.channels,
parent.output_shape.height, parent.output_shape.width
]
class SubNodeFuser(object):
'''
......@@ -395,6 +401,8 @@ class ParameterNamer(object):
names = ('scale', )
if getattr(node.parameters, 'bias_term', False):
names = ('scale', 'offset')
elif node.kind == NodeKind.PReLU:
names = ('negslope', )
elif node.kind == "Normalize":
names = ('scale', )
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册