diff --git a/demo/semantic_role_labeling/dataprovider.py b/demo/semantic_role_labeling/dataprovider.py index d4c137ef42c4e2ec609f3e6f809363e602dfd8dd..2c8e13462730a2e980fa1c3fe342ef0e062ab5d7 100644 --- a/demo/semantic_role_labeling/dataprovider.py +++ b/demo/semantic_role_labeling/dataprovider.py @@ -25,12 +25,13 @@ def hook(settings, word_dict, label_dict, predicate_dict, **kwargs): #all inputs are integral and sequential type settings.slots = [ integer_value_sequence(len(word_dict)), - integer_value_sequence(len(predicate_dict)), integer_value_sequence(len(word_dict)), integer_value_sequence(len(word_dict)), integer_value_sequence(len(word_dict)), integer_value_sequence(len(word_dict)), - integer_value_sequence(len(word_dict)), integer_value_sequence(2), + integer_value_sequence(len(word_dict)), + integer_value_sequence(len(predicate_dict)), + integer_value_sequence(2), integer_value_sequence(len(label_dict)) ] @@ -63,5 +64,5 @@ def process(settings, file_name): label_list = label.split() label_slot = [settings.label_dict.get(w) for w in label_list] - yield word_slot, predicate_slot, ctx_n2_slot, ctx_n1_slot, \ - ctx_0_slot, ctx_p1_slot, ctx_p2_slot, mark_slot, label_slot + yield word_slot, ctx_n2_slot, ctx_n1_slot, \ + ctx_0_slot, ctx_p1_slot, ctx_p2_slot, predicate_slot, mark_slot, label_slot diff --git a/demo/semantic_role_labeling/predict.py b/demo/semantic_role_labeling/predict.py index 2761814e1811e701122e0be4850526c5b290c457..a7f1e8f81f59f6fe95fd29593ef1a826e652e570 100644 --- a/demo/semantic_role_labeling/predict.py +++ b/demo/semantic_role_labeling/predict.py @@ -55,18 +55,14 @@ class Prediction(): slots = [ integer_value_sequence(len_dict), - integer_value_sequence(len_pred), integer_value_sequence(len_dict), integer_value_sequence(len_dict), integer_value_sequence(len_dict), integer_value_sequence(len_dict), integer_value_sequence(len_dict), + integer_value_sequence(len_pred), integer_value_sequence(2) ] - integer_value_sequence(len_dict), integer_value_sequence(len_dict), - integer_value_sequence(len_dict), integer_value_sequence(len_dict), - integer_value_sequence(len_dict), integer_value_sequence(2) - ] self.converter = DataProviderConverter(slots) def load_dict_label(self, dict_file, label_file, predicate_dict_file): @@ -104,8 +100,8 @@ class Prediction(): marks = mark.split() mark_slot = [int(w) for w in marks] - yield word_slot, predicate_slot, ctx_n2_slot, ctx_n1_slot, \ - ctx_0_slot, ctx_p1_slot, ctx_p2_slot, mark_slot + yield word_slot, ctx_n2_slot, ctx_n1_slot, \ + ctx_0_slot, ctx_p1_slot, ctx_p2_slot, predicate_slot, mark_slot def predict(self, data_file, output_file): """ diff --git a/demo/semantic_role_labeling/predict.sh b/demo/semantic_role_labeling/predict.sh index d0acdb0bd093974485475cf796c6d41ac7899135..88ab5898f7d41056f4fe549b3145760783b27bf9 100644 --- a/demo/semantic_role_labeling/predict.sh +++ b/demo/semantic_role_labeling/predict.sh @@ -18,7 +18,7 @@ set -e function get_best_pass() { cat $1 | grep -Pzo 'Test .*\n.*pass-.*' | \ sed -r 'N;s/Test.* cost=([0-9]+\.[0-9]+).*\n.*pass-([0-9]+)/\1 \2/g' | \ - sort | head -n 1 + sort -n | head -n 1 } log=train.log diff --git a/demo/semantic_role_labeling/test.sh b/demo/semantic_role_labeling/test.sh index c4ab44f5ca08aefd18f2851a1410aa08563925a9..f9e1bdcd4c752474329d36c4de3378f7d58e7b4b 100644 --- a/demo/semantic_role_labeling/test.sh +++ b/demo/semantic_role_labeling/test.sh @@ -18,7 +18,7 @@ set -e function get_best_pass() { cat $1 | grep -Pzo 'Test .*\n.*pass-.*' | \ sed -r 'N;s/Test.* cost=([0-9]+\.[0-9]+).*\n.*pass-([0-9]+)/\1 \2/g' |\ - sort | head -n 1 + sort -n | head -n 1 } log=train.log