dot_op.h 11.0 KB
Newer Older
L
liuwei1031 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
C
chentianyu03 已提交
19 20
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
L
liuwei1031 已提交
21 22 23 24 25

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
C
chentianyu03 已提交
26 27
using complex64 = platform::complex64;
using complex128 = platform::complex128;
L
liuwei1031 已提交
28

C
chentianyu03 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41
template <typename T, typename R>
struct P {
  void operator()(T a, R b);
};

template <typename DeviceContext, typename T, typename Enabel = void>
struct DotGradFunction {
  void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
                  const Tensor* tensor_dout, Tensor* tensor_dx,
                  Tensor* tensor_dy,
                  const paddle::framework::ExecutionContext& ctx);
};

S
ShenLiang 已提交
42
template <typename DeviceContext, typename T>
C
chentianyu03 已提交
43 44 45 46 47
struct DotGradFunction<DeviceContext, T, math::EnableComplex<T>> {
  void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
                  const Tensor* tensor_dout, Tensor* tensor_dx,
                  Tensor* tensor_dy,
                  const paddle::framework::ExecutionContext& ctx) {
48
#if defined(__NVCC__) || defined(__HIPCC__)
C
chentianyu03 已提交
49 50
    if (1 == tensor_dout->dims().size()) {
      auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
S
ShenLiang 已提交
51

C
chentianyu03 已提交
52 53 54 55 56
      if (tensor_dx) {
        auto y = framework::EigenVector<T>::Flatten(*tensor_y);
        auto& dev_raw = ctx.template device_context<DeviceContext>();
        auto& dev = *dev_raw.eigen_device();
        Eigen::DSizes<int, 1> size(tensor_dx->numel());
S
ShenLiang 已提交
57

C
chentianyu03 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
                                                            tensor_y->numel());
        math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
                                     tensor_dx->data<T>());
        for_range(functor);
        auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);

        dx.device(dev) = dx * dout.broadcast(size);
      }

      if (tensor_dy) {
        auto x = framework::EigenVector<T>::Flatten(*tensor_x);
        auto& dev_raw = ctx.template device_context<DeviceContext>();
        auto& dev = *dev_raw.eigen_device();
        Eigen::DSizes<int, 1> size(tensor_dy->numel());

        paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
                                                            tensor_y->numel());
        math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
                                     tensor_dy->data<T>());
        for_range(functor);
        auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);

        dy.device(dev) = dy * dout.broadcast(size);
      }
    } else {
W
wuhuanzhou 已提交
84
      auto dout = framework::EigenMatrix<T>::From(*tensor_dout);
C
chentianyu03 已提交
85 86 87

      if (tensor_dx) {
        tensor_dx->mutable_data<T>(ctx.GetPlace());
W
wuhuanzhou 已提交
88
        auto y = framework::EigenMatrix<T>::From(*tensor_y);
C
chentianyu03 已提交
89 90 91 92 93 94 95 96 97
        auto& dev_raw = ctx.template device_context<DeviceContext>();
        auto& dev = *dev_raw.eigen_device();
        Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);

        paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
                                                            tensor_y->numel());
        math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
                                     tensor_dx->data<T>());
        for_range(functor);
W
wuhuanzhou 已提交
98
        auto dx = framework::EigenMatrix<T>::From(*tensor_dx);
C
chentianyu03 已提交
99 100 101 102 103 104

        dx.device(dev) = dx * dout.broadcast(size);
      }

      if (tensor_dy) {
        tensor_dy->mutable_data<T>(ctx.GetPlace());
W
wuhuanzhou 已提交
105
        auto x = framework::EigenMatrix<T>::From(*tensor_x);
C
chentianyu03 已提交
106 107 108 109 110 111 112 113 114 115
        auto& dev_raw = ctx.template device_context<DeviceContext>();
        auto& dev = *dev_raw.eigen_device();
        Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);

        paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
                                                            tensor_x->numel());
        math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
                                     tensor_dy->data<T>());
        for_range(functor);

