未验证 提交 0ea41e62 编写于 作者: L Leo Chen 提交者: GitHub

[cherry-pick] support layer_norm fp16 in dygraph amp (#30430) #30566

[cherry-pick] support layer_norm fp16 in dygraph amp (#30430)
上级 96058384
......@@ -14,6 +14,7 @@
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
......@@ -35,14 +36,29 @@ AmpOperators& AmpOperators::Instance() {
return instance;
}
std::shared_ptr<std::unordered_set<std::string>> AmpOperators::GetAllowOps() {
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableAllowOps() {
return allow_ops_;
}
std::shared_ptr<std::unordered_set<std::string>> AmpOperators::GetBlockOps() {
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableBlockOps() {
return block_ops_;
}
std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
os << "allow ops: ";
auto allow_ops = ops.GetMutableAllowOps();
std::copy((*allow_ops).begin(), (*allow_ops).end(),
std::ostream_iterator<std::string>(os, " "));
os << "; ";
os << "block ops: ";
auto block_ops = ops.GetMutableBlockOps();
std::copy((*block_ops).begin(), (*block_ops).end(),
std::ostream_iterator<std::string>(os, " "));
return os;
}
inline std::string GetDtypeStr(
const std::shared_ptr<imperative::VarBase>& var) {
return framework::DataTypeToString(var->DataType());
......@@ -115,51 +131,50 @@ static inline framework::proto::VarType::Type GetPromoteType(
NameVarBaseMap AutoCastInputs(const std::string& op_type,
const NameVarBaseMap& ins) {
NameVarBaseMap new_ins = {};
if (AmpOperators::Instance().GetAllowOps()->count(op_type)) {
for (const auto& pair : ins) {
NameVarBaseMap new_ins(ins);
if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) {
for (auto& pair : new_ins) {
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
if ((op_type == "batch_norm" || op_type == "layer_norm") &&
pair.first != "X") {
continue;
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float16";
for (const auto& var : pair.second) {
auto new_var = CastToFP16(var);
new_ins[pair.first].emplace_back(new_var);
for (auto& var : pair.second) {
var = CastToFP16(var);
}
}
return new_ins;
} else if (AmpOperators::Instance().GetBlockOps()->count(op_type)) {
for (const auto& pair : ins) {
} else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float";
for (const auto& var : pair.second) {
auto new_var = CastToFP32(var);
new_ins[pair.first].emplace_back(new_var);
for (auto& var : pair.second) {
var = CastToFP32(var);
}
}
return new_ins;
} else {
auto dst_type = GetPromoteType(ins);
for (const auto& pair : ins) {
for (auto& pair : new_ins) {
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
if ((op_type == "batch_norm" || op_type == "layer_norm") &&
pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
continue;
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type);
for (const auto& var : pair.second) {
// NOTE(zhiqiu): Conv + BN always occur together, we needn't
// cast X of batch_norm to FP32, which is produced by conv as FP16 type.
if (op_type == "batch_norm" && pair.first == "X" &&
dst_type == framework::proto::VarType::FP32) {
new_ins[pair.first].emplace_back(var);
continue;
}
auto new_var = dst_type == framework::proto::VarType::FP32
? CastToFP32(var)
: CastToFP16(var);
new_ins[pair.first].emplace_back(new_var);
for (auto& var : pair.second) {
var = (dst_type == framework::proto::VarType::FP32 ? CastToFP32(var)
: CastToFP16(var));
}
}
return new_ins;
}
return ins;
return new_ins;
}
} // namespace imperative
......
......@@ -36,9 +36,9 @@ class AmpOperators {
static AmpOperators& Instance();
std::shared_ptr<std::unordered_set<std::string>> GetAllowOps();
std::shared_ptr<std::unordered_set<std::string>> GetMutableAllowOps();
std::shared_ptr<std::unordered_set<std::string>> GetBlockOps();
std::shared_ptr<std::unordered_set<std::string>> GetMutableBlockOps();
private:
AmpOperators(); // forbid calling default constructor
......@@ -52,6 +52,8 @@ class AmpOperators {
std::shared_ptr<std::unordered_set<std::string>> block_ops_;
};
std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
// NOTE(zhiqiu): AutoCastGuard is used for RAII.
class AutoCastGuard {
public:
......
......@@ -1257,8 +1257,7 @@ void BindImperative(py::module *m_ptr) {
py::return_value_policy::reference)
.def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName,
py::arg("key") = "dygraph_tmp")
.def(
"_set_amp_op_list",
.def("_set_amp_op_list",
[](imperative::Tracer &self,
std::unordered_set<std::string> &allow_ops,
std::unordered_set<std::string> &block_ops) {
......@@ -1270,14 +1269,18 @@ void BindImperative(py::module *m_ptr) {
// std::shared_ptr<std::unordered_set<std::string>>
// is that pybind11 forbid shared_ptr<T> where T is not custom
// type.
imperative::AmpOperators::Instance().GetAllowOps()->swap(allow_ops);
imperative::AmpOperators::Instance().GetBlockOps()->swap(block_ops);
imperative::AmpOperators::Instance().GetMutableAllowOps()->swap(
allow_ops);
imperative::AmpOperators::Instance().GetMutableBlockOps()->swap(
block_ops);
VLOG(4) << "AMP operators changed, "
<< imperative::AmpOperators::Instance();
})
.def("_get_amp_op_list",
[](imperative::Tracer &self) {
return std::make_tuple(
*(imperative::AmpOperators::Instance().GetAllowOps()),
*(imperative::AmpOperators::Instance().GetBlockOps()));
*(imperative::AmpOperators::Instance().GetMutableAllowOps()),
*(imperative::AmpOperators::Instance().GetMutableBlockOps()));
})
.def("trace",
[](imperative::Tracer &self, const std::string &type,
......
......@@ -389,5 +389,21 @@ class TestResnet(unittest.TestCase):
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2))
class TestLayerNormFp16(unittest.TestCase):
r''' layer_norm and batch_norm support mixed inputs, i.e., only input x is fp16
and other params are fp32.
'''
def test_layer_norm_fp16(self):
if fluid.is_compiled_with_cuda():
with fluid.dygraph.guard(fluid.CUDAPlace(0)):
x = paddle.rand([2, 2, 2, 3])
layer_norm = paddle.nn.LayerNorm(x.shape[1:])
with paddle.amp.auto_cast(custom_white_list=['layer_norm']):
out = layer_norm(x)
self.assertTrue(out.dtype == fluid.core.VarDesc.VarType.FP16)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册