amp_auto_cast.cc 11.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/amp_auto_cast.h"
#include <memory>
#include <string>
J
Jiabin Yang 已提交
18
#include "paddle/fluid/eager/eager_tensor.h"
19
#include "paddle/fluid/imperative/tracer.h"
J
Jiabin Yang 已提交
20 21
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h"
22 23 24 25

namespace paddle {
namespace imperative {

W
wanghuancoder 已提交
26 27
class VarBase;

L
Leo Chen 已提交
28 29 30 31 32 33 34 35 36 37 38
AutoCastGuard::AutoCastGuard(std::shared_ptr<Tracer> tracer, AmpLevel level)
    : tracer_(tracer) {
  pre_amp_level_ = tracer_->GetAmpLevel();

  if (pre_amp_level_ != level) {
    tracer_->SetAmpLevel(level);
  }
}

AutoCastGuard::~AutoCastGuard() { tracer_->SetAmpLevel(pre_amp_level_); }

39 40
AmpOperators::AmpOperators()
    : allow_ops_(new std::unordered_set<std::string>()),
41 42 43 44 45 46 47
      block_ops_(new std::unordered_set<std::string>()),
      unsupported_fp16_ops_(new std::unordered_set<std::string>()) {
  auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
  auto fp16_dtype = framework::proto::VarType::FP16;
  for (auto it = all_kernels.begin(); it != all_kernels.end(); it++) {
    bool supported = false;
    for (auto& kernel_type : it->second) {
T
taixiurong 已提交
48 49
      if ((platform::is_gpu_place(kernel_type.first.place_) ||
           platform::is_xpu_place(kernel_type.first.place_)) &&
50 51 52 53 54 55 56 57 58 59
          kernel_type.first.data_type_ == fp16_dtype) {
        supported = true;
      }
    }
    if (!supported) {
      unsupported_fp16_ops_->insert(it->first);
    }
  }
}

60 61 62 63 64 65 66
AmpOperators::~AmpOperators() {}

AmpOperators& AmpOperators::Instance() {
  static AmpOperators instance;
  return instance;
}

67 68
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableAllowOps() {
69 70 71
  return allow_ops_;
}

72 73
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableBlockOps() {
74 75 76
  return block_ops_;
}

77 78 79 80 81
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableUnsupportedFp16Ops() {
  return unsupported_fp16_ops_;
}

82 83 84 85 86
std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
  os << "allow ops: ";
  auto allow_ops = ops.GetMutableAllowOps();
  std::copy((*allow_ops).begin(), (*allow_ops).end(),
            std::ostream_iterator<std::string>(os, " "));
87
  os << "\n";
88 89 90 91
  os << "block ops: ";
  auto block_ops = ops.GetMutableBlockOps();
  std::copy((*block_ops).begin(), (*block_ops).end(),
            std::ostream_iterator<std::string>(os, " "));
92 93 94 95 96
  os << "\n";
  os << "unsupported fp16 ops: ";
  auto unsupported_fp16_ops = ops.GetMutableUnsupportedFp16Ops();
  std::copy((*unsupported_fp16_ops).begin(), (*unsupported_fp16_ops).end(),
            std::ostream_iterator<std::string>(os, " "));
97 98 99
  return os;
}

J
Jiabin Yang 已提交
100 101 102
template <typename VarType>
inline std::string GetDtypeStr(const std::shared_ptr<VarType>& var) {
  return framework::DataTypeToString(GetDataType<VarType>(var));
103
}
J
Jiabin Yang 已提交
104 105 106 107 108 109 110
template <typename VarType>
inline bool NeedCast(const std::shared_ptr<VarType>& var) {
  auto place = GetPlace(var);
  auto data_type = GetDataType<VarType>(var);
  if (paddle::platform::is_gpu_place(place) ||
      paddle::platform::is_cuda_pinned_place(place) ||
      paddle::platform::is_xpu_place(place)) {
L
Leo Chen 已提交
111
    // CudaPinndePlace is added for varbase created by dataloader
J
Jiabin Yang 已提交
112 113
    if (data_type == paddle::framework::proto::VarType::FP32 ||
        data_type == paddle::framework::proto::VarType::FP16) {
L
Leo Chen 已提交
114 115
      return true;
    }
116
  }
L
Leo Chen 已提交
117
  return false;
118 119 120 121
}

// NOTE: Trace a cast op, so if a var is casted from fp32 to fp16, then the grad
// var will be cast back from fp16 to fp32 during backward phase.
J
Jiabin Yang 已提交
122 123 124
template <typename VarType>
static inline std::shared_ptr<VarType> CastToType(
    const std::shared_ptr<VarType>& var,
125 126
    const framework::proto::VarType::Type dst_type) {
  const auto& tracer = imperative::GetCurrentTracer();
J
Jiabin Yang 已提交
127 128
  imperative::NameVarMap<VarType> ins = {{"X", {var}}};
  framework::AttributeMap attrs = {{"in_dtype", GetDataType<VarType>(var)},
129
                                   {"out_dtype", dst_type}};
J
Jiabin Yang 已提交
130 131 132
  auto out =
      std::shared_ptr<VarType>(new VarType(tracer->GenerateUniqueName()));
  imperative::NameVarMap<VarType> outs = {{"Out", {out}}};
133 134

  {
L
Leo Chen 已提交
135
    AutoCastGuard guard(tracer, AmpLevel::O0);
136 137 138 139 140
    tracer->TraceOp("cast", ins, outs, std::move(attrs));
  }

  return out;
}
J
Jiabin Yang 已提交
141 142 143
template <typename VarType>
static inline std::shared_ptr<VarType> CastToFP16(
    const std::shared_ptr<VarType>& var) {
144
  auto dst_type = framework::proto::VarType::FP16;
J
Jiabin Yang 已提交
145
  if (NeedCast(var) && (GetDataType<VarType>(var) != dst_type)) {
146 147 148 149 150
    return CastToType(var, dst_type);
  }
  return var;
}

J
Jiabin Yang 已提交
151 152 153
template <typename VarType>
static inline std::shared_ptr<VarType> CastToFP32(
    const std::shared_ptr<VarType>& var) {
154
  auto dst_type = framework::proto::VarType::FP32;
J
Jiabin Yang 已提交
155
  if (NeedCast(var) && (GetDataType<VarType>(var) != dst_type)) {
156 157 158 159 160
    return CastToType(var, dst_type);
  }
  return var;
}

J
Jiabin Yang 已提交
161
template <typename VarType>
162
static inline framework::proto::VarType::Type GetPromoteType(
J
Jiabin Yang 已提交
163
    const std::string& op_type, const NameVarMap<VarType>& ins) {
164 165 166
  auto dst_type = framework::proto::VarType::FP16;
  for (const auto& pair : ins) {
    for (const auto& var : pair.second) {
J
Jiabin Yang 已提交
167 168
      if (GetDataType<VarType>(var) == framework::proto::VarType::FP32) {
        dst_type = GetDataType<VarType>(var);
169 170 171 172
        break;
      }
    }
  }
C
cc 已提交
173 174 175 176 177 178

  // NOTE(juncai): moving_average_abs_max_scale only consider the
  // dtype of input(X)
  if (op_type == "moving_average_abs_max_scale") {
    for (const auto& pair : ins) {
      if (pair.first == "X" &&
J
Jiabin Yang 已提交
179 180
          GetDataType<VarType>(pair.second.front()) ==
              framework::proto::VarType::FP16) {
C
cc 已提交
181 182 183 184 185
        dst_type = framework::proto::VarType::FP16;
      }
    }
  }

186 187 188
  return dst_type;
}

J
Jiabin Yang 已提交
189 190 191 192
template <typename VarType>
NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
                                   const NameVarMap<VarType>& ins) {
  NameVarMap<VarType> new_ins(ins);
193 194 195
  if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) {
    for (auto& pair : new_ins) {
      // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
196 197
      if ((op_type == "batch_norm" || op_type == "layer_norm" ||
           op_type == "sync_batch_norm") &&
198 199 200 201
          pair.first != "X") {
        continue;
      }

202 203 204 205 206 207 208 209
      if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
        if (pair.first == "LnScale" || pair.first == "LnBias" ||
            pair.first == "Ln2Scale" || pair.first == "Ln2Bias" ||
            pair.first == "Ln1Scale" || pair.first == "Ln1Bias") {
          continue;
        }
      }

210 211
      VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
              << GetDtypeStr(*pair.second.cbegin()) << " to float16";
212
      for (auto& var : pair.second) {
J
Jiabin Yang 已提交
213
        var = CastToFP16<VarType>(var);
214 215 216
      }
    }
    return new_ins;
