test_fused_adam_kernel.cc 15.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/generator.h"

#ifdef PADDLE_WITH_CUDA
#include "paddle/phi/backends/gpu/gpu_context.h"
#endif

#include "gtest/gtest.h"

25
#include "paddle/phi/backends/context_pool.h"
26 27 28 29 30 31 32 33 34
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/kernels/abs_kernel.h"
#include "paddle/phi/kernels/adam_kernel.h"
#include "paddle/phi/kernels/adamw_kernel.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
35
#include "paddle/phi/kernels/fused_adam_kernel.h"
36
#include "paddle/phi/kernels/gaussian_kernel.h"
37
#include "paddle/phi/kernels/legacy/reduce_max_kernel.h"
38 39 40 41 42 43 44 45 46

namespace phi {

template <typename T, typename Context>
auto GenerateRandomTensorVectors(
    const Context &ctx, const std::vector<std::vector<int64_t>> &shapes) {
  size_t n = shapes.size();
  std::vector<DenseTensor> tensors(n);
  for (size_t i = 0; i < n; ++i) {
47 48 49 50 51 52 53
    GaussianKernel<T, Context>(ctx,
                               shapes[i],
                               0.0f,
                               1.0f,
                               0,
                               phi::CppTypeToDataType<T>::Type(),
                               &tensors[i]);
54 55 56 57 58 59 60 61 62 63 64 65
  }
  return tensors;
}

template <typename T, typename Context>
auto GenerateConstantTensorVectors(
    const Context &ctx,
    const std::vector<std::vector<int64_t>> &shapes,
    T value) {
  size_t n = shapes.size();
  std::vector<DenseTensor> tensors(n);
  for (size_t i = 0; i < n; ++i) {
66 67
    FullKernel<T, Context>(
        ctx, shapes[i], value, phi::CppTypeToDataType<T>::Type(), &tensors[i]);
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
  }
  return tensors;
}

static auto ToConstTensorPtrVector(const std::vector<DenseTensor> &tensors) {
  std::vector<const DenseTensor *> results;
  for (const auto &t : tensors) {
    results.push_back(&t);
  }
  return results;
}

static auto ToMutableTensorPtrVector(
    std::vector<DenseTensor> &tensors) {  // NOLINT
  std::vector<DenseTensor *> results;
  for (auto &t : tensors) {
    results.push_back(&t);
  }
  return results;
}

static auto ToMetaTensorVector(const std::vector<DenseTensor> &tensors) {
  std::vector<MetaTensor> results;
  for (auto &t : tensors) {
    results.push_back(t);
  }
  return results;
}

static auto ToConstMetaTensorPtrVector(
    const std::vector<MetaTensor> &meta_tensors) {
  std::vector<const MetaTensor *> results;
  for (auto &t : meta_tensors) {
    results.push_back(&t);
  }
  return results;
}

static auto ToMutableMetaTensorPtrVector(
    std::vector<MetaTensor> &meta_tensors) {  // NOLINT
  std::vector<MetaTensor *> results;
  for (auto &t : meta_tensors) {
    results.push_back(&t);
  }
  return results;
}

template <typename T, typename Context>
struct AdamInfo {
  const Context *ctx;
  std::vector<std::vector<int64_t>> shapes;

  std::vector<DenseTensor> params;
  std::vector<DenseTensor> master_params;
  std::vector<DenseTensor> moment1s;
  std::vector<DenseTensor> moment2s;
  std::vector<DenseTensor> beta1_pows;
  std::vector<DenseTensor> beta2_pows;
  DenseTensor learning_rate;
  float beta1;
  float beta2;
  float weight_decay;
  float epsilon = 1e-6;
  bool multi_precision;
  bool use_adamw;
  int chunk_size = 4096;

  using MT = typename phi::dtype::MPTypeTrait<T>::Type;

