未验证 提交 d2b31a14 编写于 作者: L Leo Chen 提交者: GitHub

[AMP] Autocast to fp32 for op has no fp16 kernel (#32543)

* skip op has no fp16 kernel

* add ut
上级 756f4639
......@@ -26,7 +26,24 @@ class VarBase;
AmpOperators::AmpOperators()
: allow_ops_(new std::unordered_set<std::string>()),
block_ops_(new std::unordered_set<std::string>()) {}
block_ops_(new std::unordered_set<std::string>()),
unsupported_fp16_ops_(new std::unordered_set<std::string>()) {
auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto fp16_dtype = framework::proto::VarType::FP16;
for (auto it = all_kernels.begin(); it != all_kernels.end(); it++) {
bool supported = false;
for (auto& kernel_type : it->second) {
if (platform::is_gpu_place(kernel_type.first.place_) &&
kernel_type.first.data_type_ == fp16_dtype) {
supported = true;
}
}
if (!supported) {
unsupported_fp16_ops_->insert(it->first);
}
}
}
AmpOperators::~AmpOperators() {}
AmpOperators& AmpOperators::Instance() {
......@@ -44,16 +61,26 @@ AmpOperators::GetMutableBlockOps() {
return block_ops_;
}
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableUnsupportedFp16Ops() {
return unsupported_fp16_ops_;
}
std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
os << "allow ops: ";
auto allow_ops = ops.GetMutableAllowOps();
std::copy((*allow_ops).begin(), (*allow_ops).end(),
std::ostream_iterator<std::string>(os, " "));
os << "; ";
os << "\n";
os << "block ops: ";
auto block_ops = ops.GetMutableBlockOps();
std::copy((*block_ops).begin(), (*block_ops).end(),
std::ostream_iterator<std::string>(os, " "));
os << "\n";
os << "unsupported fp16 ops: ";
auto unsupported_fp16_ops = ops.GetMutableUnsupportedFp16Ops();
std::copy((*unsupported_fp16_ops).begin(), (*unsupported_fp16_ops).end(),
std::ostream_iterator<std::string>(os, " "));
return os;
}
......@@ -156,6 +183,12 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
return new_ins;
} else {
auto dst_type = GetPromoteType(ins);
// NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32.
if (dst_type == framework::proto::VarType::FP16 &&
AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(
op_type)) {
dst_type = framework::proto::VarType::FP32;
}
for (auto& pair : new_ins) {
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
if ((op_type == "batch_norm" || op_type == "layer_norm") &&
......
......@@ -40,6 +40,9 @@ class AmpOperators {
std::shared_ptr<std::unordered_set<std::string>> GetMutableBlockOps();
std::shared_ptr<std::unordered_set<std::string>>
GetMutableUnsupportedFp16Ops();
private:
AmpOperators(); // forbid calling default constructor
......@@ -50,6 +53,9 @@ class AmpOperators {
// The set of ops that support fp16 calculation and are considered numerically
// dangerous and whose effects may also be observed in downstream ops.
std::shared_ptr<std::unordered_set<std::string>> block_ops_;
// The set of ops that has no fp16 CUDA kennel.
std::shared_ptr<std::unordered_set<std::string>> unsupported_fp16_ops_;
};
std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
......
......@@ -1488,7 +1488,7 @@ void BindImperative(py::module *m_ptr) {
allow_ops);
imperative::AmpOperators::Instance().GetMutableBlockOps()->swap(
block_ops);
VLOG(4) << "AMP operators changed, "
VLOG(5) << "AMP operators changed, "
<< imperative::AmpOperators::Instance();
})
.def("_get_amp_op_list",
......
......@@ -106,6 +106,20 @@ class TestAutoCast(unittest.TestCase):
self.assertRaises(ValueError, func)
def test_amp_guard_upsupported_fp16_op(self):
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with fluid.dygraph.guard():
conv2d = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None)
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard(True):
out_fp16 = conv2d(data)
out_fp32 = paddle.expand_as(
out_fp16, out_fp16) # expand_as_v2 has no fp16 kernel
self.assertTrue(data.dtype == fluid.core.VarDesc.VarType.FP32)
self.assertTrue(out_fp16.dtype == fluid.core.VarDesc.VarType.FP16)
self.assertTrue(out_fp32.dtype == fluid.core.VarDesc.VarType.FP32)
class TestAmpScaler(unittest.TestCase):
def test_scale(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册