提交 f4be1d99 编写于 作者: J JiabinYang

polish code and test

上级 b8ff0972
......@@ -115,7 +115,7 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
"[batch_size, code_length], where code_length represents the "
"maximum path length from root to leaf nodes.")
.AsIntermediate();
AddAttr<AttrType>("num_classes", "(int, required), The number of classes")
AddAttr<AttrType>("num_classes", "(int, optional), The number of classes")
.SetDefault(2);
AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree.
......
......@@ -4348,12 +4348,14 @@ def nce(input,
def hsigmoid(input,
label,
num_classes,
ptabl=None,
num_classes=None,
non_leaf_num=None,
ptable=None,
pcode=None,
param_attr=None,
bias_attr=None,
name=None):
name=None,
is_costum=False):
"""
The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a
......@@ -4373,7 +4375,8 @@ def hsigmoid(input,
and :math:`D` is the feature size.
label (Variable): The tensor variable contains labels of training data.
It's a tensor with shape is :math:`[N \\times 1]`.
num_classes: (int), The number of classes, must not be less than 2.
num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set
non_leaf_num: this defines the number of non-leaf nodes in costumed tree
ptable: (Variable|None) this variable can store each batch of samples' path to root,
it should be in leaf -> root order
ptable should have the same shape with pcode, and for each sample i ptable[i] indicates a np.array like
......@@ -4409,20 +4412,33 @@ def hsigmoid(input,
out = helper.create_variable_for_type_inference(dtype)
pre_out = helper.create_variable_for_type_inference(dtype)
dim = input.shape[1]
if num_classes < 2:
raise ValueError("num_classes must not be less than 2.")
if (ptable is not None) and (pcode is None):
raise ValueError("pcode should not be None when ptable has been set")
elif (ptable is None) and (pcode is not None):
raise ValueError("ptable should not be None when pcode has been set")
if ((num_classes < 2) or (num_classes is None)) and (not is_costum):
raise ValueError(
"num_classes must not be less than 2 with default tree")
if (is_costum) and (pcode is None):
raise ValueError("pcode should not be None with costum tree")
elif (is_costum) and (ptable is None):
raise ValueError("ptable should not be None with costum tree")
elif (is_costum) and (non_leaf_num is None):
raise ValueError("non_leaf_num should not be None with costum tree")
else:
pass
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes - 1, dim],
is_bias=False,
dtype=input.dtype)
weights = None
if not is_costum:
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes - 1, dim],
is_bias=False,
dtype=input.dtype)
else:
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[non_leaf_num, dim],
is_bias=False,
dtype=input.dtype)
inputs = {
"X": input,
"W": weights,
......@@ -4431,12 +4447,20 @@ def hsigmoid(input,
"Label": label
}
if helper.bias_attr:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[1, num_classes - 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
if not is_costum:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[1, num_classes - 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
else:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[1, non_leaf_num],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
helper.append_op(
type="hierarchical_sigmoid",
inputs=inputs,
......
......@@ -185,6 +185,23 @@ class TestBook(unittest.TestCase):
input=x, label=y, num_classes=2))
print(str(program))
program2 = Program()
with program_guard(program2):
x2 = layers.data(name='x2', shape=[4, 8], dtype='float32')
y2 = layers.data(name='y2', shape=[4], dtype='int64')
ptable = layers.data(name='ptable', shape=[4, 6], dtype='int64')
pcode = layers.data(name='pcode', shape=[4, 6], dtype='int64')
self.assertIsNotNone(
layers.hsigmoid(
input=x2,
label=y2,
non_leaf_num=6,
ptable=ptable,
pcode=pcode,
is_costum=True))
print(str(program2))
def test_sequence_expand(self):
program = Program()
with program_guard(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册