beam_search_decode_op.cc 9.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <string>
16 17

#include "paddle/fluid/operators/beam_search_decode_op.h"
Y
Yi Wang 已提交
18
#include "paddle/fluid/platform/device_context.h"
Q
Qiao Longfei 已提交
19

W
wanghuancoder 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
class Scope;
template <typename T>
class EmptyGradOpMaker;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
}  // namespace paddle

Q
Qiao Longfei 已提交
33 34 35
namespace paddle {
namespace operators {

36 37 38
struct BeamSearchDecodeFunctor {
  BeamSearchDecodeFunctor(const LoDTensorArray& step_ids,
                          const LoDTensorArray& step_scores,
39 40 41 42 43
                          LoDTensor* id_tensor, LoDTensor* score_tensor,
                          size_t beam_size, int end_id)
      : beam_size_(beam_size),
        end_id_(end_id),
        step_ids_origin_(step_ids),
44
        step_scores_origin_(step_scores),
45
        id_tensor_(id_tensor),
46 47 48 49 50 51 52 53 54 55 56
        score_tensor_(score_tensor) {
    tensor_on_gpu_ = false;
    // First make a copy of GPU data on CPU
    if (platform::is_gpu_place(step_ids_origin_[0].place())) {
      tensor_on_gpu_ = true;
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(step_ids_origin_[0].place());
      // Copy all tensors in the input tensor array
      for (auto& step_id : step_ids_origin_) {
        framework::LoDTensor out;
57 58 59 60 61
        if (step_id.numel() > 0) {
          dev_ctx->Wait();
          framework::TensorCopy(step_id, platform::CPUPlace(), *dev_ctx, &out);
          dev_ctx->Wait();
        }
62 63 64 65 66 67 68 69 70 71 72 73 74

        out.set_lod(step_id.lod());
        step_ids_.push_back(out);
      }
    }
    if (platform::is_gpu_place(step_scores_origin_[0].place())) {
      tensor_on_gpu_ = true;
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(step_scores_origin_[0].place());
      // Copy all tensors in the input tensor array
      for (auto& step_score : step_scores_origin_) {
        framework::LoDTensor out;
75 76 77 78 79 80
        if (step_score.numel() > 0) {
          dev_ctx->Wait();
          framework::TensorCopy(step_score, platform::CPUPlace(), *dev_ctx,
                                &out);
          dev_ctx->Wait();
        }
81 82 83 84 85 86

        out.set_lod(step_score.lod());
        step_scores_.push_back(out);
      }
    }
  }
87 88

  template <typename T>
D
dzhwinter 已提交
89
  void apply() const;
90

91
  bool tensor_on_gpu_;
92 93
  size_t beam_size_;
  int end_id_;
Y
Yan Chunwei 已提交
94 95 96
  // TODO(Superjomn) Here might result serious performance issue in the
  // concurrency
  // scenarios.
97 98 99 100
  const LoDTensorArray& step_ids_origin_;
  const LoDTensorArray& step_scores_origin_;
  LoDTensorArray step_ids_ = LoDTensorArray();
  LoDTensorArray step_scores_ = LoDTensorArray();
101 102 103 104 105
  LoDTensor* id_tensor_;
  LoDTensor* score_tensor_;
};

template <typename T>
D
dzhwinter 已提交
106
void BeamSearchDecodeFunctor::apply() const {
107
  BeamSearchDecoder<T> beam_search_decoder(beam_size_, end_id_);
108 109
  // Check if the tensor is on GPU. If so, use the CPU copy instead
  if (tensor_on_gpu_) {
110 111
    beam_search_decoder.Backtrace(step_ids_, step_scores_, id_tensor_,
                                  score_tensor_);
112
  } else {
113 114
    beam_search_decoder.Backtrace(step_ids_origin_, step_scores_origin_,
                                  id_tensor_, score_tensor_);
115
  }
116 117 118
}

template <>
D
dzhwinter 已提交
119
void BeamSearchDecodeFunctor::apply<bool>() const {
120 121 122
  PADDLE_THROW("beam search decode op does not support bool!");
}

Q
Qiao Longfei 已提交
123 124 125 126 127 128 129
class BeamSearchDecodeOp : public framework::OperatorBase {
 public:
  BeamSearchDecodeOp(const std::string& type,
                     const framework::VariableNameMap& inputs,
                     const framework::VariableNameMap& outputs,
                     const framework::AttributeMap& attrs)
      : OperatorBase(type, inputs, outputs, attrs) {}
130 131 132 133

