dot_op.h 10.2 KB
Newer Older
L
liuwei1031 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
C
chentianyu03 已提交
19 20
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
L
liuwei1031 已提交
21

22
// only can include the headers in paddle/pten/api dirs
23 24 25
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/linalg.h"
26

L
liuwei1031 已提交
27 28 29 30 31
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

C
chentianyu03 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44
template <typename T, typename R>
struct P {
  void operator()(T a, R b);
};

template <typename DeviceContext, typename T, typename Enabel = void>
struct DotGradFunction {
  void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
                  const Tensor* tensor_dout, Tensor* tensor_dx,
                  Tensor* tensor_dy,
                  const paddle::framework::ExecutionContext& ctx);
};

S
ShenLiang 已提交
45
template <typename DeviceContext, typename T>
C
chentianyu03 已提交
46 47 48 49 50
struct DotGradFunction<DeviceContext, T, math::EnableComplex<T>> {
  void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
                  const Tensor* tensor_dout, Tensor* tensor_dx,
                  Tensor* tensor_dy,
                  const paddle::framework::ExecutionContext& ctx) {
51
#if defined(__NVCC__) || defined(__HIPCC__)
C
chentianyu03 已提交
52 53
    if (1 == tensor_dout->dims().size()) {
      auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
S
ShenLiang 已提交
54

C
chentianyu03 已提交
55 56 57 58 59
      if (tensor_dx) {
        auto y = framework::EigenVector<T>::Flatten(*tensor_y);
        auto& dev_raw = ctx.template device_context<DeviceContext>();
        auto& dev = *dev_raw.eigen_device();
        Eigen::DSizes<int, 1> size(tensor_dx->numel());
S
ShenLiang 已提交
60

C
chentianyu03 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
        paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
                                                            tensor_y->numel());
        math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
                                     tensor_dx->data<T>());
        for_range(functor);
        auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);

        dx.device(dev) = dx * dout.broadcast(size);
      }

      if (tensor_dy) {
        auto x = framework::EigenVector<T>::Flatten(*tensor_x);
        auto& dev_raw = ctx.template device_context<DeviceContext>();
        auto& dev = *dev_raw.eigen_device();
        Eigen::DSizes<int, 1> size(tensor_dy->numel());

        paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
                                                            tensor_y->numel());
        math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
                                     tensor_dy->data<T>());
        for_range(functor);
        auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);

        dy.device(dev) = dy * dout.broadcast(size);
      }
    } else {
W
wuhuanzhou 已提交
87
      auto dout = framework::EigenMatrix<T>::From(*tensor_dout);
C
chentianyu03 已提交
88 89 90

      if (tensor_dx) {
        tensor_dx->mutable_data<T>(ctx.GetPlace());
W
wuhuanzhou 已提交
91
        auto y = framework::EigenMatrix<T>::From(*tensor_y);
C
chentianyu03 已提交
92 93 94 95 96 97 98 99 100
        auto& dev_raw = ctx.template device_context<DeviceContext>();
        auto& dev = *dev_raw.eigen_device();
        Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);

        paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
                                                            tensor_y->numel());
        math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
                                     tensor_dx->data<T>());
        for_range(functor);
W
wuhuanzhou 已提交
101
        auto dx = framework::EigenMatrix<T>::From(*tensor_dx);
C
chentianyu03 已提交
102 103 104 105 106 107

        dx.device(dev) = dx * dout.broadcast(size);
      }

      if (tensor_dy) {
        tensor_dy->mutable_data<T>(ctx.GetPlace());
W
wuhuanzhou 已提交
108
        auto x = framework::EigenMatrix<T>::From(*tensor_x);
C
chentianyu03 已提交
109 110 111 112 113 114 115 116 117 118
        auto& dev_raw = ctx.template device_context<DeviceContext>();
        auto& dev = *dev_raw.eigen_device();
        Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);

        paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
                                                            tensor_x->numel());
        math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
                                     tensor_dy->data<T>());
        for_range(functor);

W
wuhuanzhou 已提交
119
        auto dy = framework::EigenMatrix<T>::From(*tensor_dy);
C
chentianyu03 已提交
120 121 122

        dy.device(dev) = dy * dout.broadcast(size);
      }
S
ShenLiang 已提交
123
    }
C
chentianyu03 已提交
124 125
#else
    const auto* data_dout = tensor_dout->data<T>();
S
ShenLiang 已提交
126 127

    if (tensor_dx) {
C
chentianyu03 已提交
128 129 130 131 132 133 134 135 136 137 138 139
      auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
      const auto* data_y = tensor_y->data<T>();
      const framework::DDim& dim = tensor_x->dims();
      size_t N = static_cast<size_t>(framework::product(dim));

      auto step = dim[dim.size() - 1];

      int s = -1;
      for (size_t i = 0; i < N; ++i) {
        if (0 == i % step) ++s;
        data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s];
      }
S
ShenLiang 已提交
140 141 142
    }

    if (tensor_dy) {
C
chentianyu03 已提交
143 144 145 146 147 148 149 150 151 152 153 154
      auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
      const auto* data_x = tensor_x->data<T>();
      const framework::DDim& dim = tensor_y->dims();
      size_t N = static_cast<size_t>(framework::product(dim));

      auto step = dim[dim.size() - 1];

      int s = -1;
      for (size_t i = 0; i < N; ++i) {
        if (0 == i % step) ++s;
        data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s];
      }