W
wuhuanzhou 已提交
116
        auto dy = framework::EigenMatrix<T>::From(*tensor_dy);
C
chentianyu03 已提交
117 118 119

        dy.device(dev) = dy * dout.broadcast(size);
      }
S
ShenLiang 已提交
120
    }
C
chentianyu03 已提交
121 122
#else
    const auto* data_dout = tensor_dout->data<T>();
S
ShenLiang 已提交
123 124

    if (tensor_dx) {
C
chentianyu03 已提交
125 126 127 128 129 130 131 132 133 134 135 136
      auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
      const auto* data_y = tensor_y->data<T>();
      const framework::DDim& dim = tensor_x->dims();
      size_t N = static_cast<size_t>(framework::product(dim));

      auto step = dim[dim.size() - 1];

      int s = -1;
      for (size_t i = 0; i < N; ++i) {
        if (0 == i % step) ++s;
        data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s];
      }
S
ShenLiang 已提交
137 138 139
    }

    if (tensor_dy) {
C
chentianyu03 已提交
140 141 142 143 144 145 146 147 148 149 150 151
      auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
      const auto* data_x = tensor_x->data<T>();
      const framework::DDim& dim = tensor_y->dims();
      size_t N = static_cast<size_t>(framework::product(dim));

      auto step = dim[dim.size() - 1];

      int s = -1;
      for (size_t i = 0; i < N; ++i) {
        if (0 == i % step) ++s;
        data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s];
      }
S
ShenLiang 已提交
152
    }
C
chentianyu03 已提交
153
#endif
S
ShenLiang 已提交
154
  }
C
chentianyu03 已提交
155 156 157 158 159 160 161 162
};

template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
  void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
                  const Tensor* tensor_dout, Tensor* tensor_dx,
                  Tensor* tensor_dy,
                  const paddle::framework::ExecutionContext& ctx) {
163
#if defined(__NVCC__) || defined(__HIPCC__)
C
chentianyu03 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    if (1 == tensor_dout->dims().size()) {
      auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);

      if (tensor_dx) {
        auto y = framework::EigenVector<T>::Flatten(*tensor_y);
        auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
        auto& dev =
            *ctx.template device_context<DeviceContext>().eigen_device();
        Eigen::DSizes<int, 1> size(tensor_dx->numel());
        dx.device(dev) = y * dout.broadcast(size);
      }

      if (tensor_dy) {
        auto x = framework::EigenVector<T>::Flatten(*tensor_x);
        auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
        auto& dev =
            *ctx.template device_context<DeviceContext>().eigen_device();
        Eigen::DSizes<int, 1> size(tensor_dy->numel());
        dy.device(dev) = x * dout.broadcast(size);
      }
    } else {
W
wuhuanzhou 已提交
185
      auto dout = framework::EigenMatrix<T>::From(*tensor_dout);
C
chentianyu03 已提交
186 187 188

      if (tensor_dx) {
        tensor_dx->mutable_data<T>(ctx.GetPlace());
W
wuhuanzhou 已提交
189 190
        auto y = framework::EigenMatrix<T>::From(*tensor_y);
        auto dx = framework::EigenMatrix<T>::From(*tensor_dx);
C
chentianyu03 已提交
191 192 193 194 195 196 197 198
        auto& dev =
            *ctx.template device_context<DeviceContext>().eigen_device();
        Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
        dx.device(dev) = y * dout.broadcast(size);
      }

      if (tensor_dy) {
        tensor_dy->mutable_data<T>(ctx.GetPlace());
W
wuhuanzhou 已提交
199 200
        auto x = framework::EigenMatrix<T>::From(*tensor_x);
        auto dy = framework::EigenMatrix<T>::From(*tensor_dy);
C
chentianyu03 已提交
201 202 203 204 205 206
        auto& dev =
            *ctx.template device_context<DeviceContext>().eigen_device();
        Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
        dy.device(dev) = x * dout.broadcast(size);
      }
    }
S
ShenLiang 已提交
207
#else
208 209 210 211 212
    auto const *x = tensor_x->data<T>(), *y = tensor_y->data<T>(),
               *dz = tensor_dout->data<T>();
    auto&& d = tensor_x->dims();
    auto const N = tensor_x->numel();
    auto const B = d[d.size() - 1];
S
ShenLiang 已提交
213

C
chentianyu03 已提交
214
    if (tensor_dx) {
215 216 217 218
      auto* dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
      for (auto j = 0; j < N / B; ++j) {
        auto const ss = dz[j];
        for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss;
C
chentianyu03 已提交
219
      }
S
ShenLiang 已提交
220 221
    }

C
chentianyu03 已提交
222
    if (tensor_dy) {
223 224 225 226
      auto* dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
      for (auto j = 0; j < N / B; ++j) {
        auto const ss = dz[j];
        for (auto i = 0; i < B; i++) *dy++ = *x++ * ss;
C
chentianyu03 已提交
227
      }
S
ShenLiang 已提交
228 229
    }
#endif
C
chentianyu03 已提交
230 231
  }
};
S
ShenLiang 已提交
232

L
liuwei1031 已提交
233 234 235 236 237 238 239 240 241
template <typename DeviceContext, typename T>
class DotKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* tensor_x = ctx.Input<Tensor>("X");
    auto* tensor_y = ctx.Input<Tensor>("Y");
    auto* tensor_out = ctx.Output<Tensor>("Out");
    tensor_out->mutable_data<T>(ctx.GetPlace());

242
#if defined(__NVCC__) || defined(__HIPCC__)
L
liuwei1031 已提交
243 244 245 246 247 248 249 250
    if (1 == tensor_out->dims().size()) {
      auto out = framework::EigenScalar<T>::From(*tensor_out);
      auto x = framework::EigenVector<T>::Flatten(*tensor_x);
      auto y = framework::EigenVector<T>::Flatten(*tensor_y);

      auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
      out.device(dev) = (x * y).sum();
    } else {
W
wuhuanzhou 已提交
251 252 253
      auto out = framework::EigenMatrix<T>::From(*tensor_out);
      auto x = framework::EigenMatrix<T>::From(*tensor_x);
      auto y = framework::EigenMatrix<T>::From(*tensor_y);
L
liuwei1031 已提交
254 255 256 257 258

      auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
      out.device(dev) = (x * y).sum(Eigen::DSizes<int, 1>(1));
    }
#else
259 260 261 262 263 264 265 266 267 268 269 270 271 272
    auto const *x = tensor_x->data<T>(), *x_ = &x[0];
    auto const *y = tensor_y->data<T>(), *y_ = &y[0];
    auto* z = tensor_out->data<T>();

    // Loop over the total N elements of both operands while sum-reducing every
    // B pairs along the way where B is the dimension of the least ordered axis
    auto&& d = tensor_x->dims();
    auto const N = tensor_x->numel();
    auto const B = d[d.size() - 1];

    for (int j = 0; j < N / B; j++) {
      T ss = 0;
      for (int i = 0; i < B; i++) ss += (*x_++) * (*y_++);
      z[j] = ss;
L
liuwei1031 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
    }
#endif
  }
};

template <typename DeviceContext, typename T>
class DotGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* tensor_x = ctx.Input<Tensor>("X");
    auto* tensor_y = ctx.Input<Tensor>("Y");
    auto* tensor_dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* tensor_dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* tensor_dy = ctx.Output<Tensor>(framework::GradVarName("Y"));

    if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
    if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());

C
chentianyu03 已提交
291 292
    DotGradFunction<DeviceContext, T>()(tensor_x, tensor_y, tensor_dout,
                                        tensor_dx, tensor_dy, ctx);
L
liuwei1031 已提交
293 294 295 296 297
  }
};

}  // namespace operators
}  // namespace paddle