amp_utils.h 5.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
17

18 19 20 21 22
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"

namespace egr {

23
static inline phi::DataType GetPromoteType(
24
    const std::string& op_name,
25
    const paddle::small_vector<std::vector<paddle::Tensor>,
26
                               kSlotSmallVectorSize>& amp_tensors_vector,
27
    const phi::DataType& amp_dtype) {
28 29 30
  auto dst_type = amp_dtype;
  if (egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype() ==
      "float16") {
31 32
    if (op_name == "batch_norm" || op_name == "layer_norm" ||
        op_name == "sync_batch_norm") {
33 34
      if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT32) {
        dst_type = phi::DataType::FLOAT32;
35
      }
36
    } else if (op_name == "fused_attention") {
37 38
      for (size_t i = 0; i < amp_tensors_vector.size(); i++) {
        if (i != 3 || i != 4 || i != 9 || i != 10) {
39 40
          if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) {
            dst_type = phi::DataType::FLOAT32;
41 42 43 44
            break;
          }
        }
      }
45
    } else if (op_name == "fused_feedforward") {
46 47
      for (size_t i = 0; i < amp_tensors_vector.size(); i++) {
        if (i != 7 || i != 8 || i != 9 || i != 10) {
48 49
          if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) {
            dst_type = phi::DataType::FLOAT32;
50 51 52 53 54 55 56
            break;
          }
        }
      }
    } else {
      for (const auto& tensors : amp_tensors_vector) {
        for (const auto& tensor : tensors) {
57
          if (tensor.dtype() == phi::DataType::FLOAT32) {
58 59 60 61 62 63 64 65 66
            dst_type = tensor.dtype();
            break;
          }
        }
      }
    }
  } else {
    for (const auto& tensors : amp_tensors_vector) {
      for (const auto& tensor : tensors) {
67
        if (tensor.dtype() == phi::DataType::FLOAT32) {
68 69 70 71 72 73 74 75
          dst_type = tensor.dtype();
          break;
        }
      }
    }
  }
  // NOTE(juncai): moving_average_abs_max_scale only consider the dtype of
  // input(X)
76
  if (op_name == "moving_average_abs_max_scale") {
77 78
    if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT16) {
      dst_type = phi::DataType::FLOAT16;
79 80 81 82 83
    }
  }
  return dst_type;
}

84
inline phi::DataType GetDtypeWithPlace(
85
    const std::string& op_name,
86
    const paddle::small_vector<std::vector<paddle::Tensor>,
87
                               kSlotSmallVectorSize>& amp_tensors_vector,
88 89
    const phi::DataType amp_dtype) {
  if (amp_dtype == phi::DataType::FLOAT32) {
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    return amp_dtype;
  }
  bool is_right_place = false;
  for (const auto& tensors : amp_tensors_vector) {
    for (const auto& tensor : tensors) {
      auto place = tensor.place();
      is_right_place = (paddle::platform::is_gpu_place(place) ||
                        paddle::platform::is_cuda_pinned_place(place) ||
                        paddle::platform::is_xpu_place(place) ||
                        paddle::platform::is_custom_place(place));
      if (is_right_place) {
        break;
      }
    }
  }

  if (!is_right_place) {
    VLOG(6) << "Change " << op_name << "'s AMP type from " << amp_dtype
            << " to FP32";
109
    return phi::DataType::FLOAT32;
110 111 112 113
  }
  return amp_dtype;
}

114
inline phi::DataType GetAmpDestDtype(
115
    const std::string& op_name,
116
    const paddle::small_vector<std::vector<paddle::Tensor>,
117
                               kSlotSmallVectorSize>& amp_tensors_vector) {
118
  auto amp_level = egr::Controller::Instance().GetAMPLevel();
Z
Zhang Ting 已提交
119 120 121
  auto amp_setting_dtype =
      egr::Controller::Instance().GetCurrentTracer()->GetAmpPhiDtype();
  auto dst_type = amp_setting_dtype;
122

123 124 125 126 127 128 129 130 131 132 133 134 135 136
  bool use_promote = true;
  if (amp_level == paddle::imperative::AmpLevel::O2) {
    use_promote =
        egr::Controller::Instance().GetCurrentTracer()->GetUsePromote();
  }

  if (use_promote) {
    if (paddle::imperative::AmpOperators::Instance()
            .GetMutableAllowOps()
            ->count(op_name)) {
      dst_type = amp_setting_dtype;
    } else if (paddle::imperative::AmpOperators::Instance()
                   .GetMutableBlockOps()
                   ->count(op_name)) {
137 138
      dst_type = phi::DataType::FLOAT32;
    } else {
139 140 141 142 143 144 145 146 147 148 149 150 151
      if (amp_level == paddle::imperative::AmpLevel::OD) {
        dst_type = phi::DataType::FLOAT32;
      } else {
        dst_type =
            GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
      }
    }
  } else {
    // use_promote can be set to false only for O2 training.
    if (paddle::imperative::AmpOperators::Instance()
            .GetMutableBlockOps()
            ->count(op_name)) {
      dst_type = phi::DataType::FLOAT32;
152
    }
153
  }
Z
Zhang Ting 已提交
154 155 156 157 158

  if (dst_type == amp_setting_dtype &&
      (paddle::imperative::AmpOperators::Instance()
           .GetMutableUnsupportedOps(amp_setting_dtype)
           ->count(op_name))) {
159
    dst_type = phi::DataType::FLOAT32;
Z
Zhang Ting 已提交
160 161 162 163 164 165 166
  }

  dst_type = GetDtypeWithPlace(op_name, amp_tensors_vector, dst_type);
  VLOG(6) << "AMP GetAmpDestDtype:"
          << " op(" << op_name << ") amp_dtype(" << dst_type << ") amp_level("
          << static_cast<int>(amp_level) << ").";
  return dst_type;
167 168 169
}

}  // namespace egr