elementwise_op_impl.cu.h 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

15 16
#pragma once

17
#include "paddle/fluid/framework/phi_utils.h"
18 19
#include "paddle/fluid/framework/tensor.h"

20
// only can include the headers in paddle/top/api dirs
21
#include "paddle/phi/api/lib/utils/tensor_utils.h"
22
#include "paddle/phi/kernels/funcs/broadcast_function.h"
23

24 25 26
namespace paddle {
namespace operators {

27
using ElementwiseType = phi::ElementwiseType;
28

29
template <typename OutT, typename Functor, int NumOuts = 1>
30
void LaunchSameDimsElementwiseCudaKernel(
31 32 33 34
    const KPDevice &ctx,
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs,
    Functor func) {
35 36
  std::vector<const phi::DenseTensor *> pt_inputs;
  std::vector<phi::DenseTensor *> pt_outputs;
37 38
  // TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
  // DenseTensor obj
39
  // generated by MakePhiDenseTensor can be destroyed when exits loop. *_tmp
40 41
  // can be deleted
  // when DenseTensor support copy constructor.
42 43
  std::vector<std::unique_ptr<phi::DenseTensor>> pt_inputs_tmp;
  std::vector<std::unique_ptr<phi::DenseTensor>> pt_outputs_tmp;
44 45
  for (auto in : ins) {
    pt_inputs_tmp.emplace_back(
46
        std::move(paddle::experimental::MakePhiDenseTensor(*in)));
47 48 49
  }
  for (auto out : *outs) {
    pt_outputs_tmp.emplace_back(
50
        std::move(paddle::experimental::MakePhiDenseTensor(*out)));
51 52 53 54 55 56
  }
  for (int i = 0; i < pt_inputs_tmp.size(); i++) {
    pt_inputs.push_back(pt_inputs_tmp[i].get());
  }
  for (int i = 0; i < pt_outputs_tmp.size(); i++) {
    pt_outputs.push_back(pt_outputs_tmp[i].get());
57
  }
58 59
  phi::funcs::ElementwiseKernel<OutT, Functor, NumOuts>(
      ctx, pt_inputs, &pt_outputs, func);
60 61 62 63
}

}  // namespace operators
}  // namespace paddle