提交 e022c065 编写于 作者: C caoying03

enable dropout in average and max layer.

上级 9e9ae2f3
...@@ -2305,9 +2305,10 @@ class MaxLayer(LayerBase): ...@@ -2305,9 +2305,10 @@ class MaxLayer(LayerBase):
active_type='linear', active_type='linear',
device=None, device=None,
bias=False, bias=False,
output_max_index=None): output_max_index=None,
**xargs):
super(MaxLayer, self).__init__( super(MaxLayer, self).__init__(
name, 'max', 0, inputs=inputs, device=device) name, 'max', 0, inputs=inputs, device=device, **xargs)
config_assert(len(self.inputs) == 1, 'MaxLayer must have 1 input') config_assert(len(self.inputs) == 1, 'MaxLayer must have 1 input')
self.config.trans_type = trans_type self.config.trans_type = trans_type
self.config.active_type = active_type self.config.active_type = active_type
...@@ -2609,14 +2610,16 @@ class AverageLayer(LayerBase): ...@@ -2609,14 +2610,16 @@ class AverageLayer(LayerBase):
trans_type='non-seq', trans_type='non-seq',
active_type='linear', active_type='linear',
device=None, device=None,
bias=False): bias=False,
**xargs):
super(AverageLayer, self).__init__( super(AverageLayer, self).__init__(
name, name,
'average', 'average',
0, 0,
inputs=inputs, inputs=inputs,
device=device, device=device,
active_type=active_type) active_type=active_type,
**xargs)
self.config.average_strategy = average_strategy self.config.average_strategy = average_strategy
self.config.trans_type = trans_type self.config.trans_type = trans_type
config_assert(len(inputs) == 1, 'AverageLayer must have 1 input') config_assert(len(inputs) == 1, 'AverageLayer must have 1 input')
...@@ -3490,7 +3493,7 @@ def parse_config(config_file, config_arg_str): ...@@ -3490,7 +3493,7 @@ def parse_config(config_file, config_arg_str):
def parse_config_and_serialize(config_file, config_arg_str): def parse_config_and_serialize(config_file, config_arg_str):
try: try:
config = parse_config(config_file, config_arg_str) config = parse_config(config_file, config_arg_str)
#logger.info(config) # logger.info(config)
return config.SerializeToString() return config.SerializeToString()
except: except:
traceback.print_exc() traceback.print_exc()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册