diff --git a/paddle/operators/beam_search_decode_op.cc b/paddle/operators/beam_search_decode_op.cc index 3904a97d58166cfeeb2be7d2144700dbd8bc5721..c796a0c5d089499e7858c7a427825fdbeb05cb7f 100644 --- a/paddle/operators/beam_search_decode_op.cc +++ b/paddle/operators/beam_search_decode_op.cc @@ -17,6 +17,36 @@ limitations under the License. */ namespace paddle { namespace operators { +struct BeamSearchDecodeFunctor { + BeamSearchDecodeFunctor(const LoDTensorArray& step_ids, + const LoDTensorArray& step_scores, + LoDTensor* id_tensor, LoDTensor* score_tensor) + : step_ids_(step_ids), + step_scores_(step_scores), + id_tensor_(id_tensor), + score_tensor_(score_tensor) {} + + template + void operator()() const; + + const LoDTensorArray& step_ids_; + const LoDTensorArray& step_scores_; + LoDTensor* id_tensor_; + LoDTensor* score_tensor_; +}; + +template +void BeamSearchDecodeFunctor::operator()() const { + BeamSearchDecoder beam_search_decoder; + beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_, + score_tensor_); +} + +template <> +void BeamSearchDecodeFunctor::operator()() const { + PADDLE_THROW("beam search decode op does not support bool!"); +} + class BeamSearchDecodeOp : public framework::OperatorBase { public: BeamSearchDecodeOp(const std::string& type, @@ -45,9 +75,9 @@ class BeamSearchDecodeOp : public framework::OperatorBase { LoDTensor* sentenceIds = ctx.Output("SentenceIds"); LoDTensor* sentenceScores = ctx.Output("SentenceScores"); - BeamSearchDecoder beam_search_decoder; - beam_search_decoder.PackAllSteps(*ids, *scores, sentenceIds, - sentenceScores); + framework::VisitDataType( + framework::ToDataType(scores->at(0).type()), + BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores)); } }; diff --git a/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py b/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py index 8a11820d2aba2dd4d17d925f0e0fe9f324100418..5fad7d8cce5af3677aa77dc0abb64f1ecd380419 100644 --- a/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py +++ b/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py @@ -35,15 +35,15 @@ class TestBeamSearchDecodeOp(unittest.TestCase): self.append_lod_tensor( scores, [[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]], np.array( - [1, 2, 3, 4, 5, 6], dtype="float32")) + [1, 2, 3, 4, 5, 6], dtype="float64")) self.append_lod_tensor( scores, [[0, 3, 6], [0, 1, 1, 3, 5, 5, 6]], np.array( - [0, 1, 2, 3, 4, 5], dtype="float32")) + [0, 1, 2, 3, 4, 5], dtype="float64")) self.append_lod_tensor( scores, [[0, 3, 6], [0, 0, 1, 2, 3, 4, 5]], np.array( - [0, 1, 2, 3, 4], dtype="float32")) + [0, 1, 2, 3, 4], dtype="float64")) sentence_ids = self.scope.var("sentence_ids").get_tensor() sentence_scores = self.scope.var("sentence_scores").get_tensor()