amp_utils.h 6.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
17

18 19 20 21 22 23
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"

namespace egr {

static inline paddle::experimental::DataType GetPromoteType(
24
    const std::string& op_name,
25 26
    const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
                               kSlotSmallVectorSize>& amp_tensors_vector,
27 28 29 30
    const paddle::experimental::DataType& amp_dtype) {
  auto dst_type = amp_dtype;
  if (egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype() ==
      "float16") {
31 32
    if (op_name == "batch_norm" || op_name == "layer_norm" ||
        op_name == "sync_batch_norm") {
33 34 35 36
      if (amp_tensors_vector[0][0].dtype() ==
          paddle::experimental::DataType::FLOAT32) {
        dst_type = paddle::experimental::DataType::FLOAT32;
      }
37
    } else if (op_name == "fused_attention") {
38 39 40 41 42 43 44 45 46
      for (size_t i = 0; i < amp_tensors_vector.size(); i++) {
        if (i != 3 || i != 4 || i != 9 || i != 10) {
          if (amp_tensors_vector[i][0].dtype() ==
              paddle::experimental::DataType::FLOAT32) {
            dst_type = paddle::experimental::DataType::FLOAT32;
            break;
          }
        }
      }
47
    } else if (op_name == "fused_feedforward") {
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
      for (size_t i = 0; i < amp_tensors_vector.size(); i++) {
        if (i != 7 || i != 8 || i != 9 || i != 10) {
          if (amp_tensors_vector[i][0].dtype() ==
              paddle::experimental::DataType::FLOAT32) {
            dst_type = paddle::experimental::DataType::FLOAT32;
            break;
          }
        }
      }
    } else {
      for (const auto& tensors : amp_tensors_vector) {
        for (const auto& tensor : tensors) {
          if (tensor.dtype() == paddle::experimental::DataType::FLOAT32) {
            dst_type = tensor.dtype();
            break;
          }
        }
      }
    }
  } else {
    for (const auto& tensors : amp_tensors_vector) {
      for (const auto& tensor : tensors) {
        if (tensor.dtype() == paddle::experimental::DataType::FLOAT32) {
          dst_type = tensor.dtype();
          break;
        }
      }
    }
  }
  // NOTE(juncai): moving_average_abs_max_scale only consider the dtype of
  // input(X)
79
  if (op_name == "moving_average_abs_max_scale") {
80 81 82 83 84 85 86 87
    if (amp_tensors_vector[0][0].dtype() ==
        paddle::experimental::DataType::FLOAT16) {
      dst_type = paddle::experimental::DataType::FLOAT16;
    }
  }
  return dst_type;
}

88 89
inline paddle::experimental::DataType GetAmpDestDtype(
    const std::string& op_name,
90 91
    const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
                               kSlotSmallVectorSize>& amp_tensors_vector) {
92 93 94 95
  auto amp_dtype =
      egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype();
  auto amp_level = egr::Controller::Instance().GetAMPLevel();
  VLOG(6) << "AMP GetAmpDestDtype:"
96
          << " op(" << op_name << ") amp_dtype(" << amp_dtype << ") amp_level("
97 98 99 100 101
          << static_cast<int>(amp_level) << ").";
  if (amp_dtype == "float16") {
    if (amp_level == paddle::imperative::AmpLevel::O1) {
      if (paddle::imperative::AmpOperators::Instance()
              .GetMutableAllowOps()
102
              ->count(op_name)) {
103 104 105
        return paddle::experimental::DataType::FLOAT16;
      } else if (paddle::imperative::AmpOperators::Instance()
                     .GetMutableBlockOps()
Z
zhangbo9674 已提交
106 107 108
                     ->count(op_name) ||
                 paddle::imperative::AmpOperators::Instance()
                     .GetMutableUnsupportedFp16Ops()
109
                     ->count(op_name)) {
110 111
        return paddle::experimental::DataType::FLOAT32;
      } else {
112 113
        auto dst_type = GetPromoteType(op_name,
                                       amp_tensors_vector,
114 115 116 117
                                       paddle::experimental::DataType::FLOAT16);
        if (dst_type == paddle::experimental::DataType::FLOAT16 &&
            paddle::imperative::AmpOperators::Instance()
                .GetMutableUnsupportedFp16Ops()
118
                ->count(op_name)) {
119 120 121 122 123 124 125 126
          dst_type = paddle::experimental::DataType::FLOAT32;
        }
        return dst_type;
      }
    } else if (amp_level == paddle::imperative::AmpLevel::O2) {
      auto dst_type = paddle::experimental::DataType::FLOAT16;
      if (paddle::imperative::AmpOperators::Instance()
              .GetMutableUnsupportedFp16Ops()
127
              ->count(op_name) ||
128 129
          paddle::imperative::AmpOperators::Instance()
              .GetMutableBlockOps()
130
              ->count(op_name)) {
131 132 133 134 135 136 137 138
        dst_type = paddle::experimental::DataType::FLOAT32;
      }
      return dst_type;
    }
  } else if (amp_dtype == "bfloat16") {
    if (amp_level == paddle::imperative::AmpLevel::O1) {
      if (paddle::imperative::AmpOperators::Instance()
              .GetMutableAllowOps()
139
              ->count(op_name)) {
140 141 142
        return paddle::experimental::DataType::BFLOAT16;
      } else if (paddle::imperative::AmpOperators::Instance()
                     .GetMutableBlockOps()
143
                     ->count(op_name)) {
144 145 146
        return paddle::experimental::DataType::FLOAT32;
      } else {
        auto dst_type =
147 148
            GetPromoteType(op_name,
                           amp_tensors_vector,
149 150 151 152
                           paddle::experimental::DataType::BFLOAT16);
        if (dst_type == paddle::experimental::DataType::BFLOAT16 &&
            paddle::imperative::AmpOperators::Instance()
                .GetMutableUnsupportedBf16Ops()
153
                ->count(op_name)) {
154 155 156 157 158 159 160 161
          dst_type = paddle::experimental::DataType::FLOAT32;
        }
        return dst_type;
      }
    } else if (amp_level == paddle::imperative::AmpLevel::O2) {
      auto dst_type = paddle::experimental::DataType::BFLOAT16;
      if (paddle::imperative::AmpOperators::Instance()
              .GetMutableUnsupportedBf16Ops()
162
              ->count(op_name) ||
163 164
          paddle::imperative::AmpOperators::Instance()
              .GetMutableBlockOps()
165
              ->count(op_name)) {
166 167 168 169 170 171 172 173 174
        dst_type = paddle::experimental::DataType::FLOAT32;
      }
      return dst_type;
    }
  }
  return paddle::experimental::DataType::FLOAT32;
}

}  // namespace egr