未验证 提交 7043b8cf 编写于 作者: L Leo Chen 提交者: GitHub

support layer_norm fp16 in dygraph amp (#30430)

* support layer_norm fp16 in dygraph amp

* add ut

* refine code
上级 28eb7b65
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/imperative/amp_auto_cast.h" #include "paddle/fluid/imperative/amp_auto_cast.h"
#include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -35,14 +36,29 @@ AmpOperators& AmpOperators::Instance() { ...@@ -35,14 +36,29 @@ AmpOperators& AmpOperators::Instance() {
return instance; return instance;
} }
std::shared_ptr<std::unordered_set<std::string>> AmpOperators::GetAllowOps() { std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableAllowOps() {
return allow_ops_; return allow_ops_;
} }
std::shared_ptr<std::unordered_set<std::string>> AmpOperators::GetBlockOps() { std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableBlockOps() {
return block_ops_; return block_ops_;
} }
std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
os << "allow ops: ";
auto allow_ops = ops.GetMutableAllowOps();
std::copy((*allow_ops).begin(), (*allow_ops).end(),
std::ostream_iterator<std::string>(os, " "));
os << "; ";
os << "block ops: ";
auto block_ops = ops.GetMutableBlockOps();
std::copy((*block_ops).begin(), (*block_ops).end(),
std::ostream_iterator<std::string>(os, " "));
return os;
}
inline std::string GetDtypeStr( inline std::string GetDtypeStr(
const std::shared_ptr<imperative::VarBase>& var) { const std::shared_ptr<imperative::VarBase>& var) {
return framework::DataTypeToString(var->DataType()); return framework::DataTypeToString(var->DataType());
...@@ -115,51 +131,50 @@ static inline framework::proto::VarType::Type GetPromoteType( ...@@ -115,51 +131,50 @@ static inline framework::proto::VarType::Type GetPromoteType(
NameVarBaseMap AutoCastInputs(const std::string& op_type, NameVarBaseMap AutoCastInputs(const std::string& op_type,
const NameVarBaseMap& ins) { const NameVarBaseMap& ins) {
NameVarBaseMap new_ins = {}; NameVarBaseMap new_ins(ins);
if (AmpOperators::Instance().GetAllowOps()->count(op_type)) { if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) {
for (const auto& pair : ins) { for (auto& pair : new_ins) {
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
if ((op_type == "batch_norm" || op_type == "layer_norm") &&
pair.first != "X") {
continue;
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float16"; << GetDtypeStr(*pair.second.cbegin()) << " to float16";
for (const auto& var : pair.second) { for (auto& var : pair.second) {
auto new_var = CastToFP16(var); var = CastToFP16(var);
new_ins[pair.first].emplace_back(new_var);
} }
} }
return new_ins; return new_ins;
} else if (AmpOperators::Instance().GetBlockOps()->count(op_type)) { } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
for (const auto& pair : ins) { for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float"; << GetDtypeStr(*pair.second.cbegin()) << " to float";
for (const auto& var : pair.second) { for (auto& var : pair.second) {
auto new_var = CastToFP32(var); var = CastToFP32(var);
new_ins[pair.first].emplace_back(new_var);
} }
} }
return new_ins; return new_ins;
} else { } else {
auto dst_type = GetPromoteType(ins); auto dst_type = GetPromoteType(ins);
for (auto& pair : new_ins) {
for (const auto& pair : ins) { // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
if ((op_type == "batch_norm" || op_type == "layer_norm") &&
pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
continue;
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to " << GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type); << framework::DataTypeToString(dst_type);
for (const auto& var : pair.second) { for (auto& var : pair.second) {
// NOTE(zhiqiu): Conv + BN always occur together, we needn't var = (dst_type == framework::proto::VarType::FP32 ? CastToFP32(var)
// cast X of batch_norm to FP32, which is produced by conv as FP16 type. : CastToFP16(var));
if (op_type == "batch_norm" && pair.first == "X" &&
dst_type == framework::proto::VarType::FP32) {
new_ins[pair.first].emplace_back(var);
continue;
}
auto new_var = dst_type == framework::proto::VarType::FP32
? CastToFP32(var)
: CastToFP16(var);
new_ins[pair.first].emplace_back(new_var);
} }
} }
return new_ins; return new_ins;
} }
return ins; return new_ins;
} }
} // namespace imperative } // namespace imperative
......
...@@ -36,9 +36,9 @@ class AmpOperators { ...@@ -36,9 +36,9 @@ class AmpOperators {
static AmpOperators& Instance(); static AmpOperators& Instance();
std::shared_ptr<std::unordered_set<std::string>> GetAllowOps(); std::shared_ptr<std::unordered_set<std::string>> GetMutableAllowOps();
std::shared_ptr<std::unordered_set<std::string>> GetBlockOps(); std::shared_ptr<std::unordered_set<std::string>> GetMutableBlockOps();
private: private:
AmpOperators(); // forbid calling default constructor AmpOperators(); // forbid calling default constructor
...@@ -52,6 +52,8 @@ class AmpOperators { ...@@ -52,6 +52,8 @@ class AmpOperators {
std::shared_ptr<std::unordered_set<std::string>> block_ops_; std::shared_ptr<std::unordered_set<std::string>> block_ops_;
}; };
std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
// NOTE(zhiqiu): AutoCastGuard is used for RAII. // NOTE(zhiqiu): AutoCastGuard is used for RAII.
class AutoCastGuard { class AutoCastGuard {
public: public:
......
...@@ -1257,27 +1257,30 @@ void BindImperative(py::module *m_ptr) { ...@@ -1257,27 +1257,30 @@ void BindImperative(py::module *m_ptr) {
py::return_value_policy::reference) py::return_value_policy::reference)
.def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName, .def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName,
py::arg("key") = "dygraph_tmp") py::arg("key") = "dygraph_tmp")
.def( .def("_set_amp_op_list",
"_set_amp_op_list", [](imperative::Tracer &self,
[](imperative::Tracer &self, std::unordered_set<std::string> &allow_ops,
std::unordered_set<std::string> &allow_ops, std::unordered_set<std::string> &block_ops) {
std::unordered_set<std::string> &block_ops) { // NOTE(zhiqiu): The automatic conversion in pybind11 between
// NOTE(zhiqiu): The automatic conversion in pybind11 between // c++
// c++ // STL and python set/list/dict involve a copy operation that
// STL and python set/list/dict involve a copy operation that // prevents pass-by-reference semantics, so it is ok to swap.
// prevents pass-by-reference semantics, so it is ok to swap. // The reaseon why not directly pass
// The reaseon why not directly pass // std::shared_ptr<std::unordered_set<std::string>>
// std::shared_ptr<std::unordered_set<std::string>> // is that pybind11 forbid shared_ptr<T> where T is not custom
// is that pybind11 forbid shared_ptr<T> where T is not custom // type.
// type. imperative::AmpOperators::Instance().GetMutableAllowOps()->swap(
imperative::AmpOperators::Instance().GetAllowOps()->swap(allow_ops); allow_ops);
imperative::AmpOperators::Instance().GetBlockOps()->swap(block_ops); imperative::AmpOperators::Instance().GetMutableBlockOps()->swap(
}) block_ops);
VLOG(4) << "AMP operators changed, "
<< imperative::AmpOperators::Instance();
})
.def("_get_amp_op_list", .def("_get_amp_op_list",
[](imperative::Tracer &self) { [](imperative::Tracer &self) {
return std::make_tuple( return std::make_tuple(
*(imperative::AmpOperators::Instance().GetAllowOps()), *(imperative::AmpOperators::Instance().GetMutableAllowOps()),
*(imperative::AmpOperators::Instance().GetBlockOps())); *(imperative::AmpOperators::Instance().GetMutableBlockOps()));
}) })
.def("trace", .def("trace",
[](imperative::Tracer &self, const std::string &type, [](imperative::Tracer &self, const std::string &type,
......
...@@ -389,5 +389,21 @@ class TestResnet(unittest.TestCase): ...@@ -389,5 +389,21 @@ class TestResnet(unittest.TestCase):
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2)) self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2))
class TestLayerNormFp16(unittest.TestCase):
r''' layer_norm and batch_norm support mixed inputs, i.e., only input x is fp16
and other params are fp32.
'''
def test_layer_norm_fp16(self):
if fluid.is_compiled_with_cuda():
with fluid.dygraph.guard(fluid.CUDAPlace(0)):
x = paddle.rand([2, 2, 2, 3])
layer_norm = paddle.nn.LayerNorm(x.shape[1:])
with paddle.amp.auto_cast(custom_white_list=['layer_norm']):
out = layer_norm(x)
self.assertTrue(out.dtype == fluid.core.VarDesc.VarType.FP16)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册