未验证 提交 c2bcb141 编写于 作者: Z Zhang Ting 提交者: GitHub

Implement Amp Layout AutoTune (#41884)

上级 60bec700
......@@ -16,7 +16,9 @@
#include "gtest/gtest.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/imperative/var_helper.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
......@@ -206,3 +208,28 @@ TEST(EagerVariable, Constructor) {
VLOG(6) << "Finish";
}
TEST(EagerVariable, DataLayout) {
paddle::experimental::Tensor tensor;
phi::DenseTensorMeta meta =
phi::DenseTensorMeta(phi::DataType::FLOAT32, phi::make_ddim({1, 1, 1, 1}),
paddle::experimental::DataLayout::UNDEFINED);
std::shared_ptr<phi::DenseTensor> dt = std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
auto* dt_ptr = dt->mutable_data<float>(paddle::platform::CPUPlace());
dt_ptr[0] = 5.0f;
dt_ptr[1] = 5.0f;
dt_ptr[2] = 5.0f;
dt_ptr[3] = 5.0f;
tensor.set_impl(dt);
auto eager_var = std::make_shared<egr::EagerVariable>(tensor);
auto layout = paddle::imperative::GetDataLayout(eager_var);
CHECK_EQ(layout, paddle::experimental::DataLayout::UNDEFINED);
paddle::imperative::SetDataLayout(eager_var,
paddle::experimental::DataLayout::NCHW);
layout = paddle::imperative::GetDataLayout(eager_var);
CHECK_EQ(layout, paddle::experimental::DataLayout::NCHW);
}
......@@ -7,8 +7,13 @@ cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator
ENDIF()
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry var_helper phi_api)
add_subdirectory(jit)
if (WITH_GPU)
cc_library(layout_autotune SRCS layout_autotune.cc DEPS op_info phi_gpu_info)
else()
cc_library(layout_autotune SRCS layout_autotune.cc DEPS op_info)
endif()
cc_library(amp SRCS amp_auto_cast.cc DEPS layer var_helper)
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector var_helper)
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector var_helper layout_autotune)
cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator switch_autotune)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator switch_autotune)
cc_library(imperative_profiler SRCS profiler.cc DEPS flags)
......
......@@ -211,6 +211,14 @@ class VarBase {
framework::proto::VarType::Type DataType() const { return var_->DataType(); }
void SetDataLayout(paddle::experimental::DataLayout data_layout) {
var_->SetDataLayout(data_layout);
}
paddle::experimental::DataLayout DataLayout() const {
return var_->DataLayout();
}
size_t ElementSize() const { return framework::SizeOfType(var_->DataType()); }
void SetForwardDataType(framework::proto::VarType::Type data_type) {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/imperative/layout_transformer.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace imperative {
bool LayoutAutoTune::UseLayoutAutoTune() const {
#if defined(PADDLE_WITH_CUDA)
if (!phi::backends::gpu::TensorCoreAvailable()) {
LOG(INFO) << "Layout AutoTuning is not available.";
return false;
} else {
return use_layout_autotune_;
}
#else
return false;
#endif
}
LayoutAutoTune::LayoutAutoTune() {
const auto& op_info = paddle::framework::OpInfoMap::Instance().map();
for (auto it = op_info.begin(); it != op_info.end(); it++) {
// only record forwrd operators
if (it->first.find("_grad") != std::string::npos) {
continue;
}
// some normalization operators such as instance_norm and layer_norm
// do not have data_format attr, but are layout sensitive.
if (it->first.find("norm") != std::string::npos) {
layout_agnostic_ops_.emplace(it->first);
continue;
}
auto* attr_checker = it->second.Checker();
if (attr_checker) {
auto attrs = attr_checker->GetDefaultAttrMap();
if (attrs.find("data_format") != attrs.end() ||
attrs.find("data_layout") != attrs.end()) {
VLOG(4) << "Heavily layout sensitive OP: " << it->first;
heavily_layout_sensitive_ops_.emplace(it->first);
continue;
}
// Attribute name is fuzzy matched, such as start and start_axis.
bool layout_agnostic = true;
for (auto& attr : attrs) {
auto attr_name = attr.first;
VLOG(6) << "OP: " << it->first << " Attr Name: " << attr_name;
if (attr_name.find("axis") != std::string::npos ||
attr_name.find("axes") != std::string::npos ||
attr_name.find("dim") != std::string::npos ||
attr_name.find("start") != std::string::npos ||
attr_name.find("end") != std::string::npos) {
VLOG(4) << "Lightly layout sensitive OP: " << it->first;
layout_agnostic = false;
lightly_layout_sensitive_ops_.emplace(it->first);
break;
}
}
if (layout_agnostic) {
VLOG(4) << "Layout agnostic_ops: " << it->first;
layout_agnostic_ops_.emplace(it->first);
}
}
}
VLOG(3) << "The number of layout agnostic OPs: "
<< layout_agnostic_ops_.size() << ", heavily layout sensitive OPs: "
<< heavily_layout_sensitive_ops_.size()
<< ", lightly layout sensitive OPs: "
<< lightly_layout_sensitive_ops_.size();
}
template <typename VarType>
paddle::imperative::NameVarMap<VarType> AutoTuneLayout(
const std::string& op_type,
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<imperative::Tracer>& tracer) {
if (!LayoutAutoTune::Instance().UseLayoutAutoTune()) {
return ins;
}
// When layout autotuning is enabled, the tuner will check the desired layout.
// (1) If the desired layout is undefined, and there is no convolutional
// layers, layout optimization is unnecessary. Otherwise, the desired layout
// will be set to the best layout only when these is a convolutional layer
// with
// NCHW-Layout and the TensorCore is available.
// (2) If the desired layout is defined, run the transposer.
if (LayoutAutoTune::Instance().GetDesiredLayout() == DataLayout::UNDEFINED) {
// Layout autotune only supports model with convolutional layers
if (op_type != "conv2d") {
return ins;
} else {
if (BOOST_GET_CONST(std::string, (*attrs)["data_format"]) == "NCHW") {
LayoutAutoTune::Instance().SetDesiredLayout(DataLayout::NHWC);
VLOG(3) << "Tune the layout from "
<< BOOST_GET_CONST(std::string, (*attrs)["data_format"])
<< " to " << paddle::framework::DataLayoutToString(
LayoutAutoTune::Instance().GetDesiredLayout());
} else {
LayoutAutoTune::Instance().DisableLayoutAutoTune();
return ins;
}
}
}
std::shared_ptr<LayoutTransformer<VarType>> transposer = nullptr;
if (op_type == "conv2d") {
transposer =
std::make_shared<HeavilyLayoutSensitiveOpTransformer<VarType>>(op_type);
transposer->SetArguments({"Input"}, {"Output"}, {"data_format"});
} else if (op_type == "batch_norm") {
transposer =
std::make_shared<HeavilyLayoutSensitiveOpTransformer<VarType>>(op_type);
transposer->SetArguments({"X"}, {"Y"}, {"data_layout"});
} else if (op_type == "pool2d") {
transposer =
std::make_shared<HeavilyLayoutSensitiveOpTransformer<VarType>>(op_type);
transposer->SetArguments({"X"}, {"Out"}, {"data_format"});
} else if (op_type == "transpose2") {
transposer = std::make_shared<TransposeOpTransformer<VarType>>(op_type);
} else if (op_type == "flatten_contiguous_range") {
transposer = std::make_shared<FlattenOpTransformer<VarType>>(op_type);
} else if (op_type.find("elementwise_") != std::string::npos) {
transposer = std::make_shared<ElementwiseOpTransformer<VarType>>(op_type);
} else if (LayoutAutoTune::Instance().IsLayoutAgnostic(op_type)) {
transposer = std::make_shared<LayoutTransformer<VarType>>(op_type);
} else if (LayoutAutoTune::Instance().IsLightlyLayoutSensitive(op_type)) {
transposer =
std::make_shared<LightlyLayoutSensitiveOpTransformer<VarType>>(op_type);
} else {
PADDLE_ENFORCE_NOT_NULL(
transposer, phi::errors::Unimplemented(
"%s 's LayoutTransformer is unimplemented.", op_type));
}
return transposer->Apply(ins, outs, attrs, tracer);
}
template paddle::imperative::NameVarMap<VarBase> AutoTuneLayout<VarBase>(
const std::string& op_type,
const paddle::imperative::NameVarMap<VarBase>& ins,
const paddle::imperative::NameVarMap<VarBase>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<imperative::Tracer>& tracer);
template paddle::imperative::NameVarMap<egr::EagerVariable>
AutoTuneLayout<egr::EagerVariable>(
const std::string& op_type,
const paddle::imperative::NameVarMap<egr::EagerVariable>& ins,
const paddle::imperative::NameVarMap<egr::EagerVariable>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<imperative::Tracer>& tracer);
} // namespace imperative
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <memory>
#include <unordered_set>
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/compat/type_defs.h"
namespace paddle {
namespace imperative {
class Tracer;
using DataLayout = paddle::experimental::DataLayout;
class LayoutAutoTune {
public:
static LayoutAutoTune& Instance() {
static LayoutAutoTune layout_autoTune;
return layout_autoTune;
}
bool UseLayoutAutoTune() const;
void EnableLayoutAutoTune() { use_layout_autotune_ = true; }
void DisableLayoutAutoTune() { use_layout_autotune_ = false; }
bool IsLightlyLayoutSensitive(const std::string& op_type) const {
return lightly_layout_sensitive_ops_.count(op_type) != 0;
}
bool IsLayoutAgnostic(const std::string& op_type) const {
return layout_agnostic_ops_.count(op_type) != 0;
}
DataLayout GetDesiredLayout() const { return layout_; }
void SetDesiredLayout(const DataLayout& layout) { layout_ = layout; }
private:
LayoutAutoTune();
bool use_layout_autotune_{false};
std::unordered_set<std::string> layout_agnostic_ops_{};
std::unordered_set<std::string> heavily_layout_sensitive_ops_{};
std::unordered_set<std::string> lightly_layout_sensitive_ops_{};
DataLayout layout_{DataLayout::UNDEFINED};
};
template <typename VarType>
paddle::imperative::NameVarMap<VarType> AutoTuneLayout(
const std::string& op_type,
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<imperative::Tracer>& tracer);
} // namespace imperative
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/var_helper.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace imperative {
template <typename VarType>
std::shared_ptr<VarType> TraceTransposeOp(
const std::shared_ptr<VarType>& var, const DataLayout layout,
const std::shared_ptr<paddle::imperative::Tracer>& tracer) {
std::vector<int> axis;
if (layout == DataLayout::NHWC) {
axis = {0, 2, 3, 1};
} else if (layout == DataLayout::NCHW) {
axis = {0, 3, 1, 2};
} else {
axis = {0, 1, 2, 3};
}
paddle::imperative::NameVarMap<VarType> ins = {{"X", {var}}};
auto out =
std::shared_ptr<VarType>(new VarType(tracer->GenerateUniqueName()));
auto x_shape =
std::shared_ptr<VarType>(new VarType(tracer->GenerateUniqueName()));
paddle::imperative::NameVarMap<VarType> outs = {{"Out", {out}},
{"XShape", {x_shape}}};
paddle::framework::AttributeMap attrs = {{"axis", axis}};
tracer->TraceOp("transpose2", ins, outs, std::move(attrs));
paddle::imperative::SetDataLayout(out, layout);
VLOG(4) << "Transpose " << paddle::imperative::GetNameFromVar(var) << "["
<< paddle::framework::DataLayoutToString(
paddle::imperative::GetDataLayout(var))
<< "]"
<< " to " << paddle::imperative::GetNameFromVar(out) << "["
<< paddle::framework::DataLayoutToString(
paddle::imperative::GetDataLayout(out))
<< "]";
return out;
}
template <typename VarType>
class LayoutTransformer {
public:
explicit LayoutTransformer(const std::string& type) : type_(type) {}
virtual ~LayoutTransformer() {}
LayoutTransformer(const LayoutTransformer&) = delete;
LayoutTransformer& operator=(const LayoutTransformer&) = delete;
virtual paddle::imperative::NameVarMap<VarType> Apply(
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<paddle::imperative::Tracer>& tracer) {
VLOG(3) << "Optimze Layout agnostic op: " << type_;
auto in_layout = DataLayout::UNDEFINED;
for (auto& pair : ins) {
for (auto& var : pair.second) {
// Once the any input is desired layout, we set in_layout is desired
// layout.
if (paddle::imperative::GetDataLayout(var) ==
LayoutAutoTune::Instance().GetDesiredLayout()) {
in_layout = LayoutAutoTune::Instance().GetDesiredLayout();
break;
}
}
}
SetVarsLayout(outs, in_layout);
return ins;
}
// Set inputs, outputs and attributes to be optimized for the transposer.
// Those may respectively be a subset of the corresponding original argument
// of the operator.
void SetArguments(const std::vector<std::string>& ins,
const std::vector<std::string>& outs,
const std::vector<std::string>& attrs) {
ins_ = ins;
outs_ = outs;
attrs_ = attrs;
}
// Set the variables's layout to the specified layout.
// If outs_ is not specified, it means all outputs of the operator
// will be considered. Otherwise, it only set layout for the specified output.
void SetVarsLayout(const paddle::imperative::NameVarMap<VarType>& outs,
DataLayout layout) const {
if (outs_.empty()) {
for (auto& pair : outs) {
for (auto& var : pair.second) {
paddle::imperative::SetDataLayout(var, layout);
}
}
} else {
for (auto& name : outs_) {
auto out_vars = outs.at(name);
for (auto& var : out_vars) {
paddle::imperative::SetDataLayout(var, layout);
}
}
}
}
const std::vector<std::string>& Inputs() const { return ins_; }
const std::vector<std::string>& Outputs() const { return outs_; }
const std::vector<std::string>& Attributes() const { return attrs_; }
const std::string& Type() { return type_; }
protected:
std::string type_{};
std::vector<std::string> ins_{};
std::vector<std::string> outs_{};
std::vector<std::string> attrs_{};
};
template <typename VarType>
class ElementwiseOpTransformer : public LayoutTransformer<VarType> {
public:
explicit ElementwiseOpTransformer(const std::string& type)
: LayoutTransformer<VarType>(type) {}
paddle::imperative::NameVarMap<VarType> Apply(
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<paddle::imperative::Tracer>& tracer) {
// [Why we need the this?]
// The Elementwise Ops has a axis attr, it is to support broadcast.
// When bias_attr of Conv is not false, the elementwise_add will be
// appended, and the axis will be set to the channel dimension.
// If the axis is set to the channel dimension, the attr transformation
// is necessary. Otherwise, it will fall back to the
// LayoutTransformer::Apply.
auto desired_layout = LayoutAutoTune::Instance().GetDesiredLayout();
if (attrs->find("axis") != attrs->end() &&
BOOST_GET_CONST(int, (*attrs)["axis"]) != -1) {
VLOG(3) << "Optimze layout agnostic op " << this->Type();
if (desired_layout == DataLayout::NHWC) {
(*attrs)["axis"] = 3;
} else if (desired_layout == DataLayout::NCHW) {
(*attrs)["axis"] = 1;
} else {
PADDLE_ENFORCE_EQ(
desired_layout, DataLayout::UNDEFINED,
phi::errors::PreconditionNotMet("DataLayout is unsupport."));
}
this->SetVarsLayout(outs, desired_layout);
return ins;
} else {
return LayoutTransformer<VarType>::Apply(ins, outs, attrs, tracer);
}
}
};
/*
* Both functionality and performance are affected by data layout.
* Such as operators with data_format attribute.
*/
template <typename VarType>
class HeavilyLayoutSensitiveOpTransformer : public LayoutTransformer<VarType> {
public:
explicit HeavilyLayoutSensitiveOpTransformer(const std::string& type)
: LayoutTransformer<VarType>(type) {}
paddle::imperative::NameVarMap<VarType> Apply(
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<paddle::imperative::Tracer>& tracer) {
VLOG(3) << "Optimze heavily layout sensitive op " << this->Type();
paddle::imperative::NameVarMap<VarType> new_ins(ins);
// Step 1: Adjust the data_layout attr to the desired layout
auto desired_layout = LayoutAutoTune::Instance().GetDesiredLayout();
std::string desired_layout_str = paddle::framework::DataLayoutToString(
LayoutAutoTune::Instance().GetDesiredLayout());
if (attrs->find("data_format") != attrs->end() &&
BOOST_GET_CONST(std::string, (*attrs)["data_format"]) !=
desired_layout_str) {
VLOG(4) << "Origin layout attr: "
<< BOOST_GET_CONST(std::string, (*attrs)["data_format"])
<< ", Desired layout attr: " << desired_layout_str;
(*attrs)["data_format"] = desired_layout_str;
} else if (attrs->find("data_layout") != attrs->end() &&
BOOST_GET_CONST(std::string, (*attrs)["data_layout"]) !=
desired_layout_str) {
VLOG(4) << "Origin layout attr: "
<< BOOST_GET_CONST(std::string, (*attrs)["data_layout"])
<< ", Desired layout attr: " << desired_layout_str;
(*attrs)["data_layout"] = desired_layout_str;
}
// Step 2: Transpose the specified input for Op and set the transposed var's
// layout.
for (auto& name : this->Inputs()) {
auto& in_vars = new_ins[name];
for (auto& var : in_vars) {
auto var_layout = paddle::imperative::GetDataLayout(var);
if (var_layout != desired_layout) {
var = TraceTransposeOp(var, DataLayout::NHWC, tracer);
}
}
}
// Step 3: Set the Op's layout sensitive outs var.
this->SetVarsLayout(outs, desired_layout);
return new_ins;
}
};
/*
* The functionality may be affected layout transformation before them.
* Such as operators with axis attribute.
*/
template <typename VarType>
class LightlyLayoutSensitiveOpTransformer : public LayoutTransformer<VarType> {
public:
explicit LightlyLayoutSensitiveOpTransformer(const std::string& type)
: LayoutTransformer<VarType>(type) {}
paddle::imperative::NameVarMap<VarType> Apply(
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<paddle::imperative::Tracer>& tracer) {
VLOG(3) << "Optimze lightly layout sensitive op " << this->Type();
paddle::imperative::NameVarMap<VarType> new_ins(ins);
// If input's layout is not tuned, transformation is unnecessary.
// If input's layout is already tuned, it will be transformed back to NCHW.
// TODO(zhangting): The op of this type should be adapted to the previous
// operator output data layout. Currently only a few operators are
// supported, and transposers need to be carefully designed to ensure that
// they do not cause exceptions.
for (auto& pair : new_ins) {
for (auto& var : pair.second) {
auto var_layout = paddle::imperative::GetDataLayout(var);
if (var_layout == LayoutAutoTune::Instance().GetDesiredLayout()) {
// Set layout to UNDEFINED so that TransposeOpTransformer do
// NHWC->NCHW transformation.
var = TraceTransposeOp(var, DataLayout::UNDEFINED, tracer);
}
}
}
return new_ins;
}
};
template <typename VarType>
class TransposeOpTransformer
: public LightlyLayoutSensitiveOpTransformer<VarType> {
public:
explicit TransposeOpTransformer(const std::string& type)
: LightlyLayoutSensitiveOpTransformer<VarType>(type) {}
paddle::imperative::NameVarMap<VarType> Apply(
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<paddle::imperative::Tracer>& tracer) {
VLOG(3) << "Optimze lightly layout sensitive op " << this->Type();
// When the input layout is the desired format, it means that there
// is a transpose layer in the network, it is better to transpose
// the result to the original format.
// Instead of actually inserting a transpose Op, we fuse the inserted
// transpose Op with the current transpose Op by transforming 'axis' attr.
auto& in_var = ins.at("X")[0];
auto var_layout = paddle::imperative::GetDataLayout(in_var);
if (var_layout == LayoutAutoTune::Instance().GetDesiredLayout()) {
auto axis = BOOST_GET_CONST(std::vector<int>, (*attrs)["axis"]);
// NHWC->NCHW, permutaion will be set as follows.
std::vector<int> perm = {0, 3, 1, 2};
// fuse the transpose Ops by transforming axis.
std::vector<int> fusion_axis = {perm[axis[0]], perm[axis[1]],
perm[axis[2]], perm[axis[3]]};
(*attrs)["axis"] = fusion_axis;
}
return ins;
}
};
template <typename VarType>
class FlattenOpTransformer
: public LightlyLayoutSensitiveOpTransformer<VarType> {
public:
explicit FlattenOpTransformer(const std::string& type)
: LightlyLayoutSensitiveOpTransformer<VarType>(type) {}
paddle::imperative::NameVarMap<VarType> Apply(
const paddle::imperative::NameVarMap<VarType>& ins,
const paddle::imperative::NameVarMap<VarType>& outs,
paddle::framework::AttributeMap* attrs,
const std::shared_ptr<paddle::imperative::Tracer>& tracer) {
VLOG(3) << "Optimze lightly layout sensitive op " << this->Type();
// Flatten the C, H, W dimensions will not affect functionality.
// So transformation is unnecessary. But in other cases, it needs to
// fall back to the LightlyLayoutSensitiveOpTransformer.
auto start_axis = BOOST_GET_CONST(int, (*attrs)["start_axis"]);
auto stop_axis = BOOST_GET_CONST(int, (*attrs)["stop_axis"]);
if (paddle::imperative::GetDataLayout(ins.at("X")[0]) ==
LayoutAutoTune::Instance().GetDesiredLayout() &&
start_axis == 1 && stop_axis == 3) {
return ins;
} else {
return LightlyLayoutSensitiveOpTransformer<VarType>::Apply(ins, outs,
attrs, tracer);
}
}
};
} // namespace imperative
} // namespace paddle
......@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/platform/denormal.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
......@@ -222,16 +223,22 @@ void Tracer::TraceOpImpl(const std::string& type,
NameVarMap<VarType> new_ins = ins;
if (amp_level_ == AmpLevel::O1) {
if (amp_dtype_ == phi::DataType::FLOAT16) {
const auto& tracer = imperative::GetCurrentTracer();
new_ins =
imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs, tracer);
VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastInputs<VarType>(type, ins);
new_ins = AutoCastInputs<VarType>(type, new_ins);
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastBF16Inputs<VarType>(type, ins);
}
} else if (amp_level_ == AmpLevel::O2) {
if (amp_dtype_ == phi::DataType::FLOAT16) {
const auto& tracer = imperative::GetCurrentTracer();
new_ins =
imperative::AutoTuneLayout<VarType>(type, ins, outs, &attrs, tracer);
VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureFp16Inputs<VarType>(type, ins);
new_ins = CastPureFp16Inputs<VarType>(type, new_ins);
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureBf16Inputs<VarType>(type, ins);
......
......@@ -190,6 +190,59 @@ template framework::proto::VarType::Type GetDataType<VarBase>(
template framework::proto::VarType::Type GetDataType<VariableWrapper>(
std::shared_ptr<VariableWrapper> var);
/* GetDataLayout */
template <typename VarType>
paddle::experimental::DataLayout GetDataLayout(std::shared_ptr<VarType> var) {
return var->DataLayout();
}
template <>
paddle::experimental::DataLayout GetDataLayout<egr::EagerVariable>(
std::shared_ptr<egr::EagerVariable> var) {
if (var->Var().IsType<framework::LoDTensor>()) {
return var->Var().Get<framework::LoDTensor>().layout();
} else {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Only support framework::LoDTensor, but got %s here, please checkout "
"var type of "
"tensor: %s",
paddle::framework::ToTypeName(framework::ToVarType(var->Var().Type())),
var->name()));
}
}
template paddle::experimental::DataLayout GetDataLayout<VarBase>(
std::shared_ptr<VarBase> var);
template paddle::experimental::DataLayout GetDataLayout<VariableWrapper>(
std::shared_ptr<VariableWrapper> var);
/* SetDataLayout */
template <typename VarType>
void SetDataLayout(std::shared_ptr<VarType> var,
const paddle::experimental::DataLayout layout) {
var->SetDataLayout(layout);
}
template <>
void SetDataLayout<egr::EagerVariable>(
std::shared_ptr<egr::EagerVariable> var,
const paddle::experimental::DataLayout layout) {
if (var->Var().IsType<framework::LoDTensor>()) {
var->MutableVar()->GetMutable<paddle::framework::LoDTensor>()->set_layout(
layout);
} else {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Only support framework::LoDTensor, but got %s here, please checkout "
"var type of "
"tensor: %s",
paddle::framework::ToTypeName(framework::ToVarType(var->Var().Type())),
var->name()));
}
}
template void SetDataLayout<VarBase>(
std::shared_ptr<VarBase> var,
const paddle::experimental::DataLayout layout);
template void SetDataLayout<VariableWrapper>(
std::shared_ptr<VariableWrapper> var,
const paddle::experimental::DataLayout layout);
/* CheckCachedKey */
template <typename VarType>
bool CheckCachedKey(std::shared_ptr<VarType> var,
......
......@@ -63,6 +63,13 @@ framework::proto::VarType::Type GetType(std::shared_ptr<VarType> var);
template <typename VarType>
framework::proto::VarType::Type GetDataType(std::shared_ptr<VarType> var);
template <typename VarType>
paddle::experimental::DataLayout GetDataLayout(std::shared_ptr<VarType> var);
template <typename VarType>
void SetDataLayout(std::shared_ptr<VarType> var,
const paddle::experimental::DataLayout layout);
template <typename VarType>
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<VarType>& var);
......
......@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/hooks.h"
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/phi/common/layout.h"
namespace paddle {
namespace imperative {
......@@ -186,6 +187,12 @@ class VariableWrapper {
return fwd_data_type_;
}
paddle::experimental::DataLayout DataLayout() { return layout_; }
void SetDataLayout(const paddle::experimental::DataLayout layout) {
layout_ = layout;
}
const platform::Place Place() const {
const framework::Tensor* tensor = nullptr;
auto place =
......@@ -357,6 +364,10 @@ class VariableWrapper {
// training
// NOTE: Now no need to support remove void hook
std::vector<std::shared_ptr<std::function<void()>>> void_hooks_;
// DataLayout for layoutAutotune
paddle::experimental::DataLayout layout_{
paddle::experimental::DataLayout::UNDEFINED};
};
} // namespace imperative
......
......@@ -167,6 +167,7 @@ limitations under the License. */
#endif
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/phi/kernels/autotune/cache.h"
......@@ -4490,6 +4491,20 @@ All parameter, weight, gradient are variables in Paddle.
return res;
});
m.def("enable_layout_autotune", [] {
return paddle::imperative::LayoutAutoTune::Instance()
.EnableLayoutAutoTune();
});
m.def("disable_layout_autotune", [] {
return paddle::imperative::LayoutAutoTune::Instance()
.DisableLayoutAutoTune();
});
m.def("use_layout_autotune", [] {
return paddle::imperative::LayoutAutoTune::Instance().UseLayoutAutoTune();
});
BindFleetWrapper(&m);
BindIO(&m);
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import numpy
import paddle.nn.functional as F
class SimpleNet(paddle.nn.Layer):
def __init__(self, data_format="NCHW", class_num=2):
super(SimpleNet, self).__init__()
self.conv = paddle.nn.Conv2D(3, 8, (3, 3))
self.bn = paddle.nn.BatchNorm(num_channels=8)
self.relu = paddle.nn.ReLU()
self.pool = paddle.nn.AvgPool2D(kernel_size=2, stride=2)
self.flatten = paddle.nn.Flatten()
self.fc = paddle.nn.Linear(392, class_num)
def forward(self, image):
conv_out = self.conv(image)
bn_out = self.bn(conv_out)
out = self.relu(bn_out)
out = self.pool(out)
out = self.flatten(out)
out = self.fc(out)
return conv_out, out
class LayoutAutoTune(unittest.TestCase):
def use_autoune(self):
if paddle.is_compiled_with_cuda():
paddle.fluid.core.enable_layout_autotune()
return paddle.fluid.core.use_layout_autotune()
else:
paddle.fluid.core.disable_layout_autotune()
return paddle.fluid.core.use_layout_autotune()
def train(self, data_format):
model = SimpleNet(data_format="NCHW", class_num=2)
data = paddle.rand([1, 3, 16, 16])
if (data_format == "NHWC"):
data = paddle.rand([1, 16, 16, 3])
label_data = paddle.randint(0, 1, shape=[1, 1], dtype="int64")
optimizer = paddle.optimizer.SGD(learning_rate=0.0001,
parameters=model.parameters())
scaler = paddle.amp.GradScaler()
for i in range(2):
with paddle.amp.auto_cast(level="O2"):
conv_out, predict = model(data)
loss = F.cross_entropy(predict, label=label_data)
loss = loss.mean()
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
return conv_out, predict
def test_enable_autotune(self):
if self.use_autoune():
conv_out, predict = self.train(data_format="NCHW")
self.assertEqual(conv_out.shape, [1, 14, 14, 8])
self.assertEqual(predict.shape, [1, 2])
else:
conv_out, predict = self.train(data_format="NCHW")
self.assertEqual(conv_out.shape, [1, 8, 14, 14])
self.assertEqual(predict.shape, [1, 2])
def test_transpose_op_transposer(self):
if not self.use_autoune():
return
conv = paddle.nn.Conv2D(3, 8, (3, 3))
data = paddle.rand([1, 3, 16, 14])
label_data = paddle.randint(0, 1, shape=[1, 1], dtype="int64")
optimizer = paddle.optimizer.SGD(learning_rate=0.0001,
parameters=conv.parameters())
scaler = paddle.amp.GradScaler()
with paddle.amp.auto_cast(level="O2"):
conv_out = conv(data)
# conv_out.shape = [1, 14, 12, 8] with NHWC
# layout tuner will transpose conv_out to
# [1, 8, 14, 12] with NCHW before the following transpose op.
out = paddle.transpose(conv_out, perm=[0, 3, 1, 2])
loss = out.mean()
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [1, 12, 8, 14])
def test_flatten_op_transposer(self):
if not self.use_autoune():
return
paddle.fluid.core.enable_layout_autotune()
conv = paddle.nn.Conv2D(3, 8, (3, 3))
flatten = paddle.nn.Flatten(start_axis=1, stop_axis=2)
data = paddle.rand([1, 3, 16, 14])
with paddle.amp.auto_cast(level="O2"):
conv_out = conv(data)
# conv_out.shape = [1, 14, 12, 8] with NHWC
# layout tuner will transpose conv_out to
# [1, 8, 14, 12] with NCHW before the following flatten op
# because it flatten the C and H dimensions.
out = flatten(conv_out)
self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [1, 112, 12])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册