beam_search_decode_op.cc 9.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <string>
16 17

#include "paddle/fluid/operators/beam_search_decode_op.h"
Y
Yi Wang 已提交
18
#include "paddle/fluid/platform/device_context.h"
Q
Qiao Longfei 已提交
19

W
wanghuancoder 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
class Scope;
template <typename T>
class EmptyGradOpMaker;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
}  // namespace paddle

Q
Qiao Longfei 已提交
33 34 35
namespace paddle {
namespace operators {

36 37 38
struct BeamSearchDecodeFunctor {
  BeamSearchDecodeFunctor(const LoDTensorArray& step_ids,
                          const LoDTensorArray& step_scores,
39 40 41 42 43
                          LoDTensor* id_tensor, LoDTensor* score_tensor,
                          size_t beam_size, int end_id)
      : beam_size_(beam_size),
        end_id_(end_id),
        step_ids_origin_(step_ids),
44
        step_scores_origin_(step_scores),
45
        id_tensor_(id_tensor),
46 47
        score_tensor_(score_tensor) {
    tensor_on_gpu_ = false;
48
    tensor_on_npu_ = false;
49
    // First make a copy of GPU data on CPU
50 51 52 53 54 55 56
    if (platform::is_gpu_place(step_ids_origin_[0].place()) ||
        platform::is_npu_place(step_ids_origin_[0].place())) {
      if (platform::is_gpu_place(step_ids_origin_[0].place())) {
        tensor_on_gpu_ = true;
      } else {
        tensor_on_npu_ = true;
      }
57 58 59 60 61 62
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(step_ids_origin_[0].place());
      // Copy all tensors in the input tensor array
      for (auto& step_id : step_ids_origin_) {
        framework::LoDTensor out;
63
        if (step_id.numel() > 0) {
64 65 66
          if (tensor_on_gpu_) {
            dev_ctx->Wait();
          }
67 68 69
          framework::TensorCopy(step_id, platform::CPUPlace(), *dev_ctx, &out);
          dev_ctx->Wait();
        }
70 71 72 73 74

        out.set_lod(step_id.lod());
        step_ids_.push_back(out);
      }
    }
75 76 77 78 79 80 81
    if (platform::is_gpu_place(step_scores_origin_[0].place()) ||
        platform::is_npu_place(step_scores_origin_[0].place())) {
      if (platform::is_gpu_place(step_scores_origin_[0].place())) {
        tensor_on_gpu_ = true;
      } else {
        tensor_on_npu_ = true;
      }
82 83 84 85 86 87
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(step_scores_origin_[0].place());
      // Copy all tensors in the input tensor array
      for (auto& step_score : step_scores_origin_) {
        framework::LoDTensor out;
88
        if (step_score.numel() > 0) {
89 90 91
          if (tensor_on_gpu_) {
            dev_ctx->Wait();
          }
92 93 94 95
          framework::TensorCopy(step_score, platform::CPUPlace(), *dev_ctx,
                                &out);
          dev_ctx->Wait();
        }
96 97 98 99 100 101

        out.set_lod(step_score.lod());
        step_scores_.push_back(out);
      }
    }
  }
102 103

  template <typename T>
D
dzhwinter 已提交
104
  void apply() const;
105

106
  bool tensor_on_gpu_;
107
  bool tensor_on_npu_;
108 109
  size_t beam_size_;
  int end_id_;
Y
Yan Chunwei 已提交
110 111 112
  // TODO(Superjomn) Here might result serious performance issue in the
  // concurrency
  // scenarios.
113 114 115 116
  const LoDTensorArray& step_ids_origin_;
  const LoDTensorArray& step_scores_origin_;
  LoDTensorArray step_ids_ = LoDTensorArray();
  LoDTensorArray step_scores_ = LoDTensorArray();
117 118 119 120 121
  LoDTensor* id_tensor_;
  LoDTensor* score_tensor_;
};

template <typename T>
D
dzhwinter 已提交
122
void BeamSearchDecodeFunctor::apply() const {
123
  BeamSearchDecoder<T> beam_search_decoder(beam_size_, end_id_);
124 125
  // Check if the tensor is on GPU or NPU. If so, use the CPU copy instead
  if (tensor_on_gpu_ || tensor_on_npu_) {
126 127
    beam_search_decoder.Backtrace(step_ids_, step_scores_, id_tensor_,
                                  score_tensor_);
128
  } else {
129 130
    beam_search_decoder.Backtrace(step_ids_origin_, step_scores_origin_,
                                  id_tensor_, score_tensor_);
131
  }
132 133 134
}

template <>
D
dzhwinter 已提交
135
void BeamSearchDecodeFunctor::apply<bool>() const {
136 137
  PADDLE_THROW(platform::errors::InvalidArgument(
      "beam search decode op does not support bool!"));
138 139
}

Q
Qiao Longfei 已提交
140 141 142 143 144 145 146
class BeamSearchDecodeOp : public framework::OperatorBase {
 public:
  BeamSearchDecodeOp(const std::string& type,
                     const framework::VariableNameMap& inputs,
                     const framework::VariableNameMap& outputs,
                     const framework::AttributeMap& attrs)
      : OperatorBase(type, inputs, outputs, attrs) {}
147 148 149 150

 private:
  void RunImpl(const framework::Scope& scope,
               const platform::Place& dev_place) const override {
Y
Yu Yang 已提交
151 152
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    auto& dev_ctx = *pool.Get(dev_place);
D
dzhwinter 已提交
153

X
Xin Pan 已提交
154
    framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope);
155
    framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx);
156

Q
Qiao Longfei 已提交
157 158 159
    const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
    const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
    const size_t step_num = ids->size();
160 161 162 163 164 165 166
    PADDLE_ENFORCE_GT(
        step_num, 0UL,
        platform::errors::InvalidArgument(
            "beam search steps, which is the"
            "size of Input(Ids) LoDTensorArray. beam search steps should "
            "be larger than 0, but received %d. ",
            step_num));
Q
Qiao Longfei 已提交
167
    const size_t source_num = ids->at(0).lod().at(0).size() - 1;
168 169 170 171 172 173 174 175
    PADDLE_ENFORCE_GT(
        source_num, 0UL,
        platform::errors::InvalidArgument(
            "source_num is the sequence number of the"
            "first decoding step, indicating by Input(Ids)[0].lod[0].size. "
            "The number of source_num should be larger than"
            "0, but received %d. ",
            source_num));
Q
Qiao Longfei 已提交
176 177

    for (size_t i = 0; i < step_num; ++i) {
178 179 180 181 182 183 184
      PADDLE_ENFORCE_EQ(
          ids->at(i).lod().size(), 2UL,
          platform::errors::InvalidArgument(
              "For the i step in beam search steps,"
              "the size of Input(Ids)[i].lod() should larger than 2,"
              "but received %d. ",
              ids->at(i).lod().size()));
Q
Qiao Longfei 已提交
185 186
    }

187 188 189
    size_t beam_size = ctx.Attr<int>("beam_size");
    int end_id = ctx.Attr<int>("end_id");

Q
Qiao Longfei 已提交
190 191 192 193
    // prepare output
    LoDTensor* sentenceIds = ctx.Output<LoDTensor>("SentenceIds");
    LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");

194
    framework::VisitDataType(
Y
Yu Yang 已提交
195
        scores->at(0).type(),
196 197
        BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores,
                                beam_size, end_id));
Q
Qiao Longfei 已提交
198 199 200 201 202
  }
};

class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
203
  void Make() override {
Q
Qiao Longfei 已提交
204 205
    AddInput("Ids",
             "(LodTensorArray)"
206
             "The LodTensorArray containing the selected ids of all steps");
Q
Qiao Longfei 已提交
207 208
    AddInput("Scores",
             "(LodTensorArray)"
209 210 211 212 213 214 215 216 217 218
             "The LodTensorArray containing the selected scores of all steps");
    AddOutput(
        "SentenceIds",
        "(LodTensor)"
        "An LodTensor containing all generated id sequences for all source "
        "sentences");
    AddOutput(
        "SentenceScores",
        "(LodTensor)"
        "An LodTensor containing scores corresponding to Output(SentenceIds)");
219 220 221
    AddAttr<int>("beam_size", "beam size for beam search");
    AddAttr<int>("end_id",
                 "the token id which indicates the end of a sequence");
Q
Qiao Longfei 已提交
222
    AddComment(R"DOC(
223 224 225 226
Beam Search Decode Operator. This Operator constructs the full hypotheses for
each source sentence by walking back along the LoDTensorArray Input(ids)
whose lods can be used to restore the path in the beam search tree.

M
minqiyang 已提交
227 228 229 230
The Output(SentenceIds) and Output(SentenceScores) separately contain the
generated id sequences and the corresponding scores. The shapes and lods of the
two LodTensor are same. The lod level is 2 and the two levels separately
indicate how many hypotheses each source sentence has and how many ids each
231
hypothesis has.
Q
Qiao Longfei 已提交
232 233 234 235 236 237 238
)DOC");
  }
};

class BeamSearchDecodeInferShape : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext* context) const override {
239 240 241 242 243 244 245 246
    OP_INOUT_CHECK(context->HasInput("Ids"), "Input", "Ids",
                   "BeamSearchDecode");
    OP_INOUT_CHECK(context->HasInput("Scores"), "Input", "Scores",
                   "BeamSearchDecode");
    OP_INOUT_CHECK(context->HasOutput("SentenceIds"), "Output", "SentenceIds",
                   "BeamSearchDecode");
    OP_INOUT_CHECK(context->HasOutput("SentenceScores"), "Output",
                   "SentenceScores", "BeamSearchDecode");
Q
Qiao Longfei 已提交
247 248 249 250 251
  }
};

class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
252
  void operator()(framework::InferVarTypeContext* ctx) const override {
253 254 255 256
    ctx->SetOutputType("SentenceIds", framework::proto::VarType::LOD_TENSOR,
                       framework::ALL_ELEMENTS);
    ctx->SetOutputType("SentenceScores", framework::proto::VarType::LOD_TENSOR,
                       framework::ALL_ELEMENTS);
Q
Qiao Longfei 已提交
257 258 259 260 261 262
  }
};

}  // namespace operators
}  // namespace paddle

H
hong 已提交
263 264 265 266 267 268 269
REGISTER_OPERATOR(
    beam_search_decode, paddle::operators::BeamSearchDecodeOp,
    paddle::operators::BeamSearchDecodeOpProtoMaker,
    paddle::operators::BeamSearchDecodeInferShape,
    paddle::operators::BeamSearchDecodeInferVarType,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);