From 7043b8cfc67989720e4fb53bcb43fa20ea98ca73 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 19 Jan 2021 14:59:55 +0800 Subject: [PATCH] support layer_norm fp16 in dygraph amp (#30430) * support layer_norm fp16 in dygraph amp * add ut * refine code --- paddle/fluid/imperative/amp_auto_cast.cc | 71 +++++++++++-------- paddle/fluid/imperative/amp_auto_cast.h | 6 +- paddle/fluid/pybind/imperative.cc | 39 +++++----- .../test_imperative_auto_mixed_precision.py | 16 +++++ 4 files changed, 84 insertions(+), 48 deletions(-) diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index d0f3efcdf6..25580a8381 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/imperative/amp_auto_cast.h" +#include #include #include #include @@ -35,14 +36,29 @@ AmpOperators& AmpOperators::Instance() { return instance; } -std::shared_ptr> AmpOperators::GetAllowOps() { +std::shared_ptr> +AmpOperators::GetMutableAllowOps() { return allow_ops_; } -std::shared_ptr> AmpOperators::GetBlockOps() { +std::shared_ptr> +AmpOperators::GetMutableBlockOps() { return block_ops_; } +std::ostream& operator<<(std::ostream& os, AmpOperators& ops) { + os << "allow ops: "; + auto allow_ops = ops.GetMutableAllowOps(); + std::copy((*allow_ops).begin(), (*allow_ops).end(), + std::ostream_iterator(os, " ")); + os << "; "; + os << "block ops: "; + auto block_ops = ops.GetMutableBlockOps(); + std::copy((*block_ops).begin(), (*block_ops).end(), + std::ostream_iterator(os, " ")); + return os; +} + inline std::string GetDtypeStr( const std::shared_ptr& var) { return framework::DataTypeToString(var->DataType()); @@ -115,51 +131,50 @@ static inline framework::proto::VarType::Type GetPromoteType( NameVarBaseMap AutoCastInputs(const std::string& op_type, const NameVarBaseMap& ins) { - NameVarBaseMap new_ins = {}; - if (AmpOperators::Instance().GetAllowOps()->count(op_type)) { - for (const auto& pair : ins) { + NameVarBaseMap new_ins(ins); + if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) { + for (auto& pair : new_ins) { + // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16. + if ((op_type == "batch_norm" || op_type == "layer_norm") && + pair.first != "X") { + continue; + } + VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to float16"; - for (const auto& var : pair.second) { - auto new_var = CastToFP16(var); - new_ins[pair.first].emplace_back(new_var); + for (auto& var : pair.second) { + var = CastToFP16(var); } } return new_ins; - } else if (AmpOperators::Instance().GetBlockOps()->count(op_type)) { - for (const auto& pair : ins) { + } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) { + for (auto& pair : new_ins) { VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to float"; - for (const auto& var : pair.second) { - auto new_var = CastToFP32(var); - new_ins[pair.first].emplace_back(new_var); + for (auto& var : pair.second) { + var = CastToFP32(var); } } return new_ins; } else { auto dst_type = GetPromoteType(ins); - - for (const auto& pair : ins) { + for (auto& pair : new_ins) { + // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16. + if ((op_type == "batch_norm" || op_type == "layer_norm") && + pair.first == "X" && dst_type == framework::proto::VarType::FP32) { + continue; + } VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to " << framework::DataTypeToString(dst_type); - for (const auto& var : pair.second) { - // NOTE(zhiqiu): Conv + BN always occur together, we needn't - // cast X of batch_norm to FP32, which is produced by conv as FP16 type. - if (op_type == "batch_norm" && pair.first == "X" && - dst_type == framework::proto::VarType::FP32) { - new_ins[pair.first].emplace_back(var); - continue; - } - auto new_var = dst_type == framework::proto::VarType::FP32 - ? CastToFP32(var) - : CastToFP16(var); - new_ins[pair.first].emplace_back(new_var); + for (auto& var : pair.second) { + var = (dst_type == framework::proto::VarType::FP32 ? CastToFP32(var) + : CastToFP16(var)); } } return new_ins; } - return ins; + return new_ins; } } // namespace imperative diff --git a/paddle/fluid/imperative/amp_auto_cast.h b/paddle/fluid/imperative/amp_auto_cast.h index 7ab876c1ce..619c6b0baf 100644 --- a/paddle/fluid/imperative/amp_auto_cast.h +++ b/paddle/fluid/imperative/amp_auto_cast.h @@ -36,9 +36,9 @@ class AmpOperators { static AmpOperators& Instance(); - std::shared_ptr> GetAllowOps(); + std::shared_ptr> GetMutableAllowOps(); - std::shared_ptr> GetBlockOps(); + std::shared_ptr> GetMutableBlockOps(); private: AmpOperators(); // forbid calling default constructor @@ -52,6 +52,8 @@ class AmpOperators { std::shared_ptr> block_ops_; }; +std::ostream& operator<<(std::ostream& os, AmpOperators& ops); + // NOTE(zhiqiu): AutoCastGuard is used for RAII. class AutoCastGuard { public: diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 123cc0a875..87aa989c41 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1257,27 +1257,30 @@ void BindImperative(py::module *m_ptr) { py::return_value_policy::reference) .def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName, py::arg("key") = "dygraph_tmp") - .def( - "_set_amp_op_list", - [](imperative::Tracer &self, - std::unordered_set &allow_ops, - std::unordered_set &block_ops) { - // NOTE(zhiqiu): The automatic conversion in pybind11 between - // c++ - // STL and python set/list/dict involve a copy operation that - // prevents pass-by-reference semantics, so it is ok to swap. - // The reaseon why not directly pass - // std::shared_ptr> - // is that pybind11 forbid shared_ptr where T is not custom - // type. - imperative::AmpOperators::Instance().GetAllowOps()->swap(allow_ops); - imperative::AmpOperators::Instance().GetBlockOps()->swap(block_ops); - }) + .def("_set_amp_op_list", + [](imperative::Tracer &self, + std::unordered_set &allow_ops, + std::unordered_set &block_ops) { + // NOTE(zhiqiu): The automatic conversion in pybind11 between + // c++ + // STL and python set/list/dict involve a copy operation that + // prevents pass-by-reference semantics, so it is ok to swap. + // The reaseon why not directly pass + // std::shared_ptr> + // is that pybind11 forbid shared_ptr where T is not custom + // type. + imperative::AmpOperators::Instance().GetMutableAllowOps()->swap( + allow_ops); + imperative::AmpOperators::Instance().GetMutableBlockOps()->swap( + block_ops); + VLOG(4) << "AMP operators changed, " + << imperative::AmpOperators::Instance(); + }) .def("_get_amp_op_list", [](imperative::Tracer &self) { return std::make_tuple( - *(imperative::AmpOperators::Instance().GetAllowOps()), - *(imperative::AmpOperators::Instance().GetBlockOps())); + *(imperative::AmpOperators::Instance().GetMutableAllowOps()), + *(imperative::AmpOperators::Instance().GetMutableBlockOps())); }) .def("trace", [](imperative::Tracer &self, const std::string &type, diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 0118f3c800..ef2900be39 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -389,5 +389,21 @@ class TestResnet(unittest.TestCase): self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2)) +class TestLayerNormFp16(unittest.TestCase): + r''' layer_norm and batch_norm support mixed inputs, i.e., only input x is fp16 + and other params are fp32. + ''' + + def test_layer_norm_fp16(self): + if fluid.is_compiled_with_cuda(): + with fluid.dygraph.guard(fluid.CUDAPlace(0)): + x = paddle.rand([2, 2, 2, 3]) + layer_norm = paddle.nn.LayerNorm(x.shape[1:]) + with paddle.amp.auto_cast(custom_white_list=['layer_norm']): + out = layer_norm(x) + + self.assertTrue(out.dtype == fluid.core.VarDesc.VarType.FP16) + + if __name__ == '__main__': unittest.main() -- GitLab