提交 98a47232 编写于 作者: W walloollaw 提交者: qingqing01

add more help info in argmax when axis is not set (#879)

上级 81b673eb
......@@ -121,6 +121,7 @@ def generate_net_code(net_name, inputs_info):
net_codes = str(inspect.getsource(MyNet))
net_codes = net_codes.replace('MyNet(object)', '%s(Network)' % net_name)
net_codes = net_codes.replace('MyNet', net_name)
net_codes = net_codes.replace('"INPUTS_INFO"', inputs_info)
custom_layer_dir = os.path.dirname(os.path.abspath(__file__))
......
......@@ -64,6 +64,11 @@ class MaybeActivated(object):
if node.metadata.get('relu', False) != default:
self.inject_kwargs['relu'] = not default
default_slope = 0.0
slope = node.metadata.get('relu_negative_slope', default_slope)
if slope != default_slope:
self.inject_kwargs['relu_negative_slope'] = slope
def __call__(self, *args, **kwargs):
kwargs.update(self.inject_kwargs)
return TensorFlowNode(*args, **kwargs)
......@@ -108,11 +113,19 @@ class TensorFlowMapper(NodeMapper):
else:
# Stochastic pooling, for instance.
raise KaffeError('Unsupported pooling type.')
(kernel_params, padding) = self.get_kernel_params(node)
ceil_mode = getattr(node.layer.parameters, 'ceil_mode', True)
return TensorFlowNode(pool_op, kernel_params.kernel_h,
kernel_params.kernel_w, kernel_params.stride_h,
kernel_params.stride_w, ceil_mode, **padding)
global_pool = getattr(node.layer.parameters, 'global_pooling', False)
if global_pool:
input_shape = node.get_only_parent().output_shape
return TensorFlowNode(pool_op, input_shape.height,
input_shape.width, 1, 1, ceil_mode)
else:
(kernel_params, padding) = self.get_kernel_params(node)
return TensorFlowNode(pool_op, kernel_params.kernel_h,
kernel_params.kernel_w,
kernel_params.stride_h,
kernel_params.stride_w, ceil_mode, **padding)
def map_sigmoid(self, node):
return TensorFlowNode('sigmoid')
......@@ -169,6 +182,11 @@ class TensorFlowMapper(NodeMapper):
raise KaffeError('Unknown elementwise operation: {}'.format(
op_code))
def map_scale(self, node):
params = node.parameters
return TensorFlowNode(
'scale', axis=params.axis, num_axes=params.num_axes)
def commit(self, chains):
return chains
......
......@@ -95,12 +95,16 @@ def shape_convolution(node):
def shape_pool(node):
global_pool = getattr(node.layer.parameters, 'global_pooling', False)
if global_pool:
input_shape = node.get_only_parent().output_shape
return make_tensor(input_shape.batch_size, input_shape.channels, 1, 1)
ceil_mode = getattr(node.layer.parameters, 'ceil_mode', True)
if ceil_mode is True:
method = math.ceil
else:
method = math.floor
return get_strided_kernel_output_shape(node, method)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册