elementwise_op_impl.cu.h 2.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

15 16
#pragma once

17
#include "paddle/fluid/framework/phi_utils.h"
18 19
#include "paddle/fluid/framework/tensor.h"

20
// only can include the headers in paddle/top/api dirs
21
#include "paddle/phi/kernels/funcs/broadcast_function.h"
22

23 24 25
namespace paddle {
namespace operators {

26
template <typename OutT, typename Functor, int NumOuts = 1>
27
void LaunchSameDimsElementwiseCudaKernel(
28
    const KPDevice &ctx,
29 30
    const std::vector<const phi::DenseTensor *> &ins,
    std::vector<phi::DenseTensor *> *outs,
31
    Functor func) {
32 33
  std::vector<const phi::DenseTensor *> pt_inputs;
  std::vector<phi::DenseTensor *> pt_outputs;
34 35
  // TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
  // DenseTensor obj
36
  // generated by MakePhiDenseTensor can be destroyed when exits loop. *_tmp
37 38
  // can be deleted
  // when DenseTensor support copy constructor.
39 40
  std::vector<std::unique_ptr<phi::DenseTensor>> pt_inputs_tmp;
  std::vector<std::unique_ptr<phi::DenseTensor>> pt_outputs_tmp;
41 42
  for (auto in : ins) {
    pt_inputs_tmp.emplace_back(
H
Huang Jiyi 已提交
43
        std::move(std::make_unique<phi::DenseTensor>(*in)));
44 45 46
  }
  for (auto out : *outs) {
    pt_outputs_tmp.emplace_back(
H
Huang Jiyi 已提交
47
        std::move(std::make_unique<phi::DenseTensor>(*out)));
48 49 50 51 52 53
  }
  for (int i = 0; i < pt_inputs_tmp.size(); i++) {
    pt_inputs.push_back(pt_inputs_tmp[i].get());
  }
  for (int i = 0; i < pt_outputs_tmp.size(); i++) {
    pt_outputs.push_back(pt_outputs_tmp[i].get());
54
  }
55 56
  phi::funcs::ElementwiseKernel<OutT, Functor, NumOuts>(
      ctx, pt_inputs, &pt_outputs, func);
57 58 59 60
}

}  // namespace operators
}  // namespace paddle