transform.h 5.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14 15 16

#pragma once

17 18 19
#include <algorithm>
#include <type_traits>

Y
Yi Wang 已提交
20 21 22
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
23
#include "paddle/phi/core/hostdevice.h"
Y
Yu Yang 已提交
24

25
#if defined(__NVCC__) || defined(__HIPCC__)
Y
Yu Yang 已提交
26
#include <thrust/execution_policy.h>
Y
Yu Yang 已提交
27
#include <thrust/transform.h>
28

29
#include "paddle/fluid/platform/details/cuda_transform_iterator_cast.h"
Y
Yu Yang 已提交
30 31 32 33
#endif

namespace paddle {
namespace platform {
34

35 36 37 38 39 40 41 42 43 44 45 46 47
// Transform applys a unary or a binary functor on each element in a
// range defined by a pair of iterators.
//
// - The specialization for CPU calls std::transform.
// - The specialization for CUDA calls thrust::tranform.
//
// NOTE: We need to define InputIter and OutputIter defined as
//       different types, because the InputIter points op's inputs and
//       OutputIter pints to op's outputs.
//
// NOTE: We don't assume that InputIter to be const InputType* and
//       OutputIter to be OutputType*, because we might use a iterator
//       class, paddle::fluid::operators::RowwiseTRansformIterator.
Q
QI JUN 已提交
48
template <typename DeviceContext>
49
struct Transform {
50
  // The unary version.
51
  template <typename InputIter, typename OutputIter, typename UnaryOperation>
52 53 54 55 56
  void operator()(const DeviceContext& context,
                  InputIter first,
                  InputIter last,
                  OutputIter result,
                  UnaryOperation op);
57

58
  // The binary version.
59 60 61
  template <typename InputIter1,
            typename InputIter2,
            typename OutputIter,
62
            typename BinaryOperation>
63 64 65 66 67
  void operator()(const DeviceContext& context,
                  InputIter1 first1,
                  InputIter1 last1,
                  InputIter2 first2,
                  OutputIter result,
68 69 70
                  BinaryOperation op);
};

71
// NOTE: After the phi kernel is migrated, it needs to be deleted.
72

W
Wilber 已提交
73
template <>
74
struct Transform<phi::CPUContext> {
W
Wilber 已提交
75
  template <typename InputIter, typename OutputIter, typename UnaryOperation>
76 77 78 79 80
  void operator()(const phi::CPUContext& context,
                  InputIter first,
                  InputIter last,
                  OutputIter result,
                  UnaryOperation op) {
W
Wilber 已提交
81 82 83
    std::transform(first, last, result, op);
  }

84 85 86
  template <typename InputIter1,
            typename InputIter2,
            typename OutputIter,
W
Wilber 已提交
87
            typename BinaryOperation>
88 89 90 91 92
  void operator()(const phi::CPUContext& context,
                  InputIter1 first1,
                  InputIter1 last1,
                  InputIter2 first2,
                  OutputIter result,
W
Wilber 已提交
93 94 95 96 97
                  BinaryOperation op) {
    std::transform(first1, last1, first2, result, op);
  }
};

98
#if defined(__NVCC__) || defined(__HIPCC__)
99 100

template <>
101
struct Transform<phi::GPUContext> {
102
  template <typename InputIter, typename OutputIter, typename UnaryOperation>
103 104 105 106 107
  void operator()(const phi::GPUContext& context,
                  InputIter first,
                  InputIter last,
                  OutputIter result,
                  UnaryOperation op) {
108
    auto place = context.GetPlace();
109 110
    PADDLE_ENFORCE_EQ(is_gpu_place(place),
                      true,
111 112 113 114 115 116
                      platform::errors::PreconditionNotMet(
                          "The CUDA Transform must be used in GPU place."));
#ifdef __HIPCC__
    thrust::transform(thrust::hip::par.on(context.stream()),
                      details::CastToCUDATransformIterator(first),
                      details::CastToCUDATransformIterator(last),
117 118
                      details::CastToCUDATransformIterator(result),
                      op);
119 120 121 122
#else
    thrust::transform(thrust::cuda::par.on(context.stream()),
                      details::CastToCUDATransformIterator(first),
                      details::CastToCUDATransformIterator(last),
123 124
                      details::CastToCUDATransformIterator(result),
                      op);
125 126 127
#endif
  }

128 129 130
  template <typename InputIter1,
            typename InputIter2,
            typename OutputIter,
131
            typename BinaryOperation>
132 133 134 135 136
  void operator()(const phi::GPUContext& context,
                  InputIter1 first1,
                  InputIter1 last1,
                  InputIter2 first2,
                  OutputIter result,
137
                  BinaryOperation op) {
138
    auto place = context.GetPlace();
139 140
    PADDLE_ENFORCE_EQ(is_gpu_place(place),
                      true,
G
GaoWei8 已提交
141 142
                      platform::errors::PreconditionNotMet(
                          "The CUDA Transform must be used in GPU place."));
143 144 145 146 147
#ifdef __HIPCC__
    thrust::transform(thrust::hip::par.on(context.stream()),
                      details::CastToCUDATransformIterator(first1),
                      details::CastToCUDATransformIterator(last1),
                      details::CastToCUDATransformIterator(first2),
148 149
                      details::CastToCUDATransformIterator(result),
                      op);
150
#else
Q
QI JUN 已提交
151
    thrust::transform(thrust::cuda::par.on(context.stream()),
152 153 154
                      details::CastToCUDATransformIterator(first1),
                      details::CastToCUDATransformIterator(last1),
                      details::CastToCUDATransformIterator(first2),
155 156
                      details::CastToCUDATransformIterator(result),
                      op);
157
#endif
Y
Yu Yang 已提交
158 159
  }
};
160
#endif
Y
Yu Yang 已提交
161 162 163

}  // namespace platform
}  // namespace paddle