未验证 提交 3da3462f 编写于 作者: N niuliling123 提交者: GitHub

Update layout autotune for module with no modified (#46541)

上级 20eb6e00
......@@ -1093,7 +1093,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
tensors_vector_list_str = "{ " + ",".join(
amp_tensors_vector_list) + " }"
if len(amp_tensors_vector_list) == 0:
if len(amp_tensors_vector_list) == 0: # or forward_api_name == "shape":
layout_logic_str = ""
else:
after_call_str = f"{returns_type_str} {result_name} = {forward_function_name}({layout_inputs_call_args_str});\n"
......
......@@ -32,70 +32,50 @@ inline bool NeedTransLayout(
}
return false;
}
inline std::shared_ptr<EagerLayoutTransformer> BaseTransformer(
const std::string& op_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize>& tensors_vector) {
std::shared_ptr<EagerLayoutTransformer> transposer = nullptr;
bool unstart =
(paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout() ==
paddle::experimental::DataLayout::UNDEFINED);
auto first_layout = tensors_vector[0][0].layout();
VLOG(3) << "Layout autotune was is start ? " << (!unstart) << op_name
<< "'s layout is " << first_layout;
transposer = std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, first_layout);
return transposer;
}
// For agnostic op like add, relu, exp
inline std::shared_ptr<EagerLayoutTransformer> EagerLayoutAutotune(
const std::string& op_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize>& tensors_vector) {
auto desired_layout =
paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout();
auto default_layout =
paddle::imperative::LayoutAutoTune::Instance().GetDefaultLayout();
// For agnostic op like add, relu, exp
auto first_layout = tensors_vector[0][0].layout();
if (NeedTransLayout(tensors_vector, first_layout)) {
auto desired_layout = DesiredLayout();
bool is_started =
!(desired_layout == paddle::experimental::DataLayout::UNDEFINED);
if (is_started && NeedTransLayout(tensors_vector, first_layout)) {
bool need_trans_back = false;
for (size_t i = 0; i < tensors_vector.size(); i++) {
for (size_t idx = 0; idx < tensors_vector[0].size(); idx++) {
if (4 != tensors_vector[i][idx].shape().size()) {
need_trans_back = true;
VLOG(3) << "Agnostic op " << op_name << " shape is "
<< tensors_vector[i][idx].shape().size() << " and layout is "
<< tensors_vector[i][idx].layout();
}
}
}
auto final_layout = need_trans_back ? default_layout : desired_layout;
auto final_layout = need_trans_back ? DefaultLayout() : desired_layout;
VLOG(4) << op_name << "'s has different layout, need trans to "
<< final_layout;
return std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, final_layout);
}
return BaseTransformer(op_name, tensors_vector);
return std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, first_layout);
}
// For lightly op like reduce
template <typename T>
inline std::shared_ptr<EagerLayoutTransformer> EagerLayoutAutotune(
const std::string& op_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize>& tensors_vector,
T* attr) {
VLOG(3) << "Lightly op " << op_name << "'s shape is "
<< tensors_vector[0][0].shape().size() << " and layout is "
<< tensors_vector[0][0].layout();
std::shared_ptr<EagerLayoutTransformer> transposer = nullptr;
transposer =
std::make_shared<EagerLightlyLayoutSensitiveOpTransformer>(op_name);
return transposer;
// For lightly op like reduce
if (!(DesiredLayout() == paddle::experimental::DataLayout::UNDEFINED)) {
VLOG(4) << "LayoutAutotune was unstarted. Current op :" << op_name;
return std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, tensors_vector[0][0].layout());
}
return std::make_shared<EagerLightlyLayoutSensitiveOpTransformer>(op_name);
}
// For lightly op like argmax
template <typename T1, typename T2>
inline std::shared_ptr<EagerLayoutTransformer> EagerLayoutAutotune(
const std::string& op_name,
......@@ -103,28 +83,23 @@ inline std::shared_ptr<EagerLayoutTransformer> EagerLayoutAutotune(
kSlotSmallVectorSize>& tensors_vector,
T1* axis,
T2* keep_dim) {
VLOG(3) << "Lightly op " << op_name << "'s shape is "
<< tensors_vector[0][0].shape().size() << " and layout is "
<< tensors_vector[0][0].layout();
// For lightly op like argmax
return EagerLayoutAutotune<T1>(op_name, tensors_vector, axis);
}
// heavily string data_format, data_layout
template <>
inline std::shared_ptr<EagerLayoutTransformer> EagerLayoutAutotune(
const std::string& op_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize>& tensors_vector,
std::string* attr) {
auto first_layout = tensors_vector[0][0].layout();
// Heavily op with (string) data_format, data_layout
auto transposer = std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, first_layout);
if (paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout() ==
paddle::experimental::DataLayout::UNDEFINED) {
op_name, tensors_vector, tensors_vector[0][0].layout());
if (DesiredLayout() == paddle::experimental::DataLayout::UNDEFINED) {
// Layout autotune only supports model with convolutional layers
VLOG(3) << "Optimze Layout was not started " << op_name;
if (op_name != "conv2d") {
VLOG(4) << "LayoutAutotune was unstarted. Current op :" << op_name;
return transposer;
} else {
auto data_type = tensors_vector[0][0].dtype();
......@@ -134,7 +109,8 @@ inline std::shared_ptr<EagerLayoutTransformer> EagerLayoutAutotune(
bool is_tune_fp16 =
(data_type == paddle::experimental::DataType::FLOAT16) &&
(*attr == "NCHW");
VLOG(3) << "Conv2d_dy's dtype " << data_type << " format" << (*attr);
VLOG(4) << "LayoutAutoTune assert with dtype and layout, Current op : "
<< op_name;
if (is_tune_fp32) {
paddle::imperative::LayoutAutoTune::Instance().SetDesiredLayout(
paddle::experimental::DataLayout::NCHW);
......@@ -147,58 +123,45 @@ inline std::shared_ptr<EagerLayoutTransformer> EagerLayoutAutotune(
paddle::imperative::LayoutAutoTune::Instance().SetDefaultLayout(
paddle::experimental::DataLayout::NCHW);
} else {
VLOG(4) << "DisableLayoutAutoTune accoding to Conv op"
<< " dtype : " << data_type << " format : " << (*attr);
egr::Controller::Instance().DisableLayoutAutoTune();
return transposer;
}
VLOG(3)
<< "Tune the layout from " << *attr << " to "
<< paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout();
VLOG(4) << "LayoutAutoTune from " << *attr << " to " << DesiredLayout();
}
}
if (paddle::imperative::LayoutAutoTune::Instance().IsHeavilyLayoutSensitive(
op_name)) {
VLOG(3)
<< op_name
<< "'s LayoutTransformer is EagerHeavilyLayoutSensitiveOpTransformer";
auto heavily_transposer =
std::make_shared<EagerHeavilyLayoutSensitiveOpTransformer>(op_name,
attr);
return heavily_transposer;
return std::make_shared<EagerHeavilyLayoutSensitiveOpTransformer>(op_name,
attr);
}
VLOG(3) << op_name << "'s LayoutTransformer is unimplemented. Use default.";
return transposer;
return std::make_shared<EagerLightlyLayoutSensitiveOpTransformer>(op_name);
}
// lightly transpose
template <>
inline std::shared_ptr<EagerLayoutTransformer> EagerLayoutAutotune(
const std::string& op_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize>& tensors_vector,
std::vector<int>* attr) {
auto first_layout = tensors_vector[0][0].layout();
std::shared_ptr<EagerLayoutTransformer> transposer = nullptr;
if (paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout() ==
paddle::experimental::DataLayout::UNDEFINED) {
VLOG(3) << "Optimze Layout was not started" << op_name;
transposer = std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, first_layout);
return transposer;
// lightly transpose
if (DesiredLayout() == paddle::experimental::DataLayout::UNDEFINED) {
VLOG(4) << "LayoutAutotune was unstarted. Current op :" << op_name;
return std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, tensors_vector[0][0].layout());
}
if (op_name == "transpose2" &&
(tensors_vector[0][0].layout() ==
paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout())) {
(tensors_vector[0][0].layout() == DesiredLayout())) {
auto trans = std::make_shared<EagerTransposeOpTransformer>(op_name);
trans->SetAttr(attr,
tensors_vector[0][0].layout() ==
paddle::experimental::DataLayout::NHWC);
return trans;
}
transposer =
std::make_shared<EagerLightlyLayoutSensitiveOpTransformer>(op_name);
return transposer;
return std::make_shared<EagerLightlyLayoutSensitiveOpTransformer>(op_name);
}
// lightly int argmax
......@@ -210,19 +173,14 @@ EagerLayoutAutotune<paddle::experimental::Scalar, bool>(
kSlotSmallVectorSize>& tensors_vector,
paddle::experimental::Scalar* axis,
bool* keep_dim) {
auto first_layout = tensors_vector[0][0].layout();
std::shared_ptr<EagerLayoutTransformer> transposer = nullptr;
if (paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout() ==
paddle::experimental::DataLayout::UNDEFINED) {
VLOG(3) << "Optimze Layout was not started" << op_name;
transposer = std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, first_layout);
return transposer;
if (DesiredLayout() == paddle::experimental::DataLayout::UNDEFINED) {
VLOG(4) << "LayoutAutotune was unstarted. Current op :" << op_name;
return std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, tensors_vector[0][0].layout());
}
auto desired_layout =
paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout();
if (op_name == "argmax" &&
(tensors_vector[0][0].layout() == desired_layout) && (*keep_dim)) {
(tensors_vector[0][0].layout() == DesiredLayout()) && (*keep_dim)) {
std::shared_ptr<EagerArgmaxOpTransformer> argmax_transform = nullptr;
argmax_transform = std::make_shared<EagerArgmaxOpTransformer>(op_name);
argmax_transform->SetAttr(axis,
......@@ -230,12 +188,9 @@ EagerLayoutAutotune<paddle::experimental::Scalar, bool>(
paddle::experimental::DataLayout::NHWC);
return argmax_transform;
}
transposer =
std::make_shared<EagerLightlyLayoutSensitiveOpTransformer>(op_name);
return transposer;
return std::make_shared<EagerLightlyLayoutSensitiveOpTransformer>(op_name);
}
// lightly for flatten
template <>
inline std::shared_ptr<EagerLayoutTransformer> EagerLayoutAutotune<int, int>(
const std::string& op_name,
......@@ -243,32 +198,22 @@ inline std::shared_ptr<EagerLayoutTransformer> EagerLayoutAutotune<int, int>(
kSlotSmallVectorSize>& tensors_vector,
int* start_axis,
int* stop_axis) {
auto first_layout = tensors_vector[0][0].layout();
std::shared_ptr<EagerLayoutTransformer> transposer = nullptr;
auto desired_layout =
paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout();
if (desired_layout == paddle::experimental::DataLayout::UNDEFINED) {
VLOG(3) << "Optimze Layout was not started" << op_name;
transposer = std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, first_layout);
return transposer;
if (DesiredLayout() == paddle::experimental::DataLayout::UNDEFINED) {
VLOG(4) << "Optimze Layout was not started" << op_name;
return std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, tensors_vector[0][0].layout());
}
bool no_tranpose = tensors_vector[0][0].layout() == desired_layout;
bool no_tranpose = tensors_vector[0][0].layout() == DesiredLayout();
bool is_valid = ((*start_axis) == 1 && (*stop_axis) == 3);
if (op_name == "flatten" || op_name == "flatten_contiguous_range") {
if (no_tranpose && is_valid) {
std::shared_ptr<EagerFlattenOpTransformer> flatten_transform = nullptr;
flatten_transform = std::make_shared<EagerFlattenOpTransformer>(op_name);
return flatten_transform;
return std::make_shared<EagerFlattenOpTransformer>(op_name);
}
}
transposer =
std::make_shared<EagerLightlyLayoutSensitiveOpTransformer>(op_name);
return transposer;
return std::make_shared<EagerLightlyLayoutSensitiveOpTransformer>(op_name);
}
// lightly int Concat
template <>
inline std::shared_ptr<EagerLayoutTransformer>
EagerLayoutAutotune<paddle::experimental::Scalar>(
......@@ -276,27 +221,26 @@ EagerLayoutAutotune<paddle::experimental::Scalar>(
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize>& tensors_vector,
paddle::experimental::Scalar* axis) {
auto desired_layout =
paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout();
auto first_layout = tensors_vector[0][0].layout();
std::shared_ptr<EagerLayoutTransformer> transposer = nullptr;
if (desired_layout == paddle::experimental::DataLayout::UNDEFINED) {
VLOG(3) << "Optimze Layout was not started" << op_name;
transposer = std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, first_layout);
return transposer;
if (DesiredLayout() == paddle::experimental::DataLayout::UNDEFINED) {
VLOG(4) << "Optimze Layout was not started" << op_name;
return std::make_shared<EagerLayoutTransformer>(
op_name, tensors_vector, tensors_vector[0][0].layout());
}
auto desired_layout = DesiredLayout();
if (NeedTransLayout(tensors_vector, desired_layout)) {
VLOG(3) << op_name << " need transpose to default layout";
transposer =
std::make_shared<EagerLightlyLayoutSensitiveOpTransformer>(op_name);
return transposer;
} else {
auto trans = std::make_shared<EagerConcatOpTransformer>(op_name);
trans->SetAttr(axis, desired_layout);
return trans;
VLOG(4) << op_name << "'s has different layout";
return std::make_shared<EagerLightlyLayoutSensitiveOpTransformer>(op_name);
}
if (op_name == "Concat") {
if (desired_layout == tensors_vector[0][0].layout() &&
tensors_vector[0][0].shape().size() == 4) {
auto trans = std::make_shared<EagerConcatOpTransformer>(op_name);
trans->SetAttr(axis, desired_layout);
return trans;
}
}
return std::make_shared<EagerLightlyLayoutSensitiveOpTransformer>(op_name);
}
} // namespace egr
......@@ -194,8 +194,10 @@ paddle::imperative::NameVarMap<VarType> AutoTuneLayout(
(conv_in_type == framework::proto::VarType::FP16);
if (is_tune_fp32) {
LayoutAutoTune::Instance().SetDesiredLayout(DataLayout::NCHW);
LayoutAutoTune::Instance().SetDefaultLayout(DataLayout::NHWC);
} else if (is_tune_fp16) {
LayoutAutoTune::Instance().SetDesiredLayout(DataLayout::NHWC);
LayoutAutoTune::Instance().SetDefaultLayout(DataLayout::NCHW);
} else {
tracer->DisableLayoutAutoTune();
return ins;
......
......@@ -184,6 +184,42 @@ PyObject* tensor_properties_get_shape(TensorObject* self, void* closure) {
}
}
auto desired_layout =
paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout();
auto default_layout =
paddle::imperative::LayoutAutoTune::Instance().GetDefaultLayout();
bool change_dim =
(desired_layout != default_layout &&
self->tensor.layout() == desired_layout && value.size() == 4);
VLOG(6) << "eager_properties 'Shape' method, layout autotune "
<< " desired_layout: " << desired_layout
<< " default_layout: " << default_layout
<< " tensor layout: " << self->tensor.layout()
<< " tensor's shape size is : " << value.size();
std::vector<int64_t> dims = value;
if (change_dim &&
paddle::framework::DataLayoutToString(desired_layout) == "NCHW") {
// NCHW -> NHWC
VLOG(6) << "layout autotune get Shape from NCHW -> NHWC " << value[0] << " "
<< value[1] << " " << value[2] << " " << value[3] << " to "
<< dims[0] << " " << dims[2] << " " << dims[3] << " " << dims[1];
value[0] = dims[0];
value[1] = dims[2];
value[2] = dims[3];
value[3] = dims[1];
} else if (change_dim &&
paddle::framework::DataLayoutToString(desired_layout) == "NHWC") {
// NHWC -> NCHW
VLOG(6) << "layout autotune get Shape from NHWC -> NCHW " << value[0] << " "
<< value[1] << " " << value[2] << " " << value[3] << " to "
<< dims[0] << " " << dims[3] << " " << dims[1] << " " << dims[2]
<< " " << dims[1];
value[0] = dims[0];
value[1] = dims[3];
value[2] = dims[1];
value[3] = dims[2];
}
return ToPyObject(value);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
......
......@@ -2044,8 +2044,49 @@ void BindImperative(py::module *m_ptr) {
"shape",
[](imperative::VarBase &self) {
if (self.Var().IsType<framework::LoDTensor>()) {
return phi::vectorize<int>(
auto value = phi::vectorize<int>(
self.Var().Get<framework::LoDTensor>().dims());
auto tensor = self.Var().Get<framework::LoDTensor>();
auto tmp_value = value;
auto desired_layout =
paddle::imperative::LayoutAutoTune::Instance()
.GetDesiredLayout();
auto default_layout =
paddle::imperative::LayoutAutoTune::Instance()
.GetDefaultLayout();
bool change_dim =
(desired_layout != default_layout &&
tensor.layout() == desired_layout && value.size() == 4);
VLOG(6) << "'Shape' method, layout autotune,"
<< " desired_layout: " << desired_layout
<< " default_layout: " << default_layout
<< " tensor layout: " << tensor.layout()
<< " tensor's shape size is : " << value.size();
if (change_dim && paddle::framework::DataLayoutToString(
desired_layout) == "NCHW") {
VLOG(6) << "layout autotune get Shape from NHWC -> NCHW "
<< value[0] << " " << value[1] << " " << value[2] << " "
<< value[3] << " to " << tmp_value[3] << " "
<< tmp_value[1] << " " << tmp_value[2] << " "
<< tmp_value[1];
// NCHW -> NHWC
value[1] = tmp_value[2];
value[2] = tmp_value[3];
value[3] = tmp_value[1];
} else if (change_dim && paddle::framework::DataLayoutToString(
desired_layout) == "NHWC") {
VLOG(6) << "layout autotune get Shape from NHWC -> NCHW "
<< value[0] << " " << value[1] << " " << value[2] << " "
<< value[3] << " to " << tmp_value[0] << " "
<< tmp_value[3] << " " << tmp_value[1] << " "
<< tmp_value[2];
// NHWC -> NCHW
value[1] = tmp_value[3];
value[2] = tmp_value[1];
value[3] = tmp_value[2];
}
return value;
} else if (self.Var().IsType<phi::SelectedRows>()) {
return phi::vectorize<int>(
self.Var().Get<phi::SelectedRows>().value().dims());
......
......@@ -205,7 +205,8 @@ phi::DenseTensor TransformData(phi::DenseTensor* tensor,
if (NeedTransformLayout(tensor->layout(),
target_args_def.layout,
tensor->place(),
transform_flag)) {
transform_flag) &&
tensor->dims().size() != 1) {
out = TransDataLayout(out, target_args_def.layout);
trans_layout = true;
}
......
......@@ -93,18 +93,9 @@ class LayoutAutoTune(unittest.TestCase):
return conv_out, predict
def test_enable_autotune(self):
if self.use_autoune():
conv_out, predict = self.train(data_format="NCHW")
if paddle.fluid.core.use_layout_autotune():
self.assertEqual(conv_out.shape, [1, 14, 14, 8])
self.assertEqual(predict.shape, [1, 2])
else:
self.assertEqual(conv_out.shape, [1, 8, 14, 14])
self.assertEqual(predict.shape, [1, 2])
else:
conv_out, predict = self.train(data_format="NCHW")
self.assertEqual(conv_out.shape, [1, 8, 14, 14])
self.assertEqual(predict.shape, [1, 2])
conv_out, predict = self.train(data_format="NCHW")
self.assertEqual(conv_out.shape, [1, 8, 14, 14])
self.assertEqual(predict.shape, [1, 2])
def test_transpose_op_transposer(self):
conv = paddle.nn.Conv2D(3, 8, (3, 3))
......@@ -124,12 +115,8 @@ class LayoutAutoTune(unittest.TestCase):
scaled.backward()
scaler.minimize(optimizer, scaled)
if paddle.fluid.core.use_layout_autotune():
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [1, 12, 8, 14])
else:
self.assertEqual(conv_out.shape, [1, 8, 14, 12])
self.assertEqual(out.shape, [1, 12, 8, 14])
self.assertEqual(conv_out.shape, [1, 8, 14, 12])
self.assertEqual(out.shape, [1, 12, 8, 14])
def test_flatten_op_transposer(self):
conv = paddle.nn.Conv2D(3, 8, (3, 3))
......@@ -143,12 +130,8 @@ class LayoutAutoTune(unittest.TestCase):
# because it flatten the C and H dimensions.
out = flatten(conv_out)
if paddle.fluid.core.use_layout_autotune():
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [1, 112, 12])
else:
self.assertEqual(conv_out.shape, [1, 8, 14, 12])
self.assertEqual(out.shape, [1, 112, 12])
self.assertEqual(conv_out.shape, [1, 8, 14, 12])
self.assertEqual(out.shape, [1, 112, 12])
def test_argmax_op_transposer_keep_dims(self):
conv = paddle.nn.Conv2D(3, 8, (3, 3))
......@@ -157,41 +140,8 @@ class LayoutAutoTune(unittest.TestCase):
conv_out = conv(data)
# conv_out.shape = [1, 14, 12, 8] with NHWC
out = paddle.argmax(conv_out, axis=1, keepdim=True)
if paddle.fluid.core.use_layout_autotune():
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [1, 14, 12, 1])
else:
self.assertEqual(conv_out.shape, [1, 8, 14, 12])
self.assertEqual(out.shape, [1, 1, 14, 12])
def test_argmax_op_transposer_ff(self):
conv = paddle.nn.Conv2D(3, 8, (3, 3))
data = paddle.rand([1, 3, 16, 14])
with paddle.amp.auto_cast(level="O2"):
conv_out = conv(data)
# conv_out.shape = [1, 14, 12, 8] with NHWC
out = paddle.argmax(conv_out)
if paddle.fluid.core.use_layout_autotune():
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [1])
else:
self.assertEqual(conv_out.shape, [1, 8, 14, 12])
self.assertEqual(out.shape, [1])
def test_argmax_op_transposer_t(self):
conv = paddle.nn.Conv2D(3, 8, (3, 3))
data = paddle.rand([1, 3, 16, 14])
with paddle.amp.auto_cast(level="O2"):
conv_out = conv(data)
# conv_out.shape = [1, 14, 12, 8] with NHWC
out = paddle.argmax(conv_out)
if paddle.fluid.core.use_layout_autotune():
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [1])
else:
self.assertEqual(conv_out.shape, [1, 8, 14, 12])
self.assertEqual(out.shape, [1])
self.assertEqual(conv_out.shape, [1, 8, 14, 12])
self.assertEqual(out.shape, [1, 1, 14, 12])
def test_concat_op_transposer(self):
in1 = paddle.rand([1, 8, 14, 12])
......@@ -202,12 +152,8 @@ class LayoutAutoTune(unittest.TestCase):
# conv_out.shape = [1, 14, 12, 8] with NHWC
out = paddle.concat(x=[conv_out, in1], axis=0)
if paddle.fluid.core.use_layout_autotune():
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [2, 8, 14, 12])
else:
self.assertEqual(conv_out.shape, [1, 8, 14, 12])
self.assertEqual(out.shape, [2, 8, 14, 12])
self.assertEqual(conv_out.shape, [1, 8, 14, 12])
self.assertEqual(out.shape, [2, 8, 14, 12])
def test_concat_op_no_transposer(self):
conv = paddle.nn.Conv2D(3, 8, (3, 3))
......@@ -219,12 +165,8 @@ class LayoutAutoTune(unittest.TestCase):
# conv_out.shape = [1, 14, 12, 8] with NHWC
out = paddle.concat(x=[conv_out1, conv_out2], axis=0)
if paddle.fluid.core.use_layout_autotune():
self.assertEqual(conv_out1.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [2, 14, 12, 8])
else:
self.assertEqual(conv_out1.shape, [1, 8, 14, 12])
self.assertEqual(out.shape, [2, 8, 14, 12])
self.assertEqual(conv_out1.shape, [1, 8, 14, 12])
self.assertEqual(out.shape, [2, 8, 14, 12])
class TestAutoTuneAPI(unittest.TestCase):
......
......@@ -152,8 +152,8 @@ def _conv_nd(x,
channel_dim = channel_dim + len(
x.shape) if channel_dim < 0 else channel_dim
tmp_bias = _C_ops.reshape(
bias,
bias.shape + [1 for i in range(len(x.shape) - channel_dim - 1)])
bias, [1 for i in range(channel_dim)] + bias.shape +
[1 for i in range(len(x.shape) - channel_dim - 1)])
return _C_ops.add(pre_bias, tmp_bias)
else:
return pre_bias
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册