提交 8a2219cb 编写于 作者: C chengmo

for merge

上级 d1f3d646
......@@ -104,12 +104,16 @@ class TdmInferNet(object):
)
for layer_idx in range(self.first_layer_idx, self.max_layers):
if layer_idx == 0:
current_layer_node_num = len(self.first_layer_node)
else:
current_layer_node_num = current_layer_node.shape[1] * \
current_layer_node.shape[2]
current_layer_node = fluid.layers.reshape(
current_layer_node, [self.batch_size, -1])
current_layer_child_mask = fluid.layer.reshape(
current_layer_child_mask, [self.batch_size, -1])
current_layer_node_num = current_layer_node.shape[1]
current_layer_node, [self.batch_size, current_layer_node_num])
current_layer_child_mask = fluid.layers.reshape(
current_layer_child_mask, [self.batch_size, current_layer_node_num])
node_emb = fluid.embedding(
input=current_layer_node,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册