S
ShenLiang 已提交
155
    }
C
chentianyu03 已提交
156
#endif
S
ShenLiang 已提交
157
  }
C
chentianyu03 已提交
158 159 160 161 162 163 164 165
};

template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
  void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
                  const Tensor* tensor_dout, Tensor* tensor_dx,
                  Tensor* tensor_dy,
                  const paddle::framework::ExecutionContext& ctx) {
166
#if defined(__NVCC__) || defined(__HIPCC__)
C
chentianyu03 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    if (1 == tensor_dout->dims().size()) {
      auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);

      if (tensor_dx) {
        auto y = framework::EigenVector<T>::Flatten(*tensor_y);
        auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
        auto& dev =
            *ctx.template device_context<DeviceContext>().eigen_device();
        Eigen::DSizes<int, 1> size(tensor_dx->numel());
        dx.device(dev) = y * dout.broadcast(size);
      }

      if (tensor_dy) {
        auto x = framework::EigenVector<T>::Flatten(*tensor_x);
        auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
        auto& dev =
            *ctx.template device_context<DeviceContext>().eigen_device();
        Eigen::DSizes<int, 1> size(tensor_dy->numel());
        dy.device(dev) = x * dout.broadcast(size);
      }
    } else {
W
wuhuanzhou 已提交
188
      auto dout = framework::EigenMatrix<T>::From(*tensor_dout);
C
chentianyu03 已提交
189 190 191

      if (tensor_dx) {
        tensor_dx->mutable_data<T>(ctx.GetPlace());
W
wuhuanzhou 已提交
192 193
        auto y = framework::EigenMatrix<T>::From(*tensor_y);
        auto dx = framework::EigenMatrix<T>::From(*tensor_dx);
C
chentianyu03 已提交
194 195 196 197 198 199 200 201
        auto& dev =
            *ctx.template device_context<DeviceContext>().eigen_device();
        Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
        dx.device(dev) = y * dout.broadcast(size);
      }

      if (tensor_dy) {
        tensor_dy->mutable_data<T>(ctx.GetPlace());
W
wuhuanzhou 已提交
202 203
        auto x = framework::EigenMatrix<T>::From(*tensor_x);
        auto dy = framework::EigenMatrix<T>::From(*tensor_dy);
C
chentianyu03 已提交
204 205 206 207 208 209
        auto& dev =
            *ctx.template device_context<DeviceContext>().eigen_device();
        Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
        dy.device(dev) = x * dout.broadcast(size);
      }
    }
S
ShenLiang 已提交
210
#else
211 212 213 214 215
    auto const *x = tensor_x->data<T>(), *y = tensor_y->data<T>(),
               *dz = tensor_dout->data<T>();
    auto&& d = tensor_x->dims();
    auto const N = tensor_x->numel();
    auto const B = d[d.size() - 1];
S
ShenLiang 已提交
216

C
chentianyu03 已提交
217
    if (tensor_dx) {
218 219 220 221
      auto* dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
      for (auto j = 0; j < N / B; ++j) {
        auto const ss = dz[j];
        for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss;
C
chentianyu03 已提交
222
      }
S
ShenLiang 已提交
223 224
    }

C
chentianyu03 已提交
225
    if (tensor_dy) {
226 227 228 229
      auto* dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
      for (auto j = 0; j < N / B; ++j) {
        auto const ss = dz[j];
        for (auto i = 0; i < B; i++) *dy++ = *x++ * ss;
C
chentianyu03 已提交
230
      }
S
ShenLiang 已提交
231 232
    }
#endif
C
chentianyu03 已提交
233 234
  }
};
S
ShenLiang 已提交
235

236
// See Note [ Why still keep the original kernel implementation? ]
L
liuwei1031 已提交
237 238 239 240
template <typename DeviceContext, typename T>
class DotKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
241 242 243 244 245 246 247 248 249 250 251 252
    auto* x = ctx.Input<Tensor>("X");
    auto* y = ctx.Input<Tensor>("Y");
    auto* out = ctx.Output<Tensor>("Out");
    auto& dev_ctx = ctx.device_context<DeviceContext>();
    out->mutable_data<T>(x->place());

    auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
    auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
    auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);

    // call new kernel
    pten::Dot<T>(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get());
L
liuwei1031 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
  }
};

template <typename DeviceContext, typename T>
class DotGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* tensor_x = ctx.Input<Tensor>("X");
    auto* tensor_y = ctx.Input<Tensor>("Y");
    auto* tensor_dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* tensor_dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* tensor_dy = ctx.Output<Tensor>(framework::GradVarName("Y"));

    if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
    if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());

C
chentianyu03 已提交
269 270
    DotGradFunction<DeviceContext, T>()(tensor_x, tensor_y, tensor_dout,
                                        tensor_dx, tensor_dy, ctx);
L
liuwei1031 已提交
271 272 273 274 275
  }
};

}  // namespace operators
}  // namespace paddle