提交 0db126d8 编写于 作者: W walloollaw 提交者: qingqing01

caffe2fluid: support prelu convertion (#1161)

上级 a42bfcc3
...@@ -45,7 +45,7 @@ def calc_diff(f1, f2): ...@@ -45,7 +45,7 @@ def calc_diff(f1, f2):
sq_df = np.mean(df * df) sq_df = np.mean(df * df)
return max_df, sq_df return max_df, sq_df
except Exception as e: except Exception as e:
return -1.0, -1.0 return 1.0, 1.0
def compare(path1, path2, no_exception): def compare(path1, path2, no_exception):
......
...@@ -245,10 +245,18 @@ class Network(object): ...@@ -245,10 +245,18 @@ class Network(object):
@layer @layer
def prelu(self, input, channel_shared, name): def prelu(self, input, channel_shared, name):
#fluid = import_fluid() fluid = import_fluid()
#output = fluid.layers.relu(input) if channel_shared:
#return output mode = 'all'
raise NotImplementedError('prelu not implemented') else:
mode = 'channel'
prefix = name + '_'
output = fluid.layers.prelu(
input,
mode=mode,
param_attr=fluid.ParamAttr(name=prefix + 'negslope'))
return output
def pool(self, pool_type, input, k_h, k_w, s_h, s_w, ceil_mode, padding, def pool(self, pool_type, input, k_h, k_w, s_h, s_w, ceil_mode, padding,
name): name):
......
...@@ -176,6 +176,7 @@ class DataReshaper(object): ...@@ -176,6 +176,7 @@ class DataReshaper(object):
del node.reshaped_data del node.reshaped_data
return graph return graph
class CropFuser(object): class CropFuser(object):
''' '''
Crop is to return a scalar output Blob for an input Blob of arbitrary size. Crop is to return a scalar output Blob for an input Blob of arbitrary size.
...@@ -197,7 +198,8 @@ class CropFuser(object): ...@@ -197,7 +198,8 @@ class CropFuser(object):
cls._traced_names[fname] = [] cls._traced_names[fname] = []
cls._traced_names[fname].append(tname) cls._traced_names[fname].append(tname)
def __init__(self, allowed_parent_types=[NodeKind.Input, NodeKind.DummyData]): def __init__(self,
allowed_parent_types=[NodeKind.Input, NodeKind.DummyData]):
self.allowed_parent_types = allowed_parent_types self.allowed_parent_types = allowed_parent_types
def __call__(self, graph): def __call__(self, graph):
...@@ -232,7 +234,11 @@ class CropFuser(object): ...@@ -232,7 +234,11 @@ class CropFuser(object):
def merge(self, parent, child): def merge(self, parent, child):
'''Merge the parent node into the child.''' '''Merge the parent node into the child.'''
child.metadata['shape'] = [parent.output_shape.batch_size, parent.output_shape.channels, parent.output_shape.height, parent.output_shape.width] child.metadata['shape'] = [
parent.output_shape.batch_size, parent.output_shape.channels,
parent.output_shape.height, parent.output_shape.width
]
class SubNodeFuser(object): class SubNodeFuser(object):
''' '''
...@@ -395,6 +401,8 @@ class ParameterNamer(object): ...@@ -395,6 +401,8 @@ class ParameterNamer(object):
names = ('scale', ) names = ('scale', )
if getattr(node.parameters, 'bias_term', False): if getattr(node.parameters, 'bias_term', False):
names = ('scale', 'offset') names = ('scale', 'offset')
elif node.kind == NodeKind.PReLU:
names = ('negslope', )
elif node.kind == "Normalize": elif node.kind == "Normalize":
names = ('scale', ) names = ('scale', )
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册