提交 d1f3d646 编写于 作者: C chengmo

fix infer net word

上级 5d0e4a9c
......@@ -104,6 +104,11 @@ class TdmInferNet(object):
)
for layer_idx in range(self.first_layer_idx, self.max_layers):
current_layer_node = fluid.layers.reshape(
current_layer_node, [self.batch_size, -1])
current_layer_child_mask = fluid.layer.reshape(
current_layer_child_mask, [self.batch_size, -1])
current_layer_node_num = current_layer_node.shape[1]
node_emb = fluid.embedding(
......@@ -142,10 +147,10 @@ class TdmInferNet(object):
if layer_idx < self.max_layers - 1:
current_layer_node, current_layer_child_mask = \
fluid.contribs.layers.tdm_child(x=top_node,
node_nums=self.node_nums,
child_nums=self.child_nums,
param_attr=fluid.ParamAttr(name="TDM_Tree_Info"), dtype='int64')
fluid.contrib.layers.tdm_child(x=top_node,
node_nums=self.node_nums,
child_nums=self.child_nums,
param_attr=fluid.ParamAttr(name="TDM_Tree_Info"), dtype='int64')
total_node_score = fluid.layers.concat(node_score, axis=1)
total_node = fluid.layers.concat(node_list, axis=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册