sequence_concat_compute.h 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

template <typename T>
inline LoD ConcatLoD(const std::vector<lite::Tensor*>& xs,
                     std::vector<lite::Tensor>* xs_in_order) {
  std::vector<size_t> result;
  result.resize(xs[0]->lod()[0].size());

  for (size_t i = 1; i < result.size(); ++i) {
    size_t sum = 0;
    for (size_t j = 0; j < xs.size(); ++j) {
      auto& x_lod = xs[j]->lod()[0];
      if (x_lod[i - 1] < x_lod[i]) {
        xs_in_order->emplace_back(xs[j]->Slice<T>(x_lod[i - 1], x_lod[i]));
      }
      sum += x_lod[i];
    }
    result[i] = sum;
  }
  LoD lod;
  lod.emplace_back(result);
  return lod;
}

template <typename T>
class SequenceConcatCompute
    : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
 public:
  using param_t = operators::SequenceConcatParam;

  void Run() override {
    auto& param = *param_.get_mutable<param_t>();
W
Wilber 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

    int64_t batch_size = 0;
    int64_t feature_size = 0;
    std::vector<int64_t> out_dims;
    for (const auto& tensor : param.X) {
      const auto x_dims = tensor->dims();
      if (out_dims.empty()) {
        out_dims = x_dims.Vectorize();
      }
      batch_size += x_dims[0];
      if (feature_size == 0) {
        feature_size = x_dims.production() / x_dims[0];
      } else {
        CHECK_EQ(feature_size, x_dims.production() / x_dims[0])
            << "Inputs of sequence concat must have same feature size";
      }
    }
    if (batch_size < 0) {
      batch_size = -1;  // Normalize batch size for compile time.
    }
    out_dims[0] = batch_size;
    param.Out->Resize(out_dims);

78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    T* dout = param.Out->mutable_data<T>();

    std::vector<lite::Tensor> x_in_order;
    param.Out->set_lod(ConcatLoD<T>(param.X, &x_in_order));

    int num = x_in_order.size();
    int out_rows = 1;

    std::vector<int64_t> input_cols(num);
    for (int i = 0; i < num; ++i) {
      input_cols[i] = x_in_order[i].numel() / out_rows;
    }

    int col_idx = 0;
    for (int j = 0; j < num; ++j) {
      int col_len = input_cols[j];
      auto input_data = x_in_order[j].data<T>();
      memcpy(dout + col_idx, input_data, sizeof(T) * col_len);
      col_idx += col_len;
    }
  }

  virtual ~SequenceConcatCompute() = default;
};

}  // namespace x86
}  // namespace kernels
}  // namespace lite
}  // namespace paddle