提交 c469334c 编写于 作者: J JiabinYang

polish python code and comment, test=develop

上级 87648f8e
...@@ -47,11 +47,11 @@ template <typename DeviceContext, typename T> ...@@ -47,11 +47,11 @@ template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto in = detail::Ref(ctx.Input<framework::LoDTensor>("X")); auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
auto w = detail::Ref(ctx.Input<framework::LoDTensor>("W")); auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
auto* path = ctx.Input<framework::LoDTensor>("PTable"); auto* path = ctx.Input<framework::LoDTensor>("PTable");
auto* code = ctx.Input<framework::LoDTensor>("PathCode"); auto* code = ctx.Input<framework::LoDTensor>("PathCode");
auto label = detail::Ref(ctx.Input<framework::LoDTensor>("Label")); auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
auto* bias = ctx.Input<framework::LoDTensor>("Bias"); auto* bias = ctx.Input<framework::LoDTensor>("Bias");
auto* out = ctx.Output<framework::LoDTensor>("Out"); auto* out = ctx.Output<framework::LoDTensor>("Out");
auto* pre_out = ctx.Output<framework::LoDTensor>("PreOut"); auto* pre_out = ctx.Output<framework::LoDTensor>("PreOut");
...@@ -114,8 +114,8 @@ template <typename DeviceContext, typename T> ...@@ -114,8 +114,8 @@ template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto in = detail::Ref(ctx.Input<framework::LoDTensor>("X")); auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
auto w = detail::Ref(ctx.Input<framework::LoDTensor>("W")); auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
auto* path = ctx.Input<framework::LoDTensor>("PTable"); auto* path = ctx.Input<framework::LoDTensor>("PTable");
auto* code = ctx.Input<framework::LoDTensor>("PathCode"); auto* code = ctx.Input<framework::LoDTensor>("PathCode");
auto* bias = ctx.Input<framework::LoDTensor>("Bias"); auto* bias = ctx.Input<framework::LoDTensor>("Bias");
...@@ -124,9 +124,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -124,9 +124,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
bool is_sparse = ctx.Attr<bool>("is_sparse"); bool is_sparse = ctx.Attr<bool>("is_sparse");
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero; math::SetConstant<DeviceContext, T> zero;
auto label = detail::Ref(ctx.Input<framework::LoDTensor>("Label")); auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
auto pre_out = detail::Ref(ctx.Input<framework::LoDTensor>("PreOut")); auto& pre_out = detail::Ref(ctx.Input<framework::LoDTensor>("PreOut"));
auto out_grad = detail::Ref( auto& out_grad = detail::Ref(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))); ctx.Input<framework::LoDTensor>(framework::GradVarName("Out")));
framework::LoDTensor pre_out_grad; framework::LoDTensor pre_out_grad;
......
...@@ -4589,23 +4589,33 @@ def hsigmoid(input, ...@@ -4589,23 +4589,33 @@ def hsigmoid(input,
bias_attr=None, bias_attr=None,
name=None, name=None,
non_leaf_num=None, non_leaf_num=None,
ptable=None, path_table=None,
pcode=None, path_code=None,
is_costum=False, is_custom=False,
is_sparse=False): is_sparse=False):
""" """
The hierarchical sigmoid operator is used to accelerate the training The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a process of language model. This operator organizes the classes into a
complete binary tree, each leaf node represents a class(a word) and each complete binary tree, or you can use is_custom to pass your own tree to
implement hierarchical. Each leaf node represents a class(a word) and each
internal node acts as a binary classifier. For each word there's a unique internal node acts as a binary classifier. For each word there's a unique
path from root to it's leaf node, hsigmoid calculate the cost for each path from root to it's leaf node, hsigmoid calculate the cost for each
internal node on the path, and sum them to get a total cost. hsigmoid can internal node on the path, and sum them to get a total cost. hsigmoid can
achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N` achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
represents the size of word dict. represents the size of word dict.
Refer to `Hierarchical Probabilistic Neural Network Language Model Using default tree you can Refer to `Hierarchical Probabilistic Neural Network Language Model
<http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_ <http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_
And if you want to use the costumed tree by set 'is_custom' as true you may need to do following things first:
1. using your word dict to build a binary tree, each leaf node should be an word of your word dict
2. build a dict to store word_id -> word's leaf to root path, we call it path_table.
3. build a dict to store word_id -> code of word's leaf to root path, we call it path_code. Code
means label of each binary classification, using 1 indicate true, 0 indicate false.
4. now, each word should has its path and code along the path, you can pass a batch of path and code
related to the same batch of inputs.
Args: Args:
input (Variable): The input tensor variable with shape input (Variable): The input tensor variable with shape
:math:`[N \\times D]`, where :math:`N` is the size of mini-batch, :math:`[N \\times D]`, where :math:`N` is the size of mini-batch,
...@@ -4613,13 +4623,6 @@ def hsigmoid(input, ...@@ -4613,13 +4623,6 @@ def hsigmoid(input,
label (Variable): The tensor variable contains labels of training data. label (Variable): The tensor variable contains labels of training data.
It's a tensor with shape is :math:`[N \\times 1]`. It's a tensor with shape is :math:`[N \\times 1]`.
num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set
non_leaf_num: this defines the number of non-leaf nodes in costumed tree
ptable: (Variable|None) this variable can store each batch of samples' path to root,
it should be in leaf -> root order
ptable should have the same shape with pcode, and for each sample i ptable[i] indicates a np.array like
structure and each element in this array is indexes in parent nodes' Weight Matrix.
pcode: (Variable|None) this variable can store each batch of samples' code,
each code consist with every code of parent nodes. it should be in leaf -> root order
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid
will create ParamAttr as param_attr. If the Initializer of the param_attr will create ParamAttr as param_attr. If the Initializer of the param_attr
...@@ -4631,8 +4634,15 @@ def hsigmoid(input, ...@@ -4631,8 +4634,15 @@ def hsigmoid(input,
is not set, the bias is initialized zero. Default: None. is not set, the bias is initialized zero. Default: None.
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None. will be named automatically. Default: None.
is_costum: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is non_leaf_num: this defines the number of non-leaf nodes in costumed tree
set you need to set ptable/pcode/non_leaf_num, otherwise num_classes should be set path_table: (Variable|None) this variable can store each batch of samples' path to root,
it should be in leaf -> root order
path_table should have the same shape with path_code, and for each sample i path_table[i] indicates a np.array like
structure and each element in this array is indexes in parent nodes' Weight Matrix.
path_code: (Variable|None) this variable can store each batch of samples' code,
each code consist with every code of parent nodes. it should be in leaf -> root order
is_custom: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is
set you need to set path_table/path_code/non_leaf_num, otherwise num_classes should be set
is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient
of W and input will be sparse. of W and input will be sparse.
...@@ -4653,22 +4663,22 @@ def hsigmoid(input, ...@@ -4653,22 +4663,22 @@ def hsigmoid(input,
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
pre_out = helper.create_variable_for_type_inference(dtype) pre_out = helper.create_variable_for_type_inference(dtype)
dim = input.shape[1] dim = input.shape[1]
if ((num_classes is None) or (num_classes < 2)) and (not is_costum): if ((num_classes is None) or (num_classes < 2)) and (not is_custom):
raise ValueError( raise ValueError(
"num_classes must not be less than 2 with default tree") "num_classes must not be less than 2 with default tree")
if (is_costum) and (pcode is None): if (is_custom) and (path_code is None):
raise ValueError("pcode should not be None with costum tree") raise ValueError("path_code should not be None with costum tree")
elif (is_costum) and (ptable is None): elif (is_custom) and (path_table is None):
raise ValueError("ptable should not be None with costum tree") raise ValueError("path_table should not be None with costum tree")
elif (is_costum) and (non_leaf_num is None): elif (is_custom) and (non_leaf_num is None):
raise ValueError("non_leaf_num should not be None with costum tree") raise ValueError("non_leaf_num should not be None with costum tree")
else: else:
pass pass
weights = None weights = None
if not is_costum: if not is_custom:
weights = helper.create_parameter( weights = helper.create_parameter(
attr=helper.param_attr, attr=helper.param_attr,
shape=[num_classes - 1, dim], shape=[num_classes - 1, dim],
...@@ -4683,12 +4693,12 @@ def hsigmoid(input, ...@@ -4683,12 +4693,12 @@ def hsigmoid(input,
inputs = { inputs = {
"X": input, "X": input,
"W": weights, "W": weights,
"PTable": ptable, "PTable": path_table,
"PathCode": pcode, "PathCode": path_code,
"Label": label "Label": label
} }
if helper.bias_attr: if helper.bias_attr:
if not is_costum: if not is_custom:
bias = helper.create_parameter( bias = helper.create_parameter(
attr=helper.bias_attr, attr=helper.bias_attr,
shape=[num_classes - 1, 1], shape=[num_classes - 1, 1],
......
...@@ -43,9 +43,9 @@ class CodeTable(object): ...@@ -43,9 +43,9 @@ class CodeTable(object):
class CodeTableWithCustomTree(object): class CodeTableWithCustomTree(object):
def __init__(self, ptable, pcode, index): def __init__(self, path_table, path_code, index):
self.ptable_ = ptable self.ptable_ = path_table
self.pcode_ = pcode self.pcode_ = path_code
self.index_ = index self.index_ = index
def cal_index(self, bit): def cal_index(self, bit):
...@@ -102,9 +102,10 @@ def hsigmoid(x, w, label, bias, num_classes): ...@@ -102,9 +102,10 @@ def hsigmoid(x, w, label, bias, num_classes):
return pre_output, out return pre_output, out
def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): def hsigmoidWithCustomTree(x, w, path_table, path_code, label, bias,
num_classes):
batch_size = x.shape[0] batch_size = x.shape[0]
code_length = len(ptable[0]) code_length = len(path_table[0])
code_table = [0 for _ in range(code_length)] code_table = [0 for _ in range(code_length)]
# init pre_out with shape [N, code_length] # init pre_out with shape [N, code_length]
pre_output = np.zeros((batch_size, code_length)) pre_output = np.zeros((batch_size, code_length))
...@@ -112,13 +113,13 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): ...@@ -112,13 +113,13 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes):
out = np.zeros((batch_size, 1)).astype("float32") out = np.zeros((batch_size, 1)).astype("float32")
if isinstance(bias, np.ndarray): if isinstance(bias, np.ndarray):
for i in range(batch_size): for i in range(batch_size):
code_table = CodeTableWithCustomTree(ptable, pcode, i) code_table = CodeTableWithCustomTree(path_table, path_code, i)
length = code_table.get_length() length = code_table.get_length()
for j in range(length): for j in range(length):
idx = code_table.cal_index(j) idx = code_table.cal_index(j)
pre_output[i][j] += bias[idx][0] pre_output[i][j] += bias[idx][0]
for i in range(batch_size): for i in range(batch_size):
code_table = CodeTableWithCustomTree(ptable, pcode, i) code_table = CodeTableWithCustomTree(path_table, path_code, i)
length = code_table.get_length() length = code_table.get_length()
for j in range(length): for j in range(length):
idx = code_table.cal_index(j) idx = code_table.cal_index(j)
...@@ -127,7 +128,7 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes): ...@@ -127,7 +128,7 @@ def hsigmoidWithCustomTree(x, w, ptable, pcode, label, bias, num_classes):
pre_output = np.clip(pre_output, -40.0, 40.0) pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j) # out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in range(batch_size): for i in range(batch_size):
code_table = CodeTableWithCustomTree(ptable, pcode, i) code_table = CodeTableWithCustomTree(path_table, path_code, i)
length = code_table.get_length() length = code_table.get_length()
sum = 0.0 sum = 0.0
for j in range(length): for j in range(length):
...@@ -173,24 +174,24 @@ class TestHSigmoidOpSparse(OpTest): ...@@ -173,24 +174,24 @@ class TestHSigmoidOpSparse(OpTest):
x = np.random.random((batch_size, feature_size)).astype("float32") x = np.random.random((batch_size, feature_size)).astype("float32")
w = np.random.random((num_classes - 1, feature_size)).astype("float32") w = np.random.random((num_classes - 1, feature_size)).astype("float32")
label = np.array([0, 1, 4, 5]) label = np.array([0, 1, 4, 5])
ptable = np.array( path_table = np.array(
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
(0, 2, -1, -1, (0, 2, -1, -1,
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
bias = np.random.random((num_classes - 1, 1)).astype("float32") bias = np.random.random((num_classes - 1, 1)).astype("float32")
self.attrs = {'num_classes': num_classes, 'is_sparse': True} self.attrs = {'num_classes': num_classes, 'is_sparse': True}
self.inputs = { self.inputs = {
'X': x, 'X': x,
'W': w, 'W': w,
'PTable': ptable, 'PTable': path_table,
'PathCode': pcode, 'PathCode': path_code,
'Label': label, 'Label': label,
'Bias': bias 'Bias': bias
} }
pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, pre_output, out = hsigmoidWithCustomTree(x, w, path_table, path_code,
bias, num_classes) label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out} self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self): def test_check_output(self):
...@@ -200,11 +201,13 @@ class TestHSigmoidOpSparse(OpTest): ...@@ -200,11 +201,13 @@ class TestHSigmoidOpSparse(OpTest):
class TestHSigmoidOpWithSparseGrad(unittest.TestCase): class TestHSigmoidOpWithSparseGrad(unittest.TestCase):
def hs_net_conf(self, is_sparse): def hs_net_conf(self, is_sparse):
input_word = fluid.layers.data(name="x", shape=[1], dtype='int64') input_word = fluid.layers.data(name="x", shape=[1], dtype='int64')
ptable = fluid.layers.data(name='ptable', shape=[3], dtype='int64') path_table = fluid.layers.data(
pcode = fluid.layers.data(name='pcode', shape=[3], dtype='int64') name='path_table', shape=[3], dtype='int64')
path_code = fluid.layers.data(
name='path_code', shape=[3], dtype='int64')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
data_list = [input_word, ptable, pcode, label] data_list = [input_word, path_table, path_code, label]
emb = fluid.layers.embedding( emb = fluid.layers.embedding(
input=input_word, input=input_word,
...@@ -218,9 +221,9 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase): ...@@ -218,9 +221,9 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase):
label=label, label=label,
bias_attr=True, bias_attr=True,
non_leaf_num=3, non_leaf_num=3,
ptable=ptable, path_table=path_table,
pcode=pcode, path_code=path_code,
is_costum=True, is_custom=True,
is_sparse=is_sparse) is_sparse=is_sparse)
avg_cost = fluid.layers.reduce_mean(cost) avg_cost = fluid.layers.reduce_mean(cost)
...@@ -232,8 +235,8 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase): ...@@ -232,8 +235,8 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase):
start_up = fluid.default_startup_program() start_up = fluid.default_startup_program()
start_up.random_seed = 1 # Fix random seed start_up.random_seed = 1 # Fix random seed
x = np.arange(6).reshape(6) x = np.arange(6).reshape(6)
ptable = np.array([(1, 2, -1), (1, 2, -1)]) path_table = np.array([(1, 2, -1), (1, 2, -1)])
pcode = np.array([(1, 0, -1), (0, 0, -1)]) path_code = np.array([(1, 0, -1), (0, 0, -1)])
label = np.array([1, 4]) label = np.array([1, 4])
loss, data_list = self.hs_net_conf(is_sparse) loss, data_list = self.hs_net_conf(is_sparse)
...@@ -248,8 +251,8 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase): ...@@ -248,8 +251,8 @@ class TestHSigmoidOpWithSparseGrad(unittest.TestCase):
exe.run(start_up) exe.run(start_up)
result = list() result = list()
for i in range(10): for i in range(10):
data = [([[x[i % 2]]], [list(ptable[i % 2])], data = [([[x[i % 2]]], [list(path_table[i % 2])],
[list(pcode[i % 2])], [label[i % 2]])] [list(path_code[i % 2])], [label[i % 2]])]
loss_val = exe.run(main_program, loss_val = exe.run(main_program,
feed=feeder.feed(data), feed=feeder.feed(data),
...@@ -273,24 +276,24 @@ class TestHSigmoidOpWithCostumTree(OpTest): ...@@ -273,24 +276,24 @@ class TestHSigmoidOpWithCostumTree(OpTest):
w = np.random.random( w = np.random.random(
(num_classes - 1, feature_size)).astype("float32") * 2 (num_classes - 1, feature_size)).astype("float32") * 2
label = np.array([0, 1, 4, 5]) label = np.array([0, 1, 4, 5])
ptable = np.array( path_table = np.array(
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
(0, 2, -1, -1, (0, 2, -1, -1,
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
bias = np.random.random((num_classes - 1, 1)).astype("float32") bias = np.random.random((num_classes - 1, 1)).astype("float32")
self.attrs = {'num_classes': num_classes, 'is_sparse': False} self.attrs = {'num_classes': num_classes, 'is_sparse': False}
self.inputs = { self.inputs = {
'X': x, 'X': x,
'W': w, 'W': w,
'PTable': ptable, 'PTable': path_table,
'PathCode': pcode, 'PathCode': path_code,
'Label': label, 'Label': label,
'Bias': bias 'Bias': bias
} }
pre_output, out = hsigmoidWithCustomTree(x, w, ptable, pcode, label, pre_output, out = hsigmoidWithCustomTree(x, w, path_table, path_code,
bias, num_classes) label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out} self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self): def test_check_output(self):
...@@ -310,26 +313,26 @@ class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest): ...@@ -310,26 +313,26 @@ class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest):
w = np.random.random( w = np.random.random(
(num_classes - 1, feature_size)).astype("float32") * 2 (num_classes - 1, feature_size)).astype("float32") * 2
label = np.array([0, 1, 4, 5]) label = np.array([0, 1, 4, 5])
ptable = np.array( path_table = np.array(
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1), [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
(0, 2, -1, -1, (0, 2, -1, -1,
-1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) -1)]) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
pcode = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), ( path_code = np.array([(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (
1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store 1, 0, 0, -1, -1), (0, 1, -1, -1, -1)]) #np.array to store
# bias = np.random.random((num_classes - 1, 1)).astype("float32") # bias = np.random.random((num_classes - 1, 1)).astype("float32")
self.attrs = {'num_classes': num_classes, 'is_sparse': False} self.attrs = {'num_classes': num_classes, 'is_sparse': False}
self.inputs = { self.inputs = {
'X': x, 'X': x,
'W': w, 'W': w,
'PTable': ptable, 'PTable': path_table,
'PathCode': pcode, 'PathCode': path_code,
'Label': label, 'Label': label,
} }
pre_output, out = hsigmoidWithCustomTree( pre_output, out = hsigmoidWithCustomTree(
x=x, x=x,
w=w, w=w,
ptable=ptable, path_table=path_table,
pcode=pcode, path_code=path_code,
label=label, label=label,
bias=None, bias=None,
num_classes=num_classes) num_classes=num_classes)
......
...@@ -190,16 +190,18 @@ class TestBook(unittest.TestCase): ...@@ -190,16 +190,18 @@ class TestBook(unittest.TestCase):
with program_guard(program2): with program_guard(program2):
x2 = layers.data(name='x2', shape=[4, 8], dtype='float32') x2 = layers.data(name='x2', shape=[4, 8], dtype='float32')
y2 = layers.data(name='y2', shape=[4], dtype='int64') y2 = layers.data(name='y2', shape=[4], dtype='int64')
ptable = layers.data(name='ptable', shape=[4, 6], dtype='int64') path_table = layers.data(
pcode = layers.data(name='pcode', shape=[4, 6], dtype='int64') name='path_table', shape=[4, 6], dtype='int64')
path_code = layers.data(
name='path_code', shape=[4, 6], dtype='int64')
self.assertIsNotNone( self.assertIsNotNone(
layers.hsigmoid( layers.hsigmoid(
input=x2, input=x2,
label=y2, label=y2,
non_leaf_num=6, non_leaf_num=6,
ptable=ptable, path_table=path_table,
pcode=pcode, path_code=path_code,
is_costum=True)) is_custom=True))
print(str(program2)) print(str(program2))
def test_sequence_expand(self): def test_sequence_expand(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册