From 65c859db7aadfdaccb1a04afe788d66d0e4a8694 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 24 Nov 2017 13:32:47 +0800 Subject: [PATCH] beam_search_decode support multi data type (#5847) * beam_search_decode support multi data type * add VisitDataType for beam search decode * use Specialization to handle bool * move Specialization of BeamSearchDecodeFunctor out of class --- paddle/operators/beam_search_decode_op.cc | 36 +++++++++++++++++-- .../fluid/tests/test_beam_search_decode_op.py | 6 ++-- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/paddle/operators/beam_search_decode_op.cc b/paddle/operators/beam_search_decode_op.cc index 3904a97d581..c796a0c5d08 100644 --- a/paddle/operators/beam_search_decode_op.cc +++ b/paddle/operators/beam_search_decode_op.cc @@ -17,6 +17,36 @@ limitations under the License. */ namespace paddle { namespace operators { +struct BeamSearchDecodeFunctor { + BeamSearchDecodeFunctor(const LoDTensorArray& step_ids, + const LoDTensorArray& step_scores, + LoDTensor* id_tensor, LoDTensor* score_tensor) + : step_ids_(step_ids), + step_scores_(step_scores), + id_tensor_(id_tensor), + score_tensor_(score_tensor) {} + + template + void operator()() const; + + const LoDTensorArray& step_ids_; + const LoDTensorArray& step_scores_; + LoDTensor* id_tensor_; + LoDTensor* score_tensor_; +}; + +template +void BeamSearchDecodeFunctor::operator()() const { + BeamSearchDecoder beam_search_decoder; + beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_, + score_tensor_); +} + +template <> +void BeamSearchDecodeFunctor::operator()() const { + PADDLE_THROW("beam search decode op does not support bool!"); +} + class BeamSearchDecodeOp : public framework::OperatorBase { public: BeamSearchDecodeOp(const std::string& type, @@ -45,9 +75,9 @@ class BeamSearchDecodeOp : public framework::OperatorBase { LoDTensor* sentenceIds = ctx.Output("SentenceIds"); LoDTensor* sentenceScores = ctx.Output("SentenceScores"); - BeamSearchDecoder beam_search_decoder; - beam_search_decoder.PackAllSteps(*ids, *scores, sentenceIds, - sentenceScores); + framework::VisitDataType( + framework::ToDataType(scores->at(0).type()), + BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores)); } }; diff --git a/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py b/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py index 8a11820d2ab..5fad7d8cce5 100644 --- a/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py +++ b/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py @@ -35,15 +35,15 @@ class TestBeamSearchDecodeOp(unittest.TestCase): self.append_lod_tensor( scores, [[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]], np.array( - [1, 2, 3, 4, 5, 6], dtype="float32")) + [1, 2, 3, 4, 5, 6], dtype="float64")) self.append_lod_tensor( scores, [[0, 3, 6], [0, 1, 1, 3, 5, 5, 6]], np.array( - [0, 1, 2, 3, 4, 5], dtype="float32")) + [0, 1, 2, 3, 4, 5], dtype="float64")) self.append_lod_tensor( scores, [[0, 3, 6], [0, 0, 1, 2, 3, 4, 5]], np.array( - [0, 1, 2, 3, 4], dtype="float32")) + [0, 1, 2, 3, 4], dtype="float64")) sentence_ids = self.scope.var("sentence_ids").get_tensor() sentence_scores = self.scope.var("sentence_scores").get_tensor() -- GitLab