fused_dropout_test.h 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <random>
#include <vector>

#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/memory.h"
26
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
27
#include "paddle/fluid/string/printf.h"
28
#include "paddle/pten/kernels/funcs/math_function.h"
29 30 31 32 33 34

namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace memory = paddle::memory;

USE_OP(dropout);
35 36 37 38 39 40
USE_OP(layer_norm);

template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124

/**
 * @brief call paddle dropout op
 */
template <typename T>
void Dropout(const std::vector<T> &x, const framework::DDim &x_dim,
             std::vector<T> *out, std::vector<uint8_t> *mask,
             const platform::CUDADeviceContext &ctx, uint64_t seed,
             float dropout_prob, bool is_upscale_in_train, bool is_test) {
  framework::Scope scope;
  auto var_x = scope.Var("X");
  auto tensor_x = var_x->GetMutable<framework::LoDTensor>();
  framework::TensorFromVector(x, ctx, tensor_x);
  tensor_x->Resize(x_dim);

  auto var_out = scope.Var("Out");
  auto tensor_out = var_out->GetMutable<framework::LoDTensor>();

  auto var_mask = scope.Var("Mask");
  auto tensor_mask = var_mask->GetMutable<framework::LoDTensor>();

  framework::AttributeMap attrs;
  attrs.insert({"fix_seed", 1});
  attrs.insert({"seed", static_cast<int>(seed)});
  attrs.insert({"dropout_prob", dropout_prob});
  if (is_upscale_in_train) {
    attrs.insert({"dropout_implementation", std::string("upscale_in_train")});
  }

  if (is_test) {
    attrs.insert({"is_test", true});
  }

  auto op = framework::OpRegistry::CreateOp(
      "dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs);
  op->Run(scope, ctx.GetPlace());

  framework::TensorToVector<T>(*tensor_out, ctx, out);
  if (!is_test) {
    framework::TensorToVector<uint8_t>(*tensor_mask, ctx, mask);
  }
  ctx.Wait();
}

/**
 * @brief call paddle dropout_grad op
 */
template <typename T>
void DropoutGrad(std::vector<T> *dx, const framework::DDim &x_dim,
                 const std::vector<T> &dout, const std::vector<uint8_t> &mask,
                 const platform::CUDADeviceContext &ctx, float dropout_prob,
                 bool is_upscale_in_train) {
  framework::Scope scope;
  const size_t n = x_dim[0] * x_dim[1];
  auto var_out = scope.Var("DOut");
  auto tensor_out = var_out->GetMutable<framework::LoDTensor>();
  framework::TensorFromVector(dout, ctx, tensor_out);
  tensor_out->Resize(x_dim);

  auto var_mask = scope.Var("Mask");
  auto tensor_mask = var_mask->GetMutable<framework::LoDTensor>();
  framework::TensorFromVector(mask, ctx, tensor_mask);
  tensor_mask->Resize(x_dim);

  auto var_dx = scope.Var("DX");
  auto tensor_dx = var_dx->GetMutable<framework::LoDTensor>();

  framework::AttributeMap attrs;
  attrs.insert({"dropout_prob", dropout_prob});
  attrs.insert({"is_test", false});
  if (is_upscale_in_train) {
    attrs.insert({"dropout_implementation", std::string("upscale_in_train")});
  } else {
    attrs.insert({"dropout_implementation", std::string("downgrade_in_infer")});
  }

  auto op = framework::OpRegistry::CreateOp(
      "dropout_grad", {{"Out@GRAD", {"DOut"}}, {"Mask", {"Mask"}}},
      {{"X@GRAD", {"DX"}}}, attrs);
  op->Run(scope, ctx.GetPlace());

  framework::TensorToVector(*tensor_dx, ctx, dx);
  ctx.Wait();
}
125

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
/**
 * @brief call paddle layer_norm op
 */
template <typename T>
void LayerNorm(const std::vector<LayerNormParamType<T>> &scale,
               const std::vector<LayerNormParamType<T>> &bias,
               const std::vector<T> &x,
               std::vector<LayerNormParamType<T>> *means,
               std::vector<LayerNormParamType<T>> *vars, std::vector<T> *y,
               const float epsilon, const int rows, const int cols,
               const platform::CUDADeviceContext &ctx) {
  framework::Scope scope;
  auto place = ctx.GetPlace();
  if (scale.size() > 0) {
    auto var_scale = scope.Var("Scale");
    auto tensor_scale = var_scale->GetMutable<framework::LoDTensor>();
    framework::TensorFromVector(scale, ctx, tensor_scale);
    tensor_scale->Resize({cols});
  }

  if (bias.size() > 0) {
    auto var_bias = scope.Var("Bias");
    auto tensor_bias = var_bias->GetMutable<framework::LoDTensor>();
    framework::TensorFromVector(bias, ctx, tensor_bias);
    tensor_bias->Resize({cols});
  }

  auto var_x = scope.Var("X");
  auto tensor_x = var_x->GetMutable<framework::LoDTensor>();
  framework::TensorFromVector(x, ctx, tensor_x);
  tensor_x->Resize({rows, cols});

  auto var_y = scope.Var("Y");
  auto tensor_y = var_y->GetMutable<framework::LoDTensor>();

  auto var_mean = scope.Var("Mean");
  auto tensor_mean = var_mean->GetMutable<framework::LoDTensor>();

  auto var_variance = scope.Var("Variance");
  auto tensor_variance = var_variance->GetMutable<framework::LoDTensor>();

  framework::AttributeMap attrs;
  attrs.insert({"epsilon", epsilon});

  auto op = framework::OpRegistry::CreateOp(
      "layer_norm", {{"X", {"X"}}, {"Scale", {"Scale"}}, {"Bias", {"Bias"}}},
      {{"Y", {"Y"}}, {"Mean", {"Mean"}}, {"Variance", {"Variance"}}}, attrs);
  op->Run(scope, place);
  framework::TensorToVector(*tensor_y, ctx, y);
  framework::TensorToVector(*tensor_mean, ctx, means);
  framework::TensorToVector(*tensor_variance, ctx, vars);
  ctx.Wait();
}

180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
template <typename T>
inline void ReduceSum(const std::vector<T> &dout, std::vector<T> *dbias,
                      const int rows, const int cols) {
  for (int j = 0; j < cols; j++) {
    std::vector<T> tmp_dbias(rows);
    for (int i = 0; i < rows; i++) {
      tmp_dbias[i] = dout[i * cols + j];
    }
    int tmp_rows = rows / 2;
    while (tmp_rows) {
      for (int i = 0; i < tmp_rows; i++) {
        tmp_dbias[i] += tmp_dbias[i + tmp_rows];
      }
      tmp_rows /= 2;
    }
    (*dbias)[j] = tmp_dbias[0];
  }
}