amp_auto_cast.cc 8.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/amp_auto_cast.h"

#include <memory>
#include <string>

#include "paddle/fluid/imperative/tracer.h"

namespace paddle {
namespace imperative {

W
wanghuancoder 已提交
25 26
class VarBase;

27 28
AmpOperators::AmpOperators()
    : allow_ops_(new std::unordered_set<std::string>()),
29 30 31 32 33 34 35
      block_ops_(new std::unordered_set<std::string>()),
      unsupported_fp16_ops_(new std::unordered_set<std::string>()) {
  auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
  auto fp16_dtype = framework::proto::VarType::FP16;
  for (auto it = all_kernels.begin(); it != all_kernels.end(); it++) {
    bool supported = false;
    for (auto& kernel_type : it->second) {
T
taixiurong 已提交
36 37
      if ((platform::is_gpu_place(kernel_type.first.place_) ||
           platform::is_xpu_place(kernel_type.first.place_)) &&
38 39 40 41 42 43 44 45 46 47
          kernel_type.first.data_type_ == fp16_dtype) {
        supported = true;
      }
    }
    if (!supported) {
      unsupported_fp16_ops_->insert(it->first);
    }
  }
}

48 49 50 51 52 53 54
AmpOperators::~AmpOperators() {}

AmpOperators& AmpOperators::Instance() {
  static AmpOperators instance;
  return instance;
}

55 56
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableAllowOps() {
57 58 59
  return allow_ops_;
}

60 61
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableBlockOps() {
62 63 64
  return block_ops_;
}

65 66 67 68 69
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableUnsupportedFp16Ops() {
  return unsupported_fp16_ops_;
}

70 71 72 73 74
std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
  os << "allow ops: ";
  auto allow_ops = ops.GetMutableAllowOps();
  std::copy((*allow_ops).begin(), (*allow_ops).end(),
            std::ostream_iterator<std::string>(os, " "));
75
  os << "\n";
76 77 78 79
  os << "block ops: ";
  auto block_ops = ops.GetMutableBlockOps();
  std::copy((*block_ops).begin(), (*block_ops).end(),
            std::ostream_iterator<std::string>(os, " "));
80 81 82 83 84
  os << "\n";
  os << "unsupported fp16 ops: ";
  auto unsupported_fp16_ops = ops.GetMutableUnsupportedFp16Ops();
  std::copy((*unsupported_fp16_ops).begin(), (*unsupported_fp16_ops).end(),
            std::ostream_iterator<std::string>(os, " "));
85 86 87
  return os;
}

88 89 90 91 92 93
inline std::string GetDtypeStr(
    const std::shared_ptr<imperative::VarBase>& var) {
  return framework::DataTypeToString(var->DataType());
}

inline bool NeedCast(const std::shared_ptr<VarBase>& var) {
L
Leo Chen 已提交
94
  if (platform::is_gpu_place(var->Place()) ||
T
taixiurong 已提交
95 96
      platform::is_cuda_pinned_place(var->Place()) ||
      platform::is_xpu_place(var->Place())) {
L
Leo Chen 已提交
97 98 99 100 101
    // CudaPinndePlace is added for varbase created by dataloader
    if (var->DataType() == framework::proto::VarType::FP32 ||
        var->DataType() == framework::proto::VarType::FP16) {
      return true;
    }
102
  }
L
Leo Chen 已提交
103
  return false;
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
}

// NOTE: Trace a cast op, so if a var is casted from fp32 to fp16, then the grad
// var will be cast back from fp16 to fp32 during backward phase.
static inline std::shared_ptr<imperative::VarBase> CastToType(
    const std::shared_ptr<VarBase>& var,
    const framework::proto::VarType::Type dst_type) {
  const auto& tracer = imperative::GetCurrentTracer();
  imperative::NameVarBaseMap ins = {{"X", {var}}};
  framework::AttributeMap attrs = {{"in_dtype", var->DataType()},
                                   {"out_dtype", dst_type}};
  auto out = std::shared_ptr<imperative::VarBase>(
      new imperative::VarBase(tracer->GenerateUniqueName()));
  imperative::NameVarBaseMap outs = {{"Out", {out}}};

  {
120
    AutoCastGuard guard(tracer, 0);
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    tracer->TraceOp("cast", ins, outs, std::move(attrs));
  }

  return out;
}

static inline std::shared_ptr<imperative::VarBase> CastToFP16(
    const std::shared_ptr<VarBase>& var) {
  auto dst_type = framework::proto::VarType::FP16;
  if (NeedCast(var) && (var->DataType() != dst_type)) {
    return CastToType(var, dst_type);
  }
  return var;
}

static inline std::shared_ptr<imperative::VarBase> CastToFP32(
    const std::shared_ptr<VarBase>& var) {
  auto dst_type = framework::proto::VarType::FP32;
  if (NeedCast(var) && (var->DataType() != dst_type)) {
    return CastToType(var, dst_type);
  }
  return var;
}

static inline framework::proto::VarType::Type GetPromoteType(
C
cc 已提交
146
    const std::string& op_type, const NameVarBaseMap& ins) {
147 148 149 150 151 152 153 154 155
  auto dst_type = framework::proto::VarType::FP16;
  for (const auto& pair : ins) {
    for (const auto& var : pair.second) {
      if (var->DataType() == framework::proto::VarType::FP32) {
        dst_type = var->DataType();
        break;
      }
    }
  }
C
cc 已提交
156 157 158 159 160 161 162 163 164 165 166 167

  // NOTE(juncai): moving_average_abs_max_scale only consider the
  // dtype of input(X)
  if (op_type == "moving_average_abs_max_scale") {
    for (const auto& pair : ins) {
      if (pair.first == "X" &&
          pair.second.front()->DataType() == framework::proto::VarType::FP16) {
        dst_type = framework::proto::VarType::FP16;
      }
    }
  }

168 169 170 171 172
  return dst_type;
}

NameVarBaseMap AutoCastInputs(const std::string& op_type,
                              const NameVarBaseMap& ins) {
173 174 175 176
  NameVarBaseMap new_ins(ins);
  if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) {
    for (auto& pair : new_ins) {
      // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
177 178
      if ((op_type == "batch_norm" || op_type == "layer_norm" ||
           op_type == "sync_batch_norm") &&
179 180 181 182
          pair.first != "X") {
        continue;
      }

183 184
      VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
              << GetDtypeStr(*pair.second.cbegin()) << " to float16";
185 186
      for (auto& var : pair.second) {
        var = CastToFP16(var);
187 188 189
      }
    }
    return new_ins;
190 191
  } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
    for (auto& pair : new_ins) {
192 193
      VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
              << GetDtypeStr(*pair.second.cbegin()) << " to float";
194 195
      for (auto& var : pair.second) {
        var = CastToFP32(var);
196 197 198 199
      }
    }
    return new_ins;
  } else {
C
cc 已提交
200 201
    auto dst_type = GetPromoteType(op_type, ins);

202 203 204 205 206 207
    // NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32.
    if (dst_type == framework::proto::VarType::FP16 &&
        AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(
            op_type)) {
      dst_type = framework::proto::VarType::FP32;
    }
208 209
    for (auto& pair : new_ins) {
      // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
210 211
      if ((op_type == "batch_norm" || op_type == "layer_norm" ||
           op_type == "sync_batch_norm") &&
212 213 214
          pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
        continue;
      }
215 216 217
      VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
              << GetDtypeStr(*pair.second.cbegin()) << " to "
              << framework::DataTypeToString(dst_type);
218 219 220
      for (auto& var : pair.second) {
        var = (dst_type == framework::proto::VarType::FP32 ? CastToFP32(var)
                                                           : CastToFP16(var));
221 222 223 224
      }
    }
    return new_ins;
  }
225
  return new_ins;
226 227
}

228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
NameVarBaseMap CastPureFp16Inputs(const std::string& op_type,
                                  const NameVarBaseMap& ins) {
  NameVarBaseMap new_ins(ins);
  auto dst_type = framework::proto::VarType::FP16;
  if (AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(op_type) ||
      AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
    dst_type = framework::proto::VarType::FP32;
  }
  for (auto& pair : new_ins) {
    if ((op_type == "batch_norm" || op_type == "layer_norm" ||
         op_type == "sync_batch_norm") &&
        pair.first != "X") {
      continue;
    }
    VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
            << GetDtypeStr(*pair.second.cbegin()) << " to "
            << framework::DataTypeToString(dst_type);
    for (auto& var : pair.second) {
      var = (dst_type == framework::proto::VarType::FP32 ? CastToFP32(var)
                                                         : CastToFP16(var));
    }
  }
  return new_ins;
}

253 254
}  // namespace imperative
}  // namespace paddle