dot_op.h 10.9 KB
Newer Older
L
liuwei1031 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
C
chentianyu03 已提交
19 20
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
L
liuwei1031 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

C
chentianyu03 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39
template <typename T, typename R>
struct P {
  void operator()(T a, R b);
};

template <typename DeviceContext, typename T, typename Enabel = void>
struct DotGradFunction {
  void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
                  const Tensor* tensor_dout, Tensor* tensor_dx,
                  Tensor* tensor_dy,
                  const paddle::framework::ExecutionContext& ctx);
};

S
ShenLiang 已提交
40
template <typename DeviceContext, typename T>
C
chentianyu03 已提交
41 42 43 44 45
struct DotGradFunction<DeviceContext, T, math::EnableComplex<T>> {
  void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
                  const Tensor* tensor_dout, Tensor* tensor_dx,
                  Tensor* tensor_dy,
                  const paddle::framework::ExecutionContext& ctx) {
46
#if defined(__NVCC__) || defined(__HIPCC__)
C
chentianyu03 已提交
47 48
    if (1 == tensor_dout->dims().size()) {
      auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
S
ShenLiang 已提交
49

C
chentianyu03 已提交
50 51 52 53 54
      if (tensor_dx) {
        auto y = framework::EigenVector<T>::Flatten(*tensor_y);
        auto& dev_raw = ctx.template device_context<DeviceContext>();
        auto& dev = *dev_raw.eigen_device();
        Eigen::DSizes<int, 1> size(tensor_dx->numel());
S
ShenLiang 已提交
55

C
chentianyu03 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
                                                            tensor_y->numel());
        math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
                                     tensor_dx->data<T>());
        for_range(functor);
        auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);

        dx.device(dev) = dx * dout.broadcast(size);
      }

      if (tensor_dy) {
        auto x = framework::EigenVector<T>::Flatten(*tensor_x);
        auto& dev_raw = ctx.template device_context<DeviceContext>();
        auto& dev = *dev_raw.eigen_device();
        Eigen::DSizes<int, 1> size(tensor_dy->numel());

        paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
                                                            tensor_y->numel());
        math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
                                     tensor_dy->data<T>());
        for_range(functor);
        auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);

        dy.device(dev) = dy * dout.broadcast(size);
      }
    } else {
W
wuhuanzhou 已提交
82
      auto dout = framework::EigenMatrix<T>::From(*tensor_dout);
C
chentianyu03 已提交
83 84 85

      if (tensor_dx) {
        tensor_dx->mutable_data<T>(ctx.GetPlace());
W
wuhuanzhou 已提交
86
        auto y = framework::EigenMatrix<T>::From(*tensor_y);
C
chentianyu03 已提交
87 88 89 90 91 92 93 94 95
        auto& dev_raw = ctx.template device_context<DeviceContext>();
        auto& dev = *dev_raw.eigen_device();
        Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);

        paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
                                                            tensor_y->numel());
        math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
                                     tensor_dx->data<T>());
        for_range(functor);
W
wuhuanzhou 已提交
96
        auto dx = framework::EigenMatrix<T>::From(*tensor_dx);
C
chentianyu03 已提交
97 98 99 100 101 102

        dx.device(dev) = dx * dout.broadcast(size);
      }

      if (tensor_dy) {
        tensor_dy->mutable_data<T>(ctx.GetPlace());
W
wuhuanzhou 已提交
103
        auto x = framework::EigenMatrix<T>::From(*tensor_x);
C
chentianyu03 已提交
104 105 106 107 108 109 110 111 112 113
        auto& dev_raw = ctx.template device_context<DeviceContext>();
        auto& dev = *dev_raw.eigen_device();
        Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);

        paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
                                                            tensor_x->numel());
        math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
                                     tensor_dy->data<T>());
        for_range(functor);

W
wuhuanzhou 已提交
114
        auto dy = framework::EigenMatrix<T>::From(*tensor_dy);
