提交 20336428 编写于 作者: C caoying03

enable error clipping in FC layer.

上级 98378968
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import pdb
''' '''
The following functions are available in the config file: The following functions are available in the config file:
...@@ -761,8 +762,8 @@ class DotMulOperator(Operator): ...@@ -761,8 +762,8 @@ class DotMulOperator(Operator):
def check_dims(self): def check_dims(self):
for i in range(2): for i in range(2):
config_assert(self.operator_conf.input_sizes[i] == config_assert(self.operator_conf.input_sizes[
self.operator_conf.output_size, i] == self.operator_conf.output_size,
"DotMul input_size != output_size") "DotMul input_size != output_size")
def calc_output_size(self, input_sizes): def calc_output_size(self, input_sizes):
...@@ -1193,8 +1194,7 @@ def parse_image(image, input_layer_name, image_conf): ...@@ -1193,8 +1194,7 @@ def parse_image(image, input_layer_name, image_conf):
def parse_norm(norm, input_layer_name, norm_conf): def parse_norm(norm, input_layer_name, norm_conf):
norm_conf.norm_type = norm.norm_type norm_conf.norm_type = norm.norm_type
config_assert( config_assert(
norm.norm_type in norm.norm_type in ['rnorm', 'cmrnorm-projection', 'cross-channel-norm'],
['rnorm', 'cmrnorm-projection', 'cross-channel-norm'],
"norm-type %s is not in [rnorm, cmrnorm-projection, cross-channel-norm]" "norm-type %s is not in [rnorm, cmrnorm-projection, cross-channel-norm]"
% norm.norm_type) % norm.norm_type)
norm_conf.channels = norm.channels norm_conf.channels = norm.channels
...@@ -1571,7 +1571,13 @@ class MultiClassCrossEntropySelfNormCostLayer(LayerBase): ...@@ -1571,7 +1571,13 @@ class MultiClassCrossEntropySelfNormCostLayer(LayerBase):
@config_layer('fc') @config_layer('fc')
class FCLayer(LayerBase): class FCLayer(LayerBase):
def __init__(self, name, size, inputs, bias=True, **xargs): def __init__(self,
name,
size,
inputs,
bias=True,
error_clipping_threshold=None,
**xargs):
super(FCLayer, self).__init__(name, 'fc', size, inputs=inputs, **xargs) super(FCLayer, self).__init__(name, 'fc', size, inputs=inputs, **xargs)
for input_index in xrange(len(self.inputs)): for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index) input_layer = self.get_input_layer(input_index)
...@@ -1589,6 +1595,9 @@ class FCLayer(LayerBase): ...@@ -1589,6 +1595,9 @@ class FCLayer(LayerBase):
format) format)
self.create_bias_parameter(bias, self.config.size) self.create_bias_parameter(bias, self.config.size)
if error_clipping_threshold is not None:
self.config.error_clipping_threshold = error_clipping_threshold
@config_layer('selective_fc') @config_layer('selective_fc')
class SelectiveFCLayer(LayerBase): class SelectiveFCLayer(LayerBase):
...@@ -3425,7 +3434,8 @@ DEFAULT_SETTING = dict( ...@@ -3425,7 +3434,8 @@ DEFAULT_SETTING = dict(
settings = copy.deepcopy(DEFAULT_SETTING) settings = copy.deepcopy(DEFAULT_SETTING)
settings_deprecated = dict(usage_ratio=1., ) settings_deprecated = dict(
usage_ratio=1., )
trainer_settings = dict( trainer_settings = dict(
save_dir="./output/model", save_dir="./output/model",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册