217 218
  } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
    for (auto& pair : new_ins) {
219 220
      VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
              << GetDtypeStr(*pair.second.cbegin()) << " to float";
221
      for (auto& var : pair.second) {
J
Jiabin Yang 已提交
222
        var = CastToFP32<VarType>(var);
223 224 225 226
      }
    }
    return new_ins;
  } else {
J
Jiabin Yang 已提交
227
    auto dst_type = GetPromoteType<VarType>(op_type, ins);
C
cc 已提交
228

229 230 231 232 233 234
    // NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32.
    if (dst_type == framework::proto::VarType::FP16 &&
        AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(
            op_type)) {
      dst_type = framework::proto::VarType::FP32;
    }
235 236
    for (auto& pair : new_ins) {
      // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
237 238
      if ((op_type == "batch_norm" || op_type == "layer_norm" ||
           op_type == "sync_batch_norm") &&
239 240 241
          pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
        continue;
      }
242 243 244 245 246 247 248 249
      if ((op_type == "fused_attention" || op_type == "fused_feedforwad") &&
          dst_type == framework::proto::VarType::FP32) {
        if (pair.first != "LnScale" && pair.first != "LnBias" &&
            pair.first != "Ln2Scale" && pair.first != "Ln2Bias" &&
            pair.first != "Ln1Scale" && pair.first != "Ln1Bias") {
          continue;
        }
      }
250 251 252
      VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
              << GetDtypeStr(*pair.second.cbegin()) << " to "
              << framework::DataTypeToString(dst_type);
253
      for (auto& var : pair.second) {
J
Jiabin Yang 已提交
254 255 256
        var = (dst_type == framework::proto::VarType::FP32
                   ? CastToFP32<VarType>(var)
                   : CastToFP16<VarType>(var));
257 258 259 260
      }
    }
    return new_ins;
  }