 private:
  void RunImpl(const framework::Scope& scope,
               const platform::Place& dev_place) const override {
Y
Yu Yang 已提交
134 135
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    auto& dev_ctx = *pool.Get(dev_place);
D
dzhwinter 已提交
136

X
Xin Pan 已提交
137
    framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope);
138
    framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx);
139

Q
Qiao Longfei 已提交
140 141 142
    const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
    const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
    const size_t step_num = ids->size();
143 144 145 146 147 148 149
    PADDLE_ENFORCE_GT(
        step_num, 0UL,
        platform::errors::InvalidArgument(
            "beam search steps, which is the"
            "size of Input(Ids) LoDTensorArray. beam search steps should "
            "be larger than 0, but received %d. ",
            step_num));
Q
Qiao Longfei 已提交
150
    const size_t source_num = ids->at(0).lod().at(0).size() - 1;
151 152 153 154 155 156 157 158
    PADDLE_ENFORCE_GT(
        source_num, 0UL,
        platform::errors::InvalidArgument(
            "source_num is the sequence number of the"
            "first decoding step, indicating by Input(Ids)[0].lod[0].size. "
            "The number of source_num should be larger than"
            "0, but received %d. ",
            source_num));
Q
Qiao Longfei 已提交
159 160

    for (size_t i = 0; i < step_num; ++i) {
161 162 163 164 165 166 167
      PADDLE_ENFORCE_EQ(
          ids->at(i).lod().size(), 2UL,
          platform::errors::InvalidArgument(
              "For the i step in beam search steps,"
              "the size of Input(Ids)[i].lod() should larger than 2,"
              "but received %d. ",
              ids->at(i).lod().size()));
Q
Qiao Longfei 已提交
168 169
    }

170 171 172
    size_t beam_size = ctx.Attr<int>("beam_size");
    int end_id = ctx.Attr<int>("end_id");

Q
Qiao Longfei 已提交
173 174 175 176
    // prepare output
    LoDTensor* sentenceIds = ctx.Output<LoDTensor>("SentenceIds");
    LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");

177
    framework::VisitDataType(
Y
Yu Yang 已提交
178
        scores->at(0).type(),
179 180
        BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores,
                                beam_size, end_id));
Q
Qiao Longfei 已提交
181 182 183 184 185
  }
};

class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
186
  void Make() override {
Q
Qiao Longfei 已提交
187 188
    AddInput("Ids",
             "(LodTensorArray)"
189
             "The LodTensorArray containing the selected ids of all steps");
Q
Qiao Longfei 已提交
190 191
    AddInput("Scores",
             "(LodTensorArray)"
192 193 194 195 196 197 198 199 200 201
             "The LodTensorArray containing the selected scores of all steps");
    AddOutput(
        "SentenceIds",
        "(LodTensor)"
        "An LodTensor containing all generated id sequences for all source "
        "sentences");
    AddOutput(
        "SentenceScores",
        "(LodTensor)"
        "An LodTensor containing scores corresponding to Output(SentenceIds)");
202 203 204
    AddAttr<int>("beam_size", "beam size for beam search");
    AddAttr<int>("end_id",
                 "the token id which indicates the end of a sequence");
Q
Qiao Longfei 已提交
205
    AddComment(R"DOC(
206 207 208 209
Beam Search Decode Operator. This Operator constructs the full hypotheses for
each source sentence by walking back along the LoDTensorArray Input(ids)
whose lods can be used to restore the path in the beam search tree.

M
minqiyang 已提交
210 211 212 213
The Output(SentenceIds) and Output(SentenceScores) separately contain the
generated id sequences and the corresponding scores. The shapes and lods of the
two LodTensor are same. The lod level is 2 and the two levels separately
indicate how many hypotheses each source sentence has and how many ids each
214
hypothesis has.
Q
Qiao Longfei 已提交
215 216 217 218 219 220 221
)DOC");
  }
};

class BeamSearchDecodeInferShape : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext* context) const override {
222 223 224 225 226 227 228 229
    OP_INOUT_CHECK(context->HasInput("Ids"), "Input", "Ids",
                   "BeamSearchDecode");
    OP_INOUT_CHECK(context->HasInput("Scores"), "Input", "Scores",
                   "BeamSearchDecode");
    OP_INOUT_CHECK(context->HasOutput("SentenceIds"), "Output", "SentenceIds",
                   "BeamSearchDecode");
    OP_INOUT_CHECK(context->HasOutput("SentenceScores"), "Output",
                   "SentenceScores", "BeamSearchDecode");
Q
Qiao Longfei 已提交
230 231 232 233 234
  }
};

class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
235
  void operator()(framework::InferVarTypeContext* ctx) const override {
236 237 238 239
    ctx->SetOutputType("SentenceIds", framework::proto::VarType::LOD_TENSOR,
                       framework::ALL_ELEMENTS);
    ctx->SetOutputType("SentenceScores", framework::proto::VarType::LOD_TENSOR,
                       framework::ALL_ELEMENTS);
Q
Qiao Longfei 已提交
240 241 242 243 244 245
  }
};

}  // namespace operators
}  // namespace paddle

H
hong 已提交
246 247 248 249 250 251 252
REGISTER_OPERATOR(
    beam_search_decode, paddle::operators::BeamSearchDecodeOp,
    paddle::operators::BeamSearchDecodeOpProtoMaker,
    paddle::operators::BeamSearchDecodeInferShape,
    paddle::operators::BeamSearchDecodeInferVarType,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);