  AdamInfo(const Context &ctx_ref,
           const std::vector<std::vector<int64_t>> &shapes,
           float beta1,
           float beta2,
           float weight_decay,
           bool multi_precision,
           bool use_adamw)
      : ctx(&ctx_ref),
        shapes(shapes),
        beta1(beta1),
        beta2(beta2),
        weight_decay(weight_decay),
        multi_precision(multi_precision),
        use_adamw(use_adamw) {
    std::vector<std::vector<int64_t>> one_shapes(shapes.size(),
                                                 std::vector<int64_t>(1, 1));
    std::vector<std::vector<int64_t>> learning_rate_shapes(
        one_shapes.begin(), one_shapes.begin() + 1);

    params = GenerateRandomTensorVectors<T, Context>(*ctx, shapes);
    learning_rate = GenerateConstantTensorVectors<MT, Context>(
        *ctx, learning_rate_shapes, 1e-3)[0];
    moment1s = GenerateConstantTensorVectors<MT, Context>(*ctx, shapes, 0);
    moment2s = GenerateConstantTensorVectors<MT, Context>(*ctx, shapes, 0);

    if (multi_precision) {
      master_params.resize(shapes.size());
      for (size_t i = 0; i < shapes.size(); ++i) {
        master_params[i] = Cast<T, Context>(
166
            *ctx, params[i], phi::CppTypeToDataType<MT>::Type());
167 168 169 170 171 172 173 174 175
      }
    }

    beta1_pows =
        GenerateConstantTensorVectors<MT, Context>(*ctx, one_shapes, beta1);
    beta2_pows =
        GenerateConstantTensorVectors<MT, Context>(*ctx, one_shapes, beta2);
  }

176 177 178
  void Update(bool use_fused, const std::vector<DenseTensor> &grads) {
    if (use_fused) {
      UpdateWithFusedAdam(grads);
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
    } else {
      for (size_t j = 0; j < params.size(); ++j) {
        if (use_adamw) {
          UpdateWithAdamWBaseline(grads, j);
        } else {
          UpdateWithAdamBaseline(grads, j);
        }
      }
    }
  }

  static AdamInfo<T, Context> DeepCopy(const AdamInfo &other) {
    AdamInfo copied(*other.ctx,
                    other.shapes,
                    other.beta1,
                    other.beta2,
                    other.weight_decay,
                    other.multi_precision,
                    other.use_adamw);
    auto copy_tensor = [&other](const DenseTensor &x, DenseTensor *y) {
      Copy<Context>(*other.ctx, x, x.place(), false, y);
    };

    auto copy_tensors = [&other](const std::vector<DenseTensor> &xs,
                                 std::vector<DenseTensor> *ys) {
      for (size_t i = 0; i < xs.size(); ++i) {
        Copy<Context>(*other.ctx, xs[i], xs[i].place(), false, &((*ys)[i]));
      }
    };

    copy_tensors(other.params, &copied.params);
    copy_tensors(other.master_params, &copied.master_params);
    copy_tensors(other.moment1s, &copied.moment1s);
    copy_tensors(other.moment2s, &copied.moment2s);
    copy_tensors(other.beta1_pows, &copied.beta1_pows);
    copy_tensors(other.beta2_pows, &copied.beta2_pows);
    copy_tensor(other.learning_rate, &copied.learning_rate);
    copied.epsilon = other.epsilon;
    copied.chunk_size = other.chunk_size;
    other.ctx->Wait();
    return copied;
  }