261
  return new_ins;
262
}
J
Jiabin Yang 已提交
263 264 265 266 267 268 269 270
template NameVarMap<VarBase> AutoCastInputs<VarBase>(
    const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerTensor> AutoCastInputs<egr::EagerTensor>(
    const std::string& op_type, const NameVarMap<egr::EagerTensor>& ins);
template <typename VarType>
NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type,
                                       const NameVarMap<VarType>& ins) {
  NameVarMap<VarType> new_ins(ins);
271 272 273 274 275 276
  auto dst_type = framework::proto::VarType::FP16;
  if (AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(op_type) ||
      AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
    dst_type = framework::proto::VarType::FP32;
  }
  for (auto& pair : new_ins) {
277 278 279 280 281 282 283
    // NOTE: The run_program OP only has FP32 kernel. In dy2stat pure fp16
    // training, we have correctly cast the inputs of run_program OP before,
    // so here should avoid casting for run_program OP.
    if (op_type == "run_program") {
      continue;
    }

284 285 286 287 288
    if ((op_type == "batch_norm" || op_type == "layer_norm" ||
         op_type == "sync_batch_norm") &&
        pair.first != "X") {
      continue;
    }
Z
zhangkaihuo 已提交
289 290 291 292 293 294 295
    if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
      if (pair.first == "LnScale" || pair.first == "LnBias" ||
          pair.first == "Ln2Scale" || pair.first == "Ln2Bias" ||
          pair.first == "Ln1Scale" || pair.first == "Ln1Bias") {
        continue;
      }
    }
296 297 298 299
    VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
            << GetDtypeStr(*pair.second.cbegin()) << " to "
            << framework::DataTypeToString(dst_type);
    for (auto& var : pair.second) {
J
Jiabin Yang 已提交
300 301 302
      var = (dst_type == framework::proto::VarType::FP32
                 ? CastToFP32<VarType>(var)
                 : CastToFP16<VarType>(var));
303 304 305 306
    }
  }
  return new_ins;
}
J
Jiabin Yang 已提交
307 308 309 310
template NameVarMap<VarBase> CastPureFp16Inputs<VarBase>(
    const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerTensor> CastPureFp16Inputs<egr::EagerTensor>(
    const std::string& op_type, const NameVarMap<egr::EagerTensor>& ins);
311 312
}  // namespace imperative
}  // namespace paddle