C
chentianyu03 已提交
115 116 117

        dy.device(dev) = dy * dout.broadcast(size);
      }
S
ShenLiang 已提交
118
    }
C
chentianyu03 已提交
119 120
#else
    const auto* data_dout = tensor_dout->data<T>();
S
ShenLiang 已提交
121 122

    if (tensor_dx) {
C
chentianyu03 已提交
123 124 125 126 127 128 129 130 131 132 133 134
      auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
      const auto* data_y = tensor_y->data<T>();
      const framework::DDim& dim = tensor_x->dims();
      size_t N = static_cast<size_t>(framework::product(dim));

      auto step = dim[dim.size() - 1];

      int s = -1;
      for (size_t i = 0; i < N; ++i) {
        if (0 == i % step) ++s;
        data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s];
      }
S
ShenLiang 已提交
135 136 137
    }

    if (tensor_dy) {
C
chentianyu03 已提交
138 139 140 141 142 143 144 145 146 147 148 149
      auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
      const auto* data_x = tensor_x->data<T>();
      const framework::DDim& dim = tensor_y->dims();
      size_t N = static_cast<size_t>(framework::product(dim));

      auto step = dim[dim.size() - 1];

      int s = -1;
      for (size_t i = 0; i < N; ++i) {
        if (0 == i % step) ++s;
        data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s];
      }
S
ShenLiang 已提交
150
    }
C
chentianyu03 已提交
151
#endif
S
ShenLiang 已提交
152
  }
C
chentianyu03 已提交
153 154 155 156 157 158 159 160
};

template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
  void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
                  const Tensor* tensor_dout, Tensor* tensor_dx,
                  Tensor* tensor_dy,
                  const paddle::framework::ExecutionContext& ctx) {
161
#if defined(__NVCC__) || defined(__HIPCC__)
C
chentianyu03 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
    if (1 == tensor_dout->dims().size()) {
      auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);

      if (tensor_dx) {
        auto y = framework::EigenVector<T>::Flatten(*tensor_y);
        auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
        auto& dev =
            *ctx.template device_context<DeviceContext>().eigen_device();
        Eigen::DSizes<int, 1> size(tensor_dx->numel());
        dx.device(dev) = y * dout.broadcast(size);
      }

      if (tensor_dy) {
        auto x = framework::EigenVector<T>::Flatten(*tensor_x);
        auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
        auto& dev =
            *ctx.template device_context<DeviceContext>().eigen_device();
        Eigen::DSizes<int, 1> size(tensor_dy->numel());
        dy.device(dev) = x * dout.broadcast(size);
      }
    } else {
W
wuhuanzhou 已提交
183
      auto dout = framework::EigenMatrix<T>::From(*tensor_dout);
C
chentianyu03 已提交
184 185 186

      if (tensor_dx) {
        tensor_dx->mutable_data<T>(ctx.GetPlace());
W
wuhuanzhou 已提交
187 188
        auto y = framework::EigenMatrix<T>::From(*tensor_y);
        auto dx = framework::EigenMatrix<T>::From(*tensor_dx);
C
chentianyu03 已提交
189 190 191 192 193 194 195 196
        auto& dev =
            *ctx.template device_context<DeviceContext>().eigen_device();
        Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
        dx.device(dev) = y * dout.broadcast(size);
      }

      if (tensor_dy) {
        tensor_dy->mutable_data<T>(ctx.GetPlace());
W
wuhuanzhou 已提交
197 198
        auto x = framework::EigenMatrix<T>::From(*tensor_x);
        auto dy = framework::EigenMatrix<T>::From(*tensor_dy);
C
chentianyu03 已提交
199 200 201 202 203 204
        auto& dev =
            *ctx.template device_context<DeviceContext>().eigen_device();
        Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
        dy.device(dev) = x * dout.broadcast(size);
      }
    }
S
ShenLiang 已提交
205
#else
206 207 208 209 210
    auto const *x = tensor_x->data<T>(), *y = tensor_y->data<T>(),
               *dz = tensor_dout->data<T>();
    auto&& d = tensor_x->dims();
    auto const N = tensor_x->numel();
    auto const B = d[d.size() - 1];
S
ShenLiang 已提交
211

C
chentianyu03 已提交
212
    if (tensor_dx) {
213 214 215 216
      auto* dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
      for (auto j = 0; j < N / B; ++j) {
        auto const ss = dz[j];
        for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss;
C
chentianyu03 已提交
217
      }
S
ShenLiang 已提交
218 219
    }

C
chentianyu03 已提交
220
    if (tensor_dy) {
221 222 223 224
      auto* dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
      for (auto j = 0; j < N / B; ++j) {
        auto const ss = dz[j];
        for (auto i = 0; i < B; i++) *dy++ = *x++ * ss;
C
chentianyu03 已提交
225
      }
S
ShenLiang 已提交
226 227
    }
#endif
C
chentianyu03 已提交
228 229
  }
};
S
ShenLiang 已提交
230

L
liuwei1031 已提交
231 232 233 234 235 236 237 238 239
template <typename DeviceContext, typename T>
class DotKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* tensor_x = ctx.Input<Tensor>("X");
    auto* tensor_y = ctx.Input<Tensor>("Y");
    auto* tensor_out = ctx.Output<Tensor>("Out");
    tensor_out->mutable_data<T>(ctx.GetPlace());

240
#if defined(__NVCC__) || defined(__HIPCC__)
L
liuwei1031 已提交
241 242 243 244 245 246 247 248
    if (1 == tensor_out->dims().size()) {
      auto out = framework::EigenScalar<T>::From(*tensor_out);
      auto x = framework::EigenVector<T>::Flatten(*tensor_x);
      auto y = framework::EigenVector<T>::Flatten(*tensor_y);

      auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
      out.device(dev) = (x * y).sum();
    } else {
W
wuhuanzhou 已提交
249 250 251
      auto out = framework::EigenMatrix<T>::From(*tensor_out);
      auto x = framework::EigenMatrix<T>::From(*tensor_x);
      auto y = framework::EigenMatrix<T>::From(*tensor_y);
L
liuwei1031 已提交
252 253 254 255 256

      auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
      out.device(dev) = (x * y).sum(Eigen::DSizes<int, 1>(1));
    }
#else
257 258 259 260 261 262 263 264 265 266 267 268 269 270
    auto const *x = tensor_x->data<T>(), *x_ = &x[0];
    auto const *y = tensor_y->data<T>(), *y_ = &y[0];
    auto* z = tensor_out->data<T>();

    // Loop over the total N elements of both operands while sum-reducing every
    // B pairs along the way where B is the dimension of the least ordered axis
    auto&& d = tensor_x->dims();
    auto const N = tensor_x->numel();
    auto const B = d[d.size() - 1];

    for (int j = 0; j < N / B; j++) {
      T ss = 0;
      for (int i = 0; i < B; i++) ss += (*x_++) * (*y_++);
      z[j] = ss;
L
liuwei1031 已提交
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
    }
#endif
  }
};

template <typename DeviceContext, typename T>
class DotGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* tensor_x = ctx.Input<Tensor>("X");
    auto* tensor_y = ctx.Input<Tensor>("Y");
    auto* tensor_dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* tensor_dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto* tensor_dy = ctx.Output<Tensor>(framework::GradVarName("Y"));

    if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
    if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());

C
chentianyu03 已提交
289 290
    DotGradFunction<DeviceContext, T>()(tensor_x, tensor_y, tensor_dout,
                                        tensor_dx, tensor_dy, ctx);
L
liuwei1031 已提交
291 292 293 294 295
  }
};

}  // namespace operators
}  // namespace paddle