 private:
223
  void UpdateWithFusedAdam(const std::vector<DenseTensor> &grads) {
224 225 226 227 228 229 230 231
    auto param_metas = ToMetaTensorVector(params);
    auto grad_metas = ToMetaTensorVector(grads);
    auto master_param_metas = ToMetaTensorVector(master_params);
    auto moment1_metas = ToMetaTensorVector(moment1s);
    auto moment2_metas = ToMetaTensorVector(moment2s);
    auto beta1_pow_metas = ToMetaTensorVector(beta1_pows);
    auto beta2_pow_metas = ToMetaTensorVector(beta2_pows);

232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
    FusedAdamInferMeta(ToConstMetaTensorPtrVector(param_metas),
                       ToConstMetaTensorPtrVector(grad_metas),
                       learning_rate,
                       ToConstMetaTensorPtrVector(moment1_metas),
                       ToConstMetaTensorPtrVector(moment2_metas),
                       ToConstMetaTensorPtrVector(beta1_pow_metas),
                       ToConstMetaTensorPtrVector(beta2_pow_metas),
                       multi_precision
                           ? paddle::make_optional(
                                 ToConstMetaTensorPtrVector(master_param_metas))
                           : paddle::none,
                       MetaTensor(),
                       beta1,
                       beta2,
                       epsilon,
                       chunk_size,
                       weight_decay,
                       use_adamw,
                       multi_precision,
                       false,
                       ToMutableMetaTensorPtrVector(param_metas),
                       ToMutableMetaTensorPtrVector(moment1_metas),
                       ToMutableMetaTensorPtrVector(moment2_metas),
                       ToMutableMetaTensorPtrVector(beta1_pow_metas),
                       ToMutableMetaTensorPtrVector(beta2_pow_metas),
                       ToMutableMetaTensorPtrVector(master_param_metas));

    FusedAdamKernel<T, Context>(
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
        *ctx,
        ToConstTensorPtrVector(params),
        ToConstTensorPtrVector(grads),
        learning_rate,
        ToConstTensorPtrVector(moment1s),
        ToConstTensorPtrVector(moment2s),
        ToConstTensorPtrVector(beta1_pows),
        ToConstTensorPtrVector(beta2_pows),
        multi_precision
            ? paddle::make_optional(ToConstTensorPtrVector(master_params))
            : paddle::none,
        paddle::none,
        beta1,
        beta2,
        epsilon,
        chunk_size,
        weight_decay,
        use_adamw,
        multi_precision,
        false,
        ToMutableTensorPtrVector(params),
        ToMutableTensorPtrVector(moment1s),
        ToMutableTensorPtrVector(moment2s),
        ToMutableTensorPtrVector(beta1_pows),
        ToMutableTensorPtrVector(beta2_pows),
        ToMutableTensorPtrVector(master_params));
  }

  void UpdateWithAdamWBaseline(const std::vector<DenseTensor> &grads,
                               size_t idx) {
    AdamwDenseKernel<T, Context>(
        *ctx,
        params[idx],
        grads[idx],
        learning_rate,
        moment1s[idx],
        moment2s[idx],
        beta1_pows[idx],
        beta2_pows[idx],
        multi_precision ? paddle::make_optional(master_params[idx])
                        : paddle::none,
        paddle::none,
        beta1,
        beta2,
        epsilon,
        1.0,
        weight_decay,
        true,
        false,
        1000,
        multi_precision,
        false,
        &params[idx],
        &moment1s[idx],
        &moment2s[idx],
        &beta1_pows[idx],
        &beta2_pows[idx],
        multi_precision ? &master_params[idx] : nullptr);
  }

  void UpdateWithAdamBaseline(const std::vector<DenseTensor> &grads,
                              size_t idx) {
    AdamDenseKernel<T, Context>(
        *ctx,
        params[idx],
        grads[idx],
        learning_rate,
        moment1s[idx],
        moment2s[idx],
        beta1_pows[idx],
        beta2_pows[idx],
        multi_precision ? paddle::make_optional(master_params[idx])
                        : paddle::none,
        paddle::none,
        beta1,
        beta2,
        epsilon,
        false,
        1000,
        multi_precision,
        false,
        &params[idx],
        &moment1s[idx],
        &moment2s[idx],
        &beta1_pows[idx],
        &beta2_pows[idx],
        multi_precision ? &master_params[idx] : nullptr);
  }
};

template <typename T, typename Context>
auto MaxDiff(const Context &ctx,
             const DenseTensor &x_t,
             const DenseTensor &y_t) {
  using MT = typename AdamInfo<T, Context>::MT;
355
  auto mp_dtype = phi::CppTypeToDataType<MT>::Type();
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
  auto x = Cast<T, Context>(ctx, x_t, mp_dtype);
  auto y = Cast<T, Context>(ctx, y_t, mp_dtype);

  EXPECT_EQ(x.dims(), y.dims());
  DenseTensor diff, diff_reduced, diff_reduced_cpu;

  diff.Resize(x.dims());
  ctx.template Alloc<MT>(&diff);
  SubtractKernel<MT, Context>(ctx, x, y, &diff);
  AbsKernel<MT, Context>(ctx, diff, &diff);

  diff_reduced.Resize({1});
  ctx.template Alloc<MT>(&diff_reduced);
  MaxRawKernel<MT, Context>(
      ctx, diff, vectorize<int64_t>(x.dims()), false, true, &diff_reduced);

  diff_reduced_cpu.Resize(diff_reduced.dims());
  ctx.template HostAlloc<MT>(&diff_reduced_cpu);
  Copy<Context>(ctx, diff_reduced, CPUPlace(), true, &diff_reduced_cpu);
  EXPECT_EQ(diff_reduced_cpu.place(), CPUPlace());
  return diff_reduced_cpu.data<MT>()[0];
}

template <typename T, typename Context>
auto MaxDiff(const Context &ctx,
             const std::vector<DenseTensor> &xs,
             const std::vector<DenseTensor> &ys) {
  using MT = typename AdamInfo<T, Context>::MT;
  MT diff = 0;
  for (size_t i = 0; i < xs.size(); ++i) {
    diff = std::max<MT>(diff, MaxDiff<T, Context>(ctx, xs[i], ys[i]));
  }
  return diff;
}

template <typename T, typename PlaceType>
392 393 394 395 396 397 398 399 400
void TestFusedAdamBase(const std::vector<std::vector<int64_t>> &shapes,
                       float atol,
                       bool use_adamw,
                       bool multi_precision = false,
                       float beta1 = 0.9,
                       float beta2 = 0.99,
                       float weight_decay = 0.1,
                       size_t steps = 5,
                       uint64_t seed = 10) {
401
  const auto &ctx = *phi::DeviceContextPool::Instance().GetByPlace(PlaceType());
402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
  using Context = typename std::remove_const<
      typename std::remove_pointer<decltype(&ctx)>::type>::type;
  ctx.GetGenerator()->SetCurrentSeed(seed);
  AdamInfo<T, Context> info1(
      ctx, shapes, beta1, beta2, weight_decay, multi_precision, use_adamw);
  auto info2 = AdamInfo<T, Context>::DeepCopy(info1);

  for (size_t i = 0; i < steps; ++i) {
    auto grads = GenerateRandomTensorVectors<T>(ctx, shapes);
    info1.Update(false, grads);
    info2.Update(true, grads);
  }

  using MT = typename decltype(info1)::MT;

#define PD_ADAM_TEST_COMP(__field, __dtype)                          \
  do {                                                               \
    MT __diff = MaxDiff<__dtype>(ctx, info1.__field, info2.__field); \
    EXPECT_LE(__diff, static_cast<MT>(atol))                         \
        << #__field << " has diff when use_adamw = " << use_adamw    \
        << " , multi_precision = " << multi_precision;               \
  } while (0)

  PD_ADAM_TEST_COMP(beta1_pows, MT);
  PD_ADAM_TEST_COMP(beta2_pows, MT);
  PD_ADAM_TEST_COMP(params, T);
  PD_ADAM_TEST_COMP(master_params, MT);
  PD_ADAM_TEST_COMP(moment1s, MT);
  PD_ADAM_TEST_COMP(moment2s, MT);
}

static auto GenerateRandomShapes(size_t n, uint64_t low, uint64_t high) {
  std::random_device device;
  std::default_random_engine engine(device());
  std::uniform_int_distribution<uint64_t> dist(low, high);
  std::vector<std::vector<int64_t>> shapes(n);
  for (size_t i = 0; i < n; ++i) {
    shapes[i].push_back(dist(engine));
  }
  return shapes;
}

444
TEST(fused_adam, test_fp32_cpu) {
445 446 447
  auto shapes = GenerateRandomShapes(30, 10, 20);
  float atol = 0.0f;
  for (auto use_adamw : {false, true}) {
448
    TestFusedAdamBase<float, CPUPlace>(shapes, atol, use_adamw);
449 450 451 452
  }
}

#ifdef PADDLE_WITH_CUDA
453
TEST(fused_adam, test_fp32_gpu) {
454 455 456
  auto shapes = GenerateRandomShapes(40, 0, 2 << 18);
  float atol = 0.0f;
  for (auto use_adamw : {false, true}) {
457
    TestFusedAdamBase<float, GPUPlace>(shapes, atol, use_adamw);
458 459 460
  }
}

461
TEST(fused_adam, test_fp16_gpu) {
462 463 464
  auto shapes = GenerateRandomShapes(40, 0, 2 << 18);
  float atol = 5e-3f;
  for (auto use_adamw : {false, true}) {
465
    TestFusedAdamBase<dtype::float16, GPUPlace>(shapes, atol, use_adamw, true);
466 467 468 469 470
  }
}
#endif

}  // namespace phi