未验证 提交 02e4f1f8 编写于 作者: N niuliling123 提交者: GitHub

Add Concat transformer for layout autotune (#42003)

* Add Concat transformer for layout autotune
上级 d4372a1e
......@@ -131,6 +131,8 @@ paddle::imperative::NameVarMap<VarType> DealLightlyLayoutSensitive(
transposer = std::make_shared<FlattenOpTransformer<VarType>>(op_type);
} else if (op_type == "arg_max") {
transposer = std::make_shared<ArgmaxOpTransformer<VarType>>(op_type);
} else if (op_type == "concat") {
transposer = std::make_shared<ConcatOpTransformer<VarType>>(op_type);
} else if (op_type.find("elementwise_") != std::string::npos) {
transposer = std::make_shared<ElementwiseOpTransformer<VarType>>(op_type);
} else {
......
......@@ -401,5 +401,51 @@ class ArgmaxOpTransformer
}
};
template <typename VarType>
class ConcatOpTransformer
: public LightlyLayoutSensitiveOpTransformer<VarType> {
public:
explicit ConcatOpTransformer(const std::string& type)
: LightlyLayoutSensitiveOpTransformer<VarType>(type) {}
paddle::imperative::NameVarMap<VarType> Apply(
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<paddle::imperative::Tracer>& tracer) {
VLOG(3) << "Optimze lightly layout sensitive op " << this->Type();
auto& in_var = ins.at("X")[0];
auto var_layout = paddle::imperative::GetDataLayout(in_var);
bool need_tranppose = false;
for (auto& pair : ins) {
for (auto& var : pair.second) {
if (var != nullptr &&
(paddle::imperative::GetDataLayout(var) != var_layout)) {
need_tranppose = true;
break;
}
}
}
if (need_tranppose) {
return LightlyLayoutSensitiveOpTransformer<VarType>::Apply(
ins, outs, attrs, tracer);
}
if (var_layout != DataLayout::UNDEFINED) {
std::vector<int> perm_nhwc = {0, 3, 1, 2};
std::vector<int> perm_nchw = {0, 2, 3, 1};
auto perm = var_layout == DataLayout::NHWC ? perm_nhwc : perm_nchw;
auto axis = BOOST_GET_CONST(int, (*attrs)["axis"]);
(*attrs)["axis"] = static_cast<int>(perm[axis]);
}
auto axis = BOOST_GET_CONST(int, (*attrs)["axis"]);
VLOG(3) << "Optimze lightly layout sensitive op asdfasdfasdf axis" << axis;
this->SetVarsLayout(outs, var_layout);
return ins;
}
};
} // namespace imperative
} // namespace paddle
......@@ -161,6 +161,35 @@ class LayoutAutoTune(unittest.TestCase):
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [1])
def test_concat_op_transposer(self):
if not self.use_autoune():
return
in1 = paddle.rand([1, 8, 14, 12])
conv = paddle.nn.Conv2D(3, 8, (3, 3))
data = paddle.rand([1, 3, 16, 14])
with paddle.amp.auto_cast(level="O2"):
conv_out = conv(data)
# conv_out.shape = [1, 14, 12, 8] with NHWC
out = paddle.concat(x=[conv_out, in1], axis=0)
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [2, 8, 14, 12])
def test_concat_op_no_transposer(self):
if not self.use_autoune():
return
conv = paddle.nn.Conv2D(3, 8, (3, 3))
data1 = paddle.rand([1, 3, 16, 14])
data2 = paddle.rand([1, 3, 16, 14])
with paddle.amp.auto_cast(level="O2"):
conv_out1 = conv(data1)
conv_out2 = conv(data2)
# conv_out.shape = [1, 14, 12, 8] with NHWC
out = paddle.concat(x=[conv_out1, conv_out2], axis=0)
self.assertEqual(conv_out1.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [2, 14, 12, 8])
class TestAutoTuneAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册