未验证 提交 3cd3bf29 编写于 作者: W wuhuanzhou 提交者: GitHub

Add fuse_resnet_unit pass (#36818)

* GeneratePass support attr condition and mapping, test=develop

* fix coverage, test=develop

* Add fuse_resnet_unit pass, test=develop

* fix CI errors, test=develop

* fix CI errors, test=develop

* fix unittest error when compiling without CUDA, test=develop

* fix static ci error, test=develop

* limit kernel size must equal 1, test=develop
上级 d8191d06
......@@ -19,6 +19,32 @@ namespace paddle {
namespace framework {
namespace ir {
class element_visitor : public boost::static_visitor<Attribute> {
public:
explicit element_visitor(int index) : index_(index) {}
template <typename T>
Attribute operator()(const T& attr) const {
PADDLE_THROW(platform::errors::Unimplemented("Unimplemented operand."));
}
template <typename T>
Attribute operator()(const std::vector<T>& attr) const {
using ET = std::conditional_t<std::is_same<T, double>::value, float, T>;
int index = index_;
if (index < 0) {
index += attr.size();
}
if (index >= 0 && static_cast<size_t>(index) < attr.size()) {
return static_cast<ET>(attr[index]);
}
return boost::blank();
}
private:
int index_;
};
class operation_visitor : public boost::static_visitor<Attribute> {
public:
explicit operation_visitor(const proto::PassDesc::OperationType& type)
......@@ -29,15 +55,17 @@ class operation_visitor : public boost::static_visitor<Attribute> {
PADDLE_THROW(platform::errors::Unimplemented("Unimplemented operand."));
}
template <typename T,
std::enable_if_t<std::is_integral<T>::value ||
std::is_floating_point<T>::value>* = nullptr>
template <typename T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
Attribute operator()(const T& attr, const T& operation) const {
switch (type_) {
case proto::PassDesc_OperationType_kSub: {
return attr - operation;
}
case proto::PassDesc_OperationType_kMod: {
return attr % operation;
}
default:
PADDLE_THROW(
platform::errors::Unimplemented("Unimplemented operation type."));
......@@ -72,6 +100,15 @@ Attribute GetVarAttrValue(const VarDesc* desc,
return boost::blank();
}
Attribute GetOpAttrValue(const OpDesc* desc,
const proto::PassDesc::Attr& attr) {
Attribute value = desc->GetAttr(attr.name());
if (attr.has_element_index()) {
value = boost::apply_visitor(element_visitor(attr.element_index()), value);
}
return value;
}
void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
// Traverse all operators to create subgraph.
for (int index = 0; index < pass_desc.pattern_size(); ++index) {
......@@ -163,6 +200,11 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
PDNode* pdnode = pattern->RetrieveNode(condition.attr().var_name());
pdnode->assert_more([&](Node* x) {
Attribute attr = GetVarAttrValue(x->Var(), condition.attr());
if (condition.has_operation()) {
Attribute operation = GetAttrValue(condition.operation().value());
attr = boost::apply_visitor(
operation_visitor(condition.operation().type()), attr, operation);
}
switch (condition.type()) {
case proto::PassDesc_ConditionType_kEQ: {
return attr == GetAttrValue(condition.condition_value());
......@@ -305,7 +347,12 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
node = graph->CreateVarNode(&var_desc);
var_node_maps.insert({argument, node});
} else {
node = iter->second;
if (in_nodes.end() ==
std::find(in_nodes.begin(), in_nodes.end(), iter->second)) {
node = iter->second;
} else {
node = graph->CreateVarNode(iter->second->Var());
}
}
out_nodes.push_back(node);
arguments.push_back(node->Name());
......@@ -329,7 +376,7 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
Node* condition_node = subgraph.at(pattern.RetrieveNode(
std::to_string(attr_map.pattern_attr().op_index())));
attr =
condition_node->Op()->GetAttr(attr_map.pattern_attr().name());
GetOpAttrValue(condition_node->Op(), attr_map.pattern_attr());
}
if (attr_map.has_operation()) {
Attribute operation = GetAttrValue(attr_map.operation().value());
......
......@@ -26,6 +26,7 @@ message PassDesc {
kMul = 2;
kDiv = 3;
kSize = 4;
kMod = 5;
}
enum ConditionType {
kEQ = 0;
......
......@@ -109,7 +109,12 @@ class ResNetUnitOp : public framework::OperatorWithKernel {
// Check dims of inputs
const auto x_dims = ctx->GetInputDim("X");
const auto w_dims = ctx->GetInputDim("FilterX");
const auto bn_param_dims = ctx->GetInputDim("ScaleX");
std::vector<int64_t> bn_param_shape =
framework::vectorize(ctx->GetInputDim("ScaleX"));
if (1 == bn_param_shape.size()) {
bn_param_shape = {1, 1, 1, bn_param_shape[0]};
}
framework::DDim bn_param_dims = framework::make_ddim(bn_param_shape);
PADDLE_ENFORCE_EQ(x_dims.size(), 4, platform::errors::InvalidArgument(
"The dimensions of input "
"must equal to 4."
......
......@@ -68,8 +68,16 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
auto input_x_shape = framework::vectorize<int>(input_x->dims());
auto filter_x_shape = framework::vectorize<int>(filter_x->dims());
// std::swap used to convert shape of filter from conv2d when kernel size is
// 1.
if (filter_x_shape[1] != filter_x_shape[2] && 1 == filter_x_shape[2]) {
std::swap(filter_x_shape[1], filter_x_shape[3]);
}
auto param_dims = scale_x->dims();
auto param_shape = framework::vectorize<int>(scale_x->dims());
if (1 == param_shape.size()) {
param_shape = {1, 1, 1, param_shape[0]};
}
auto output_shape = framework::vectorize<int>(output->dims());
auto bitmask_shape = framework::vectorize<int>(bitmask->dims());
int output_channel = filter_x_shape[0];
......
......@@ -2508,13 +2508,13 @@ All parameter, weight, gradient are variables in Paddle.
m.def("disable_profiler", platform::DisableProfiler);
m.def("is_profiler_enabled", platform::IsProfileEnabled);
m.def("reset_profiler", platform::ResetProfiler);
m.def("register_pass", [](const std::string &pass_type,
const py::object &callable) {
m.def("register_pass", [](const std::string &pass_type, py::object callable) {
PADDLE_ENFORCE_EQ(
framework::ir::PassRegistry::Instance().Has(pass_type), false,
platform::errors::AlreadyExists(
"Pass '%s' is registered more than once. Please use another name.",
pass_type));
callable.inc_ref();
framework::ir::PassRegistry::Instance().Insert(pass_type, [pass_type,
callable]() {
py::gil_scoped_acquire guard;
......
......@@ -201,6 +201,12 @@ class RegisterPassHelper(object):
return vars, program.current_block().ops
def _convert_vars_to_pass_desc(self, patterns, replaces, desc):
def _add_element_conditions(conditions, elements):
for element in elements:
if element._condition:
conditions.append(element._condition)
_add_element_conditions(conditions, element._elements)
for (pattern, replace) in zip(patterns, replaces):
# Convert maps of inputs and outputs.
var_map = desc.var_maps.add()
......@@ -218,10 +224,7 @@ class RegisterPassHelper(object):
# Convert attr conditions.
if PassDesc.VarHelper == pattern.__class__:
for attr in pattern._attrs.values():
if attr._condition is not None:
conditions.append(attr._condition)
conditions.extend(
[e._condition for e in attr._elements if e._condition])
_add_element_conditions(conditions, [attr])
def _convert_ops_to_pass_desc(self, patterns, replaces, desc):
for replace in replaces:
......@@ -324,6 +327,10 @@ class PassDesc(object):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kAdd, value)
def Mod(self, value):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kMod, value)
def Size(self):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kSize)
......@@ -336,13 +343,20 @@ class PassDesc(object):
value._to_pass_desc_attr(condition.condition_attr)
else:
self._to_op_desc_attr(value, condition.condition_value)
if self._operation:
condition.operation.CopyFrom(self._operation)
self._condition = condition
def EQ(self, value):
self._set_with_condition(pass_desc_pb2.PassDesc.ConditionType.kEQ,
value)
def MappedPattern(self, var=None, op=None, index=0, name=None):
def MappedPattern(self,
var=None,
op=None,
index=0,
name=None,
element_index=None):
if all([var, op]):
raise ValueError("Only mapped one of which var or op.")
......@@ -356,7 +370,8 @@ class PassDesc(object):
raise ValueError(
"Index '{}' of operator '{}' is incorrect.".format(
index, op))
return PassDesc.AttrHelper(ops[index], name)
return PassDesc.AttrHelper(
ops[index], name, element_index=element_index)
self._mapped = mapped_op if var is None else mapped_var
......@@ -460,6 +475,13 @@ class PassDesc(object):
def Outputs(self):
return self._outputs
def SetOutputs(self, **kwargs):
for param, arg in kwargs.items():
if arg is None:
self._desc.remove_output(param)
else:
self._desc.set_output(param, [arg.name])
OP = OpHelper()
......
......@@ -10,3 +10,5 @@ foreach(target ${TEST_IR_PASSES})
endforeach()
add_subdirectory(inference)
set_tests_properties(test_fuse_resnet_unit PROPERTIES TIMEOUT 120)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.incubate
from paddle.fluid import core
from paddle.vision.models import ResNet
from paddle.vision.models.resnet import BottleneckBlock, BasicBlock
paddle.enable_static()
np.random.seed(0)
@unittest.skipIf(not paddle.is_compiled_with_cuda() or
paddle.get_cudnn_version() < 8000,
"only support with cuda and cudnn version is at least 8.0.")
class TestFuseResNetUnit(unittest.TestCase):
def test_fuse_resenet_unit(self):
place = paddle.CUDAPlace(0)
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.amp.fp16_guard():
with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [1, 64, 64, 8])
conv2d = paddle.nn.Conv2D(
8, 32, 1, bias_attr=False, data_format='NHWC')
batch_norm = paddle.nn.BatchNorm(
32, act='relu', data_layout='NHWC')
out = batch_norm(conv2d(x))
graph = core.Graph(program.desc)
core.get_pass("fuse_resnet_unit").apply(graph)
after_program = paddle.fluid.framework.IrGraph(graph).to_program()
params = paddle.static.amp.cast_model_to_fp16(program)
after_params = paddle.static.amp.cast_model_to_fp16(after_program)
exe = paddle.static.Executor(place)
exe.run(startup_program)
paddle.static.amp.cast_parameters_to_fp16(
place, program, to_fp16_var_names=params)
paddle.static.amp.cast_parameters_to_fp16(
place, after_program, to_fp16_var_names=after_params)
feed = {"x": np.random.randn(1, 64, 64, 8).astype("float16")}
before_out = exe.run(program, feed=feed, fetch_list=[out.name])
after_out = exe.run(after_program, feed=feed, fetch_list=[out.name])
self.assertTrue(np.allclose(before_out[0], after_out[0], atol=5e-3))
......@@ -23,6 +23,7 @@ from .tensor import segment_sum
from .tensor import segment_mean
from .tensor import segment_max
from .tensor import segment_min
from .passes import fuse_resnet_unit_pass
from . import nn #noqa: F401
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid.ir as ir
def set_resnet_unit_attrs(resnet_unit, has_shortcut):
resnet_unit.SetAttr("fuse_add", False)
resnet_unit.SetAttr("act_type", "relu")
resnet_unit.SetAttr("has_shortcut", has_shortcut)
resnet_unit.SetAttr("data_format", 'NHWC')
resnet_unit.SetAttr("dilation", 1)
resnet_unit.Attr("stride").MappedPattern(
op="conv2d", name="strides", element_index=0)
resnet_unit.Attr("padding").MappedPattern(
op="conv2d", name="paddings", element_index=0)
resnet_unit.Attr("group").MappedPattern(op="conv2d", name="groups")
resnet_unit.Attr("op_device").MappedPattern(op="conv2d", name="op_device")
resnet_unit.Attr("op_namescope").MappedPattern(
op="conv2d", name="op_namescope")
resnet_unit.Attr("momentum").MappedPattern(op="batch_norm", name="momentum")
resnet_unit.Attr("epsilon").MappedPattern(op="batch_norm", name="epsilon")
resnet_unit.Attr("use_global_stats").MappedPattern(
op="batch_norm", name="use_global_stats")
def set_resnet_unit_outputs(resnet_unit, meanX, varX, meanZ=None, varZ=None):
resnet_unit.SetOutputs(
RunningMeanX=meanX,
RunningVarX=varX,
RunningMeanZ=meanZ,
RunningVarZ=varZ)
@ir.RegisterPass
def fuse_resnet_unit():
def pattern_conv_bn(x, filter, scale, bias, mean, var):
filter.Attr("shape")[0].Mod(32).EQ(0)
filter.Attr("shape")[1].Mod(8).EQ(0)
filter.Attr("shape")[2].EQ(1)
filter.Attr("shape")[3].EQ(1)
conv2d = ir.PassDesc.OP.conv2d(Input=x, Filter=filter)
conv2d.SetAttr("data_format", 'NHWC')
bn = ir.PassDesc.OP.batch_norm(
X=conv2d, Bias=bias, Mean=mean, Scale=scale, Variance=var)
return bn
def pattern_one_input(x, filter, scale, bias, mean, var):
bn = pattern_conv_bn(x, filter, scale, bias, mean, var)
relu = ir.PassDesc.OP.relu(X=bn.Output("Y"))
return relu
def replace_one_input(x, filter, scale, bias, mean, var):
resnet_unit = ir.PassDesc.OP.resnet_unit(
X=x, FilterX=filter, ScaleX=scale, BiasX=bias, MeanX=mean, VarX=var)
set_resnet_unit_attrs(resnet_unit, False)
set_resnet_unit_outputs(resnet_unit, mean, var)
return resnet_unit.Output("Y")
def pattern_two_input(x, filterX, scaleX, biasX, meanX, varX, z, filterZ,
scaleZ, biasZ, meanZ, varZ):
bnX = pattern_conv_bn(x, filterX, scaleX, biasX, meanX, varX)
bnZ = pattern_conv_bn(x, filterZ, scaleZ, biasZ, meanZ, varZ)
ewadd = ir.PassDesc.OP.elementwise_add(
X=bnX.Output("Y"), Y=bnZ.Output("Y"))
relu = ir.PassDesc.OP.relu(X=ewadd)
return relu
def replace_two_input(x, filterX, scaleX, biasX, meanX, varX, z, filterZ,
scaleZ, biasZ, meanZ, varZ):
resnet_unit = ir.PassDesc.OP.resnet_unit(
X=x,
FilterX=filterX,
ScaleX=scaleX,
BiasX=biasX,
MeanX=meanX,
VarX=varX,
Z=z,
FilterZ=filterZ,
ScaleZ=scaleZ,
BiasZ=biasZ,
MeanZ=meanZ,
VarZ=varZ)
set_resnet_unit_attrs(resnet_unit, True)
set_resnet_unit_outputs(resnet_unit, meanX, varX, meanZ, varZ)
return resnet_unit.Output("Y")
return (pattern_one_input, replace_one_input), (pattern_two_input,
replace_two_input)
......@@ -275,6 +275,7 @@ packages=['paddle',
'paddle.incubate.operators',
'paddle.incubate.tensor',
'paddle.incubate.nn',
'paddle.incubate.passes',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.elastic',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册