未验证 提交 69e99cc7 编写于 作者: N niuliling123 提交者: GitHub

improve LayoutAutoTune for NCHW and NHWC (#43158)

上级 bd7f28bb
......@@ -26,6 +26,7 @@ namespace imperative {
bool LayoutAutoTune::UseLayoutAutoTune() const {
#if defined(PADDLE_WITH_CUDA)
if (!phi::backends::gpu::TensorCoreAvailable()) {
LayoutAutoTune::Instance().DisableLayoutAutoTune();
return false;
} else {
return use_layout_autotune_;
......@@ -38,30 +39,23 @@ bool LayoutAutoTune::UseLayoutAutoTune() const {
LayoutAutoTune::LayoutAutoTune() {
const auto& op_info = paddle::framework::OpInfoMap::Instance().map();
for (auto it = op_info.begin(); it != op_info.end(); it++) {
// only record forwrd operators
if (it->first.find("_grad") != std::string::npos) {
// only when op was not in Lightly、Heavily or Agnostic Set
if (IsLightlyLayoutSensitive(it->first) ||
IsHeavilyLayoutSensitive(it->first) || IsLayoutAgnostic(it->first)) {
VLOG(4) << "Already exists in Layout OP: " << it->first;
continue;
}
// some normalization operators such as instance_norm and layer_norm
// do not have data_format attr, but are layout sensitive.
if (it->first.find("norm") != std::string::npos) {
layout_agnostic_ops_.emplace(it->first);
// only record forwrd operators
if (it->first.find("_grad") != std::string::npos) {
continue;
}
auto* attr_checker = it->second.Checker();
bool layout_agnostic = true;
if (attr_checker) {
auto attrs = attr_checker->GetDefaultAttrMap();
if (attrs.find("data_format") != attrs.end() ||
attrs.find("data_layout") != attrs.end()) {
VLOG(4) << "Heavily layout sensitive OP: " << it->first;
heavily_layout_sensitive_ops_.emplace(it->first);
continue;
}
// Attribute name is fuzzy matched, such as start and start_axis.
bool layout_agnostic = true;
for (auto& attr : attrs) {
auto attr_name = attr.first;
VLOG(6) << "OP: " << it->first << " Attr Name: " << attr_name;
......@@ -77,11 +71,27 @@ LayoutAutoTune::LayoutAutoTune() {
}
}
if (layout_agnostic) {
VLOG(4) << "Layout agnostic_ops: " << it->first;
layout_agnostic_ops_.emplace(it->first);
if ((attrs.find("data_format") != attrs.end() ||
attrs.find("data_layout") != attrs.end()) &&
layout_agnostic == true) {
VLOG(4) << "Heavily layout sensitive OP: " << it->first;
heavily_layout_sensitive_ops_.emplace(it->first);
layout_agnostic = false;
continue;
}
}
// some normalization operators such as instance_norm and layer_norm
// do not have data_format attr, but are layout sensitive.
if (it->first.find("norm") != std::string::npos && layout_agnostic) {
lightly_layout_sensitive_ops_.emplace(it->first);
continue;
}
if (layout_agnostic) {
VLOG(4) << "Layout agnostic_ops: " << it->first;
layout_agnostic_ops_.emplace(it->first);
}
}
VLOG(3) << "The number of layout agnostic OPs: "
......@@ -91,6 +101,48 @@ LayoutAutoTune::LayoutAutoTune() {
<< lightly_layout_sensitive_ops_.size();
}
template <typename VarType>
paddle::imperative::NameVarMap<VarType> DealHeavilyLayoutSensitive(
const std::string& op_type,
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<imperative::Tracer>& tracer) {
std::shared_ptr<LayoutTransformer<VarType>> transposer = nullptr;
transposer =
std::make_shared<HeavilyLayoutSensitiveOpTransformer<VarType>>(op_type);
transposer->SetArguments(
{"Input", "X"}, {"Output", "Out", "Y"}, {"data_format", "data_layout"});
return transposer->Apply(ins, outs, attrs, tracer);
}
template <typename VarType>
paddle::imperative::NameVarMap<VarType> DealLightlyLayoutSensitive(
const std::string& op_type,
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<imperative::Tracer>& tracer) {
std::shared_ptr<LayoutTransformer<VarType>> transposer = nullptr;
if (op_type == "transpose2") {
transposer = std::make_shared<TransposeOpTransformer<VarType>>(op_type);
} else if (op_type == "flatten_contiguous_range") {
transposer = std::make_shared<FlattenOpTransformer<VarType>>(op_type);
} else if (op_type == "arg_max") {
transposer = std::make_shared<ArgmaxOpTransformer<VarType>>(op_type);
} else if (op_type.find("elementwise_") != std::string::npos) {
transposer = std::make_shared<ElementwiseOpTransformer<VarType>>(op_type);
} else {
VLOG(4) << op_type
<< "'s LayoutTransformer is unimplemented. Use default "
"LightlyLayoutTransformer instead.";
transposer =
std::make_shared<LightlyLayoutSensitiveOpTransformer<VarType>>(op_type);
}
return transposer->Apply(ins, outs, attrs, tracer);
}
template <typename VarType>
paddle::imperative::NameVarMap<VarType> AutoTuneLayout(
const std::string& op_type,
......@@ -101,7 +153,6 @@ paddle::imperative::NameVarMap<VarType> AutoTuneLayout(
if (!LayoutAutoTune::Instance().UseLayoutAutoTune()) {
return ins;
}
// When layout autotuning is enabled, the tuner will check the desired layout.
// (1) If the desired layout is undefined, and there is no convolutional
// layers, layout optimization is unnecessary. Otherwise, the desired layout
......@@ -115,51 +166,49 @@ paddle::imperative::NameVarMap<VarType> AutoTuneLayout(
if (op_type != "conv2d") {
return ins;
} else {
if (BOOST_GET_CONST(std::string, (*attrs)["data_format"]) == "NCHW") {
auto conv_in_type = framework::proto::VarType::FP32;
auto& in_vars = ins.at("Input")[0];
if (GetDataType<VarType>(in_vars) == framework::proto::VarType::FP16) {
conv_in_type = framework::proto::VarType::FP16;
}
bool is_tune_fp32 =
(BOOST_GET_CONST(std::string, (*attrs)["data_format"]) == "NHWC") &&
(conv_in_type == framework::proto::VarType::FP32);
bool is_tune_fp16 =
(BOOST_GET_CONST(std::string, (*attrs)["data_format"]) == "NCHW") &&
(conv_in_type == framework::proto::VarType::FP16);
if (is_tune_fp32) {
LayoutAutoTune::Instance().SetDesiredLayout(DataLayout::NCHW);
} else if (is_tune_fp16) {
LayoutAutoTune::Instance().SetDesiredLayout(DataLayout::NHWC);
VLOG(3) << "Tune the layout from "
<< BOOST_GET_CONST(std::string, (*attrs)["data_format"])
<< " to "
<< paddle::framework::DataLayoutToString(
LayoutAutoTune::Instance().GetDesiredLayout());
} else {
LayoutAutoTune::Instance().DisableLayoutAutoTune();
return ins;
}
VLOG(3) << "Tune the layout from "
<< BOOST_GET_CONST(std::string, (*attrs)["data_format"]) << " to "
<< paddle::framework::DataLayoutToString(
LayoutAutoTune::Instance().GetDesiredLayout());
}
}
std::shared_ptr<LayoutTransformer<VarType>> transposer = nullptr;
if (op_type == "conv2d") {
transposer =
std::make_shared<HeavilyLayoutSensitiveOpTransformer<VarType>>(op_type);
transposer->SetArguments({"Input"}, {"Output"}, {"data_format"});
} else if (op_type == "batch_norm") {
transposer =
std::make_shared<HeavilyLayoutSensitiveOpTransformer<VarType>>(op_type);
transposer->SetArguments({"X"}, {"Y"}, {"data_layout"});
} else if (op_type == "pool2d") {
transposer =
std::make_shared<HeavilyLayoutSensitiveOpTransformer<VarType>>(op_type);
transposer->SetArguments({"X"}, {"Out"}, {"data_format"});
} else if (op_type == "transpose2") {
transposer = std::make_shared<TransposeOpTransformer<VarType>>(op_type);
} else if (op_type == "flatten_contiguous_range") {
transposer = std::make_shared<FlattenOpTransformer<VarType>>(op_type);
} else if (op_type.find("elementwise_") != std::string::npos) {
transposer = std::make_shared<ElementwiseOpTransformer<VarType>>(op_type);
} else if (LayoutAutoTune::Instance().IsLayoutAgnostic(op_type)) {
transposer = std::make_shared<LayoutTransformer<VarType>>(op_type);
if (LayoutAutoTune::Instance().IsHeavilyLayoutSensitive(op_type)) {
return DealHeavilyLayoutSensitive<VarType>(
op_type, ins, outs, attrs, tracer);
} else if (LayoutAutoTune::Instance().IsLightlyLayoutSensitive(op_type)) {
transposer =
std::make_shared<LightlyLayoutSensitiveOpTransformer<VarType>>(op_type);
return DealLightlyLayoutSensitive<VarType>(
op_type, ins, outs, attrs, tracer);
} else {
std::shared_ptr<LayoutTransformer<VarType>> transposer = nullptr;
if (LayoutAutoTune::Instance().IsLayoutAgnostic(op_type)) {
transposer = std::make_shared<LayoutTransformer<VarType>>(op_type);
}
PADDLE_ENFORCE_NOT_NULL(
transposer, phi::errors::Unimplemented(
"%s 's LayoutTransformer is unimplemented.", op_type));
transposer,
phi::errors::Unimplemented("%s 's LayoutTransformer is unimplemented.",
op_type));
return transposer->Apply(ins, outs, attrs, tracer);
}
return transposer->Apply(ins, outs, attrs, tracer);
}
template paddle::imperative::NameVarMap<VarBase> AutoTuneLayout<VarBase>(
const std::string& op_type,
......
......@@ -41,6 +41,10 @@ class LayoutAutoTune {
void DisableLayoutAutoTune() { use_layout_autotune_ = false; }
bool IsHeavilyLayoutSensitive(const std::string& op_type) const {
return heavily_layout_sensitive_ops_.count(op_type) != 0;
}
bool IsLightlyLayoutSensitive(const std::string& op_type) const {
return lightly_layout_sensitive_ops_.count(op_type) != 0;
}
......@@ -60,9 +64,10 @@ class LayoutAutoTune {
std::unordered_set<std::string> layout_agnostic_ops_{};
std::unordered_set<std::string> heavily_layout_sensitive_ops_{};
std::unordered_set<std::string> heavily_layout_sensitive_ops_{"batch_norm"};
std::unordered_set<std::string> lightly_layout_sensitive_ops_{};
std::unordered_set<std::string> lightly_layout_sensitive_ops_{
"instance_norm", "softmax", "transpose", "transpose2", "reshape2"};
DataLayout layout_{DataLayout::UNDEFINED};
};
......
......@@ -13,18 +13,19 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/var_helper.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace imperative {
template <typename VarType>
std::shared_ptr<VarType> TraceTransposeOp(
const std::shared_ptr<VarType>& var, const DataLayout layout,
const std::shared_ptr<VarType>& var,
const DataLayout layout,
const std::shared_ptr<paddle::imperative::Tracer>& tracer) {
std::vector<int> axis;
if (layout == DataLayout::NHWC) {
......@@ -76,8 +77,8 @@ class LayoutTransformer {
for (auto& var : pair.second) {
// Once the any input is desired layout, we set in_layout is desired
// layout.
if (paddle::imperative::GetDataLayout(var) ==
LayoutAutoTune::Instance().GetDesiredLayout()) {
if (var != nullptr && (paddle::imperative::GetDataLayout(var) ==
LayoutAutoTune::Instance().GetDesiredLayout())) {
in_layout = LayoutAutoTune::Instance().GetDesiredLayout();
break;
}
......@@ -103,17 +104,27 @@ class LayoutTransformer {
// will be considered. Otherwise, it only set layout for the specified output.
void SetVarsLayout(const paddle::imperative::NameVarMap<VarType>& outs,
DataLayout layout) const {
if (outs_.empty()) {
for (auto& pair : outs) {
for (auto& var : pair.second) {
paddle::imperative::SetDataLayout(var, layout);
bool not_in_out = true;
if (!outs_.empty()) {
for (auto& name : outs_) {
if (outs.find(name) != outs.end()) {
auto out_vars = outs.at(name);
for (auto& var : out_vars) {
if (var != nullptr) {
paddle::imperative::SetDataLayout(var, layout);
}
}
not_in_out = false;
}
}
} else {
for (auto& name : outs_) {
auto out_vars = outs.at(name);
for (auto& var : out_vars) {
paddle::imperative::SetDataLayout(var, layout);
}
if (not_in_out) {
for (auto& pair : outs) {
for (auto& var : pair.second) {
if (var != nullptr) {
paddle::imperative::SetDataLayout(var, layout);
}
}
}
}
......@@ -132,46 +143,6 @@ class LayoutTransformer {
std::vector<std::string> attrs_{};
};
template <typename VarType>
class ElementwiseOpTransformer : public LayoutTransformer<VarType> {
public:
explicit ElementwiseOpTransformer(const std::string& type)
: LayoutTransformer<VarType>(type) {}
paddle::imperative::NameVarMap<VarType> Apply(
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<paddle::imperative::Tracer>& tracer) {
// [Why we need the this?]
// The Elementwise Ops has a axis attr, it is to support broadcast.
// When bias_attr of Conv is not false, the elementwise_add will be
// appended, and the axis will be set to the channel dimension.
// If the axis is set to the channel dimension, the attr transformation
// is necessary. Otherwise, it will fall back to the
// LayoutTransformer::Apply.
auto desired_layout = LayoutAutoTune::Instance().GetDesiredLayout();
if (attrs->find("axis") != attrs->end() &&
BOOST_GET_CONST(int, (*attrs)["axis"]) != -1) {
VLOG(3) << "Optimze layout agnostic op " << this->Type();
if (desired_layout == DataLayout::NHWC) {
(*attrs)["axis"] = 3;
} else if (desired_layout == DataLayout::NCHW) {
(*attrs)["axis"] = 1;
} else {
PADDLE_ENFORCE_EQ(
desired_layout, DataLayout::UNDEFINED,
phi::errors::PreconditionNotMet("DataLayout is unsupport."));
}
this->SetVarsLayout(outs, desired_layout);
return ins;
} else {
return LayoutTransformer<VarType>::Apply(ins, outs, attrs, tracer);
}
}
};
/*
* Both functionality and performance are affected by data layout.
* Such as operators with data_format attribute.
......@@ -213,11 +184,13 @@ class HeavilyLayoutSensitiveOpTransformer : public LayoutTransformer<VarType> {
// Step 2: Transpose the specified input for Op and set the transposed var's
// layout.
for (auto& name : this->Inputs()) {
auto& in_vars = new_ins[name];
for (auto& var : in_vars) {
auto var_layout = paddle::imperative::GetDataLayout(var);
if (var_layout != desired_layout) {
var = TraceTransposeOp(var, DataLayout::NHWC, tracer);
if (new_ins.find(name) != new_ins.end()) {
auto& in_vars = new_ins[name];
for (auto& var : in_vars) {
if (var != nullptr &&
paddle::imperative::GetDataLayout(var) != desired_layout) {
var = TraceTransposeOp(var, desired_layout, tracer);
}
}
}
}
......@@ -252,10 +225,20 @@ class LightlyLayoutSensitiveOpTransformer : public LayoutTransformer<VarType> {
// operator output data layout. Currently only a few operators are
// supported, and transposers need to be carefully designed to ensure that
// they do not cause exceptions.
auto desired_layout = LayoutAutoTune::Instance().GetDesiredLayout();
for (auto& pair : new_ins) {
for (auto& var : pair.second) {
auto var_layout = paddle::imperative::GetDataLayout(var);
if (var_layout == LayoutAutoTune::Instance().GetDesiredLayout()) {
if (var != nullptr) {
VLOG(3) << "Tune the layout from "
<< paddle::framework::DataLayoutToString(
paddle::imperative::GetDataLayout(var))
<< " to "
<< paddle::framework::DataLayoutToString(
LayoutAutoTune::Instance().GetDesiredLayout());
}
if (var != nullptr &&
paddle::imperative::GetDataLayout(var) == desired_layout &&
desired_layout == DataLayout::NHWC) {
// Set layout to UNDEFINED so that TransposeOpTransformer do
// NHWC->NCHW transformation.
var = TraceTransposeOp(var, DataLayout::UNDEFINED, tracer);
......@@ -266,6 +249,50 @@ class LightlyLayoutSensitiveOpTransformer : public LayoutTransformer<VarType> {
}
};
template <typename VarType>
class ElementwiseOpTransformer
: public LightlyLayoutSensitiveOpTransformer<VarType> {
public:
explicit ElementwiseOpTransformer(const std::string& type)
: LightlyLayoutSensitiveOpTransformer<VarType>(type) {}
paddle::imperative::NameVarMap<VarType> Apply(
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<paddle::imperative::Tracer>& tracer) {
// [Why we need the this?]
// The Elementwise Ops has a axis attr, it is to support broadcast.
// When bias_attr of Conv is not false, the elementwise_add will be
// appended, and the axis will be set to the channel dimension.
// If the axis is set to the channel dimension, the attr transformation
// is necessary. Otherwise, it will fall back to the
// LayoutTransformer::Apply.
auto& in1_vars = ins.at("X")[0];
auto& in2_vars = ins.at("Y")[0];
auto in_layout = paddle::imperative::GetDataLayout(in1_vars);
// for conv's bias
if (attrs->find("axis") != attrs->end() &&
BOOST_GET_CONST(int, (*attrs)["axis"]) != -1) {
if (in_layout == DataLayout::NHWC) {
(*attrs)["axis"] = 3;
} else if (in_layout == DataLayout::NCHW) {
(*attrs)["axis"] = 1;
}
this->SetVarsLayout(outs, in_layout);
return ins;
} else {
auto in2_layout = paddle::imperative::GetDataLayout(in2_vars);
if (in_layout == in2_layout) {
this->SetVarsLayout(outs, in_layout);
return ins;
}
return LightlyLayoutSensitiveOpTransformer<VarType>::Apply(
ins, outs, attrs, tracer);
}
}
};
template <typename VarType>
class TransposeOpTransformer
: public LightlyLayoutSensitiveOpTransformer<VarType> {
......@@ -286,13 +313,14 @@ class TransposeOpTransformer
// transpose Op with the current transpose Op by transforming 'axis' attr.
auto& in_var = ins.at("X")[0];
auto var_layout = paddle::imperative::GetDataLayout(in_var);
if (var_layout == LayoutAutoTune::Instance().GetDesiredLayout()) {
auto desired_layout = LayoutAutoTune::Instance().GetDesiredLayout();
if (var_layout == desired_layout && desired_layout == DataLayout::NHWC) {
auto axis = BOOST_GET_CONST(std::vector<int>, (*attrs)["axis"]);
// NHWC->NCHW, permutaion will be set as follows.
std::vector<int> perm = {0, 3, 1, 2};
// fuse the transpose Ops by transforming axis.
std::vector<int> fusion_axis = {perm[axis[0]], perm[axis[1]],
perm[axis[2]], perm[axis[3]]};
std::vector<int> fusion_axis = {
perm[axis[0]], perm[axis[1]], perm[axis[2]], perm[axis[3]]};
(*attrs)["axis"] = fusion_axis;
}
return ins;
......@@ -322,9 +350,53 @@ class FlattenOpTransformer
start_axis == 1 && stop_axis == 3) {
return ins;
} else {
return LightlyLayoutSensitiveOpTransformer<VarType>::Apply(ins, outs,
attrs, tracer);
return LightlyLayoutSensitiveOpTransformer<VarType>::Apply(
ins, outs, attrs, tracer);
}
}
};
template <typename VarType>
class ArgmaxOpTransformer
: public LightlyLayoutSensitiveOpTransformer<VarType> {
public:
explicit ArgmaxOpTransformer(const std::string& type)
: LightlyLayoutSensitiveOpTransformer<VarType>(type) {}
paddle::imperative::NameVarMap<VarType> Apply(
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<paddle::imperative::Tracer>& tracer) {
VLOG(3) << "Optimze lightly layout sensitive op " << this->Type();
auto& in_var = ins.at("X")[0];
auto var_layout = paddle::imperative::GetDataLayout(in_var);
bool keep_dims = BOOST_GET_CONST(bool, (*attrs)["keepdims"]);
if (keep_dims) {
if (var_layout != DataLayout::UNDEFINED) {
std::vector<int> perm_nhwc = {0, 2, 3, 1};
std::vector<int> perm_nchw = {0, 3, 1, 2};
auto perm = var_layout == DataLayout::NHWC ? perm_nhwc : perm_nchw;
switch (AttrTypeID((*attrs)["axis"])) {
case paddle::framework::proto::AttrType::INT: {
auto axis = BOOST_GET_CONST(int, (*attrs)["axis"]);
(*attrs)["axis"] = static_cast<int>(perm[axis]);
}
case paddle::framework::proto::AttrType::LONG: {
auto axis = BOOST_GET_CONST(int64_t, (*attrs)["axis"]);
(*attrs)["axis"] = static_cast<int64_t>(perm[axis]);
}
default:
VLOG(4) << "The data_type of axis is Error, axis must be int or "
"int64, bug got "
<< (AttrTypeID((*attrs)["axis"]));
}
}
this->SetVarsLayout(outs, var_layout);
return ins;
}
return LightlyLayoutSensitiveOpTransformer<VarType>::Apply(
ins, outs, attrs, tracer);
}
};
......
......@@ -225,11 +225,9 @@ void Tracer::TraceOpImpl(const std::string& type,
std::unique_ptr<NameVarMap<VarType>> ins_amp = nullptr;
if (amp_level_ == AmpLevel::O1) {
if (amp_dtype_ == phi::DataType::FLOAT16) {
const auto& tracer = imperative::GetCurrentTracer();
VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type;
ins_amp = std::make_unique<NameVarMap<VarType>>(
AutoCastInputs<VarType>(type, imperative::AutoTuneLayout<VarType>(
type, ins, outs, &attrs, tracer)));
AutoCastInputs<VarType>(type, ins));
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type;
ins_amp = std::make_unique<NameVarMap<VarType>>(
......@@ -237,18 +235,24 @@ void Tracer::TraceOpImpl(const std::string& type,
}
} else if (amp_level_ == AmpLevel::O2) {
if (amp_dtype_ == phi::DataType::FLOAT16) {
const auto& tracer = imperative::GetCurrentTracer();
VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type;
ins_amp =
std::make_unique<NameVarMap<VarType>>(CastPureFp16Inputs<VarType>(
type, imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs,
tracer)));
ins_amp = std::make_unique<NameVarMap<VarType>>(
CastPureFp16Inputs<VarType>(type, ins));
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type;
ins_amp = std::make_unique<NameVarMap<VarType>>(
CastPureBf16Inputs<VarType>(type, ins));
}
}
if (platform::is_gpu_place(place)) {
const auto& new_tmp = ins_amp == nullptr ? ins : *ins_amp;
const auto& tracer = imperative::GetCurrentTracer();
ins_amp = std::make_unique<NameVarMap<VarType>>(
imperative::AutoTuneLayout<VarType>(type, new_tmp, outs, &attrs,
tracer));
}
const auto& new_ins = ins_amp == nullptr ? ins : *ins_amp;
try {
......
......@@ -307,7 +307,8 @@ void SetCachedValue<egr::EagerVariable>(
// is equal to self: " << key == key << " and res name is:" << res->Name().
}
template void SetCachedValue<VarBase>(
std::shared_ptr<VarBase> var, const paddle::framework::OpKernelType &key,
std::shared_ptr<VarBase> var,
const paddle::framework::OpKernelType &key,
std::shared_ptr<VarBase> res);
template void SetCachedValue<VariableWrapper>(
std::shared_ptr<VariableWrapper> var,
......
......@@ -135,6 +135,32 @@ class LayoutAutoTune(unittest.TestCase):
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [1, 112, 12])
def test_argmax_op_transposer_keep_dims(self):
if not self.use_autoune():
return
conv = paddle.nn.Conv2D(3, 8, (3, 3))
data = paddle.rand([1, 3, 16, 14])
with paddle.amp.auto_cast(level="O2"):
conv_out = conv(data)
# conv_out.shape = [1, 14, 12, 8] with NHWC
out = paddle.argmax(conv_out, axis=1, keepdim=True)
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [1, 14, 1, 8])
def test_argmax_op_transposer(self):
if not self.use_autoune():
return
conv = paddle.nn.Conv2D(3, 8, (3, 3))
data = paddle.rand([1, 3, 16, 14])
with paddle.amp.auto_cast(level="O2"):
conv_out = conv(data)
# conv_out.shape = [1, 14, 12, 8] with NHWC
out = paddle.argmax(conv_out)
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [1])
class TestAutoTuneAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册