提交 87afc6dc 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1767 from reyoung/feature/better_hsigmoid_interface

It is no need to config num_classes in hsigmoid
...@@ -1916,7 +1916,7 @@ def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None): ...@@ -1916,7 +1916,7 @@ def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None):
@layer_support() @layer_support()
def hsigmoid(input, def hsigmoid(input,
label, label,
num_classes, num_classes=None,
name=None, name=None,
bias_attr=None, bias_attr=None,
param_attr=None, param_attr=None,
...@@ -1932,8 +1932,7 @@ def hsigmoid(input, ...@@ -1932,8 +1932,7 @@ def hsigmoid(input,
.. code-block:: python .. code-block:: python
cost = hsigmoid(input=[layer1, layer2], cost = hsigmoid(input=[layer1, layer2],
label=data_layer, label=data_layer)
num_classes=3)
:param input: Input layers. It could be a LayerOutput or list/tuple of :param input: Input layers. It could be a LayerOutput or list/tuple of
LayerOutput. LayerOutput.
...@@ -1941,12 +1940,14 @@ def hsigmoid(input, ...@@ -1941,12 +1940,14 @@ def hsigmoid(input,
:param label: Label layer. :param label: Label layer.
:type label: LayerOutput :type label: LayerOutput
:param num_classes: number of classes. :param num_classes: number of classes.
:type num_classes: int :type num_classes: int|None
:param name: layer name :param name: layer name
:type name: basestring :type name: basestring
:param bias_attr: Bias attribute. None means default bias. :param bias_attr: Bias attribute. None means default bias.
False means no bias. False means no bias.
:type bias_attr: ParameterAttribute|False :type bias_attr: ParameterAttribute|False
:param param_attr: Parameter Attribute. None means default parameter.
:type param_attr: ParameterAttribute|None
:param layer_attr: Extra Layer Attribute. :param layer_attr: Extra Layer Attribute.
:type layer_attr: ExtraLayerAttribute :type layer_attr: ExtraLayerAttribute
:return: LayerOutput object. :return: LayerOutput object.
...@@ -1966,6 +1967,11 @@ def hsigmoid(input, ...@@ -1966,6 +1967,11 @@ def hsigmoid(input,
assert isinstance(label, LayerOutput) assert isinstance(label, LayerOutput)
assert label.layer_type == LayerType.DATA assert label.layer_type == LayerType.DATA
if num_classes is None:
num_classes = label.size
if num_classes is None or num_classes <= 2:
raise ValueError("hsigmoid label size must larger than 2.")
ipts_for_layer = [] ipts_for_layer = []
parents = [] parents = []
for each_input, each_param_attr in zip(input, param_attr): for each_input, each_param_attr in zip(input, param_attr):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册