sequence_arithmetic_compute.h 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <algorithm>
#include <cstring>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

template <typename T>
class SequenceArithmeticCompute
    : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
 public:
  using param_t = operators::SequenceArithmeticParam;

  void Run() override {
    auto& param = *param_.get_mutable<param_t>();
    auto x = param.X;
    auto y = param.Y;
    auto out = param.Out;
    int op_type = param.op_type;

    out->Resize(x->dims());
    out->set_lod(x->lod());

    auto x_data = x->data<T>();
    auto y_data = y->data<T>();
    auto out_data = out->mutable_data<T>();
    auto x_seq_offset = x->lod()[0];
    auto y_seq_offset = y->lod()[0];
    int seq_num = x_seq_offset.size() - 1;
    int inner_size = (x->numel()) / (x->dims()[0]);

    // sum
    if (op_type == 1) {
      for (int i = 0; i < seq_num; i++) {
        int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size;
        int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size;
        auto input_x = x_data + x_seq_offset[i] * inner_size;
        auto input_y = y_data + y_seq_offset[i] * inner_size;
        auto t_out = out_data + x_seq_offset[i] * inner_size;
        int len = std::min(len_x, len_y);
        for (int j = 0; j < len; j++) {
          t_out[j] = input_x[j] + input_y[j];
        }
        if (len_x > len) {
          memcpy(t_out + len, input_x + len, sizeof(T) * (len_x - len));
        }
      }
    }

    // sub
    if (op_type == 2) {
      for (int i = 0; i < seq_num; i++) {
        int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size;
        int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size;
        auto input_x = x_data + x_seq_offset[i] * inner_size;
        auto input_y = y_data + y_seq_offset[i] * inner_size;
        auto t_out = out_data + x_seq_offset[i] * inner_size;
        int len = std::min(len_x, len_y);
        for (int j = 0; j < len; j++) {
          t_out[j] = input_x[j] - input_y[j];
        }
        if (len_x > len) {
          memcpy(t_out + len, input_x + len, sizeof(T) * (len_x - len));
        }
      }
    }

    // mul
    if (op_type == 3) {
      for (int i = 0; i < seq_num; i++) {
        int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size;
        int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size;
        auto input_x = x_data + x_seq_offset[i] * inner_size;
        auto input_y = y_data + y_seq_offset[i] * inner_size;
        auto t_out = out_data + x_seq_offset[i] * inner_size;
        int len = std::min(len_x, len_y);
        for (int j = 0; j < len; j++) {
          t_out[j] = input_x[j] * input_y[j];
        }
        if (len_x > len) {
          memcpy(t_out + len, input_x + len, sizeof(T) * (len_x - len));
        }
      }
    }
  }

  virtual ~SequenceArithmeticCompute() = default;
};

}  // namespace x86
}  // namespace kernels
}  // namespace lite
}  // namespace paddle