提交 b97f020f 编写于 作者: C caoying03

fix unittest error.

上级 7ff689f5
......@@ -70,9 +70,8 @@ void SequenceSliceLayer::checkInputs() {
const Argument& inputSeq = getInput(0);
CHECK(inputSeq.hasSeq()) << "The first input of sequence slic layer "
<< "must be a sequence.";
// Check inputs
const MatrixPtr indices1 = getInputValue(1);
CHECK_EQ(indices1->getHeight(),
CHECK_EQ(static_cast<size_t>(indices1->getHeight()),
inputSeq.hasSubseq() ? inputSeq.getNumSubSequences()
: inputSeq.getNumSequences())
<< "Height of the second input should be equal to number of sequence "
......
......@@ -6242,6 +6242,7 @@ def seq_slice_layer(input, starts, ends, name=None):
name, LayerType.SEQ_SLICE, parents=[input], size=input.size)
@wrap_name_default()
@layer_support()
def kmax_sequence_score_layer(input, name=None, beam_size=1):
"""
......
type: "nn"
layers {
name: "input"
type: "data"
size: 300
active_type: ""
}
layers {
name: "data"
name: "input_seq"
type: "data"
size: 128
active_type: ""
......@@ -17,7 +11,7 @@ layers {
size: 1
active_type: "exponential"
inputs {
input_layer_name: "data"
input_layer_name: "input_seq"
input_parameter_name: "___fc_layer_0__.w0"
}
bias_parameter_name: "___fc_layer_0__.wbias"
......@@ -51,15 +45,14 @@ parameters {
initial_strategy: 0
initial_smart: false
}
input_layer_names: "data"
input_layer_names: "input_seq"
output_layer_names: "__kmax_sequence_score_layer_0__"
sub_models {
name: "root"
layer_names: "input"
layer_names: "data"
layer_names: "input_seq"
layer_names: "__fc_layer_0__"
layer_names: "__kmax_sequence_score_layer_0__"
input_layer_names: "data"
input_layer_names: "input_seq"
output_layer_names: "__kmax_sequence_score_layer_0__"
is_recurrent_layer_group: false
}
......
......@@ -2,9 +2,7 @@
#coding=utf-8
from paddle.trainer_config_helpers import *
data = data_layer(name='input', size=300)
data = data_layer(name="data", size=128)
data = data_layer(name="input_seq", size=128)
scores = fc_layer(input=data, size=1, act=ExpActivation())
kmax_seq_id = kmax_sequence_score_layer(input=scores, beam_size=5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册