beam_search_decode_op.cc 9.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <string>
16 17

#include "paddle/fluid/operators/beam_search_decode_op.h"
Y
Yi Wang 已提交
18
#include "paddle/fluid/platform/device_context.h"
Q
Qiao Longfei 已提交
19

W
wanghuancoder 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
class Scope;
template <typename T>
class EmptyGradOpMaker;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
}  // namespace paddle

Q
Qiao Longfei 已提交
33 34 35
namespace paddle {
namespace operators {

36 37 38
struct BeamSearchDecodeFunctor {
  BeamSearchDecodeFunctor(const LoDTensorArray& step_ids,
                          const LoDTensorArray& step_scores,
39 40 41 42 43
                          LoDTensor* id_tensor, LoDTensor* score_tensor,
                          size_t beam_size, int end_id)
      : beam_size_(beam_size),
        end_id_(end_id),
        step_ids_origin_(step_ids),
44
        step_scores_origin_(step_scores),
45
        id_tensor_(id_tensor),
46 47 48 49 50 51 52 53 54 55 56
        score_tensor_(score_tensor) {
    tensor_on_gpu_ = false;
    // First make a copy of GPU data on CPU
    if (platform::is_gpu_place(step_ids_origin_[0].place())) {
      tensor_on_gpu_ = true;
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(step_ids_origin_[0].place());
      // Copy all tensors in the input tensor array
      for (auto& step_id : step_ids_origin_) {
        framework::LoDTensor out;
57 58 59 60 61
        if (step_id.numel() > 0) {
          dev_ctx->Wait();
          framework::TensorCopy(step_id, platform::CPUPlace(), *dev_ctx, &out);
          dev_ctx->Wait();
        }
62 63 64 65 66 67 68 69 70 71 72 73 74

        out.set_lod(step_id.lod());
        step_ids_.push_back(out);
      }
    }
    if (platform::is_gpu_place(step_scores_origin_[0].place())) {
      tensor_on_gpu_ = true;
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(step_scores_origin_[0].place());
      // Copy all tensors in the input tensor array
      for (auto& step_score : step_scores_origin_) {
        framework::LoDTensor out;
75 76 77 78 79 80
        if (step_score.numel() > 0) {
          dev_ctx->Wait();
          framework::TensorCopy(step_score, platform::CPUPlace(), *dev_ctx,
                                &out);
          dev_ctx->Wait();
        }
81 82 83 84 85 86

        out.set_lod(step_score.lod());
        step_scores_.push_back(out);
      }
    }
  }
87 88

  template <typename T>
D
dzhwinter 已提交
89
  void apply() const;
90

91
  bool tensor_on_gpu_;
92 93
  size_t beam_size_;
  int end_id_;
Y
Yan Chunwei 已提交
94 95 96
  // TODO(Superjomn) Here might result serious performance issue in the
  // concurrency
  // scenarios.
97 98 99 100
  const LoDTensorArray& step_ids_origin_;
  const LoDTensorArray& step_scores_origin_;
  LoDTensorArray step_ids_ = LoDTensorArray();
  LoDTensorArray step_scores_ = LoDTensorArray();
101 102 103 104 105
  LoDTensor* id_tensor_;
  LoDTensor* score_tensor_;
};

template <typename T>
D
dzhwinter 已提交
106
void BeamSearchDecodeFunctor::apply() const {
107
  BeamSearchDecoder<T> beam_search_decoder(beam_size_, end_id_);
108 109
  // Check if the tensor is on GPU. If so, use the CPU copy instead
  if (tensor_on_gpu_) {
110 111
    beam_search_decoder.Backtrace(step_ids_, step_scores_, id_tensor_,
                                  score_tensor_);
112
  } else {
113 114
    beam_search_decoder.Backtrace(step_ids_origin_, step_scores_origin_,
                                  id_tensor_, score_tensor_);
115
  }
116 117 118
}

template <>
D
dzhwinter 已提交
119
void BeamSearchDecodeFunctor::apply<bool>() const {
120 121
  PADDLE_THROW(platform::errors::InvalidArgument(
      "beam search decode op does not support bool!"));
122 123
}

Q
Qiao Longfei 已提交
124 125 126 127 128 129 130
class BeamSearchDecodeOp : public framework::OperatorBase {
 public:
  BeamSearchDecodeOp(const std::string& type,
                     const framework::VariableNameMap& inputs,
                     const framework::VariableNameMap& outputs,
                     const framework::AttributeMap& attrs)
      : OperatorBase(type, inputs, outputs, attrs) {}
131 132 133 134

 private:
  void RunImpl(const framework::Scope& scope,
               const platform::Place& dev_place) const override {
Y
Yu Yang 已提交
135 136
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    auto& dev_ctx = *pool.Get(dev_place);
D
dzhwinter 已提交
137

X
Xin Pan 已提交
138
    framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope);
139
    framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx);
140

Q
Qiao Longfei 已提交
141 142 143
    const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
    const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
    const size_t step_num = ids->size();
144 145 146 147 148 149 150
    PADDLE_ENFORCE_GT(
        step_num, 0UL,
        platform::errors::InvalidArgument(
            "beam search steps, which is the"
            "size of Input(Ids) LoDTensorArray. beam search steps should "
            "be larger than 0, but received %d. ",
            step_num));
Q
Qiao Longfei 已提交
151
    const size_t source_num = ids->at(0).lod().at(0).size() - 1;
152 153 154 155 156 157 158 159
    PADDLE_ENFORCE_GT(
        source_num, 0UL,
        platform::errors::InvalidArgument(
            "source_num is the sequence number of the"
            "first decoding step, indicating by Input(Ids)[0].lod[0].size. "
            "The number of source_num should be larger than"
            "0, but received %d. ",
            source_num));
Q
Qiao Longfei 已提交
160 161

    for (size_t i = 0; i < step_num; ++i) {
162 163 164 165 166 167 168
      PADDLE_ENFORCE_EQ(
          ids->at(i).lod().size(), 2UL,
          platform::errors::InvalidArgument(
              "For the i step in beam search steps,"
              "the size of Input(Ids)[i].lod() should larger than 2,"
              "but received %d. ",
              ids->at(i).lod().size()));
Q
Qiao Longfei 已提交
169 170
    }

171 172 173
    size_t beam_size = ctx.Attr<int>("beam_size");
    int end_id = ctx.Attr<int>("end_id");

Q
Qiao Longfei 已提交
174 175 176 177
    // prepare output
    LoDTensor* sentenceIds = ctx.Output<LoDTensor>("SentenceIds");
    LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");

178
    framework::VisitDataType(
Y
Yu Yang 已提交
179
        scores->at(0).type(),
180 181
        BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores,
                                beam_size, end_id));
Q
Qiao Longfei 已提交
182 183 184 185 186
  }
};

class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
187
  void Make() override {
Q
Qiao Longfei 已提交
188 189
    AddInput("Ids",
             "(LodTensorArray)"
190
             "The LodTensorArray containing the selected ids of all steps");
Q
Qiao Longfei 已提交
191 192
    AddInput("Scores",
             "(LodTensorArray)"
193 194 195 196 197 198 199 200 201 202
             "The LodTensorArray containing the selected scores of all steps");
    AddOutput(
        "SentenceIds",
        "(LodTensor)"
        "An LodTensor containing all generated id sequences for all source "
        "sentences");
    AddOutput(
        "SentenceScores",
        "(LodTensor)"
        "An LodTensor containing scores corresponding to Output(SentenceIds)");
203 204 205
    AddAttr<int>("beam_size", "beam size for beam search");
    AddAttr<int>("end_id",
                 "the token id which indicates the end of a sequence");
Q
Qiao Longfei 已提交
206
    AddComment(R"DOC(
207 208 209 210
Beam Search Decode Operator. This Operator constructs the full hypotheses for
each source sentence by walking back along the LoDTensorArray Input(ids)
whose lods can be used to restore the path in the beam search tree.

M
minqiyang 已提交
211 212 213 214
The Output(SentenceIds) and Output(SentenceScores) separately contain the
generated id sequences and the corresponding scores. The shapes and lods of the
two LodTensor are same. The lod level is 2 and the two levels separately
indicate how many hypotheses each source sentence has and how many ids each
215
hypothesis has.
Q
Qiao Longfei 已提交
216 217 218 219 220 221 222
)DOC");
  }
};

class BeamSearchDecodeInferShape : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext* context) const override {
223 224 225 226 227 228 229 230
    OP_INOUT_CHECK(context->HasInput("Ids"), "Input", "Ids",
                   "BeamSearchDecode");
    OP_INOUT_CHECK(context->HasInput("Scores"), "Input", "Scores",
                   "BeamSearchDecode");
    OP_INOUT_CHECK(context->HasOutput("SentenceIds"), "Output", "SentenceIds",
                   "BeamSearchDecode");
    OP_INOUT_CHECK(context->HasOutput("SentenceScores"), "Output",
                   "SentenceScores", "BeamSearchDecode");
Q
Qiao Longfei 已提交
231 232 233 234 235
  }
};

class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
236
  void operator()(framework::InferVarTypeContext* ctx) const override {
237 238 239 240
    ctx->SetOutputType("SentenceIds", framework::proto::VarType::LOD_TENSOR,
                       framework::ALL_ELEMENTS);
    ctx->SetOutputType("SentenceScores", framework::proto::VarType::LOD_TENSOR,
                       framework::ALL_ELEMENTS);
Q
Qiao Longfei 已提交
241 242 243 244 245 246
  }
};

}  // namespace operators
}  // namespace paddle

H
hong 已提交
247 248 249 250 251 252 253
REGISTER_OPERATOR(
    beam_search_decode, paddle::operators::BeamSearchDecodeOp,
    paddle::operators::BeamSearchDecodeOpProtoMaker,
    paddle::operators::BeamSearchDecodeInferShape,
    paddle::operators::BeamSearchDecodeInferVarType,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);