diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index 3f9ad5b2c5203e49916e8b7283e7ead1adb9d994..02c9d8e1c0c24aa5632f4ffd7aaeec42962e50dd 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -19,6 +19,32 @@ namespace paddle { namespace framework { namespace ir { +class element_visitor : public boost::static_visitor { + public: + explicit element_visitor(int index) : index_(index) {} + + template + Attribute operator()(const T& attr) const { + PADDLE_THROW(platform::errors::Unimplemented("Unimplemented operand.")); + } + + template + Attribute operator()(const std::vector& attr) const { + using ET = std::conditional_t::value, float, T>; + int index = index_; + if (index < 0) { + index += attr.size(); + } + if (index >= 0 && static_cast(index) < attr.size()) { + return static_cast(attr[index]); + } + return boost::blank(); + } + + private: + int index_; +}; + class operation_visitor : public boost::static_visitor { public: explicit operation_visitor(const proto::PassDesc::OperationType& type) @@ -29,15 +55,17 @@ class operation_visitor : public boost::static_visitor { PADDLE_THROW(platform::errors::Unimplemented("Unimplemented operand.")); } - template ::value || - std::is_floating_point::value>* = nullptr> + template ::value>* = nullptr> Attribute operator()(const T& attr, const T& operation) const { switch (type_) { case proto::PassDesc_OperationType_kSub: { return attr - operation; } + case proto::PassDesc_OperationType_kMod: { + return attr % operation; + } + default: PADDLE_THROW( platform::errors::Unimplemented("Unimplemented operation type.")); @@ -72,6 +100,15 @@ Attribute GetVarAttrValue(const VarDesc* desc, return boost::blank(); } +Attribute GetOpAttrValue(const OpDesc* desc, + const proto::PassDesc::Attr& attr) { + Attribute value = desc->GetAttr(attr.name()); + if (attr.has_element_index()) { + value = boost::apply_visitor(element_visitor(attr.element_index()), value); + } + return value; +} + void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { // Traverse all operators to create subgraph. for (int index = 0; index < pass_desc.pattern_size(); ++index) { @@ -163,6 +200,11 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { PDNode* pdnode = pattern->RetrieveNode(condition.attr().var_name()); pdnode->assert_more([&](Node* x) { Attribute attr = GetVarAttrValue(x->Var(), condition.attr()); + if (condition.has_operation()) { + Attribute operation = GetAttrValue(condition.operation().value()); + attr = boost::apply_visitor( + operation_visitor(condition.operation().type()), attr, operation); + } switch (condition.type()) { case proto::PassDesc_ConditionType_kEQ: { return attr == GetAttrValue(condition.condition_value()); @@ -305,7 +347,12 @@ GraphPatternDetector::handle_t GetGenerateRewrite( node = graph->CreateVarNode(&var_desc); var_node_maps.insert({argument, node}); } else { - node = iter->second; + if (in_nodes.end() == + std::find(in_nodes.begin(), in_nodes.end(), iter->second)) { + node = iter->second; + } else { + node = graph->CreateVarNode(iter->second->Var()); + } } out_nodes.push_back(node); arguments.push_back(node->Name()); @@ -329,7 +376,7 @@ GraphPatternDetector::handle_t GetGenerateRewrite( Node* condition_node = subgraph.at(pattern.RetrieveNode( std::to_string(attr_map.pattern_attr().op_index()))); attr = - condition_node->Op()->GetAttr(attr_map.pattern_attr().name()); + GetOpAttrValue(condition_node->Op(), attr_map.pattern_attr()); } if (attr_map.has_operation()) { Attribute operation = GetAttrValue(attr_map.operation().value()); diff --git a/paddle/fluid/framework/pass_desc.proto b/paddle/fluid/framework/pass_desc.proto index 86a1effb2896ef5ab4f552f5512444f886f7fac9..53b785490202248feab9ec1d63bc59c8349430a2 100644 --- a/paddle/fluid/framework/pass_desc.proto +++ b/paddle/fluid/framework/pass_desc.proto @@ -26,6 +26,7 @@ message PassDesc { kMul = 2; kDiv = 3; kSize = 4; + kMod = 5; } enum ConditionType { kEQ = 0; diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cc b/paddle/fluid/operators/fused/resnet_unit_op.cc index d2ac089d4d1d21cb7767292d41035d86b79cc4f6..79e813c52181c6bc9659dde49c4f4b49ddb246a8 100644 --- a/paddle/fluid/operators/fused/resnet_unit_op.cc +++ b/paddle/fluid/operators/fused/resnet_unit_op.cc @@ -109,7 +109,12 @@ class ResNetUnitOp : public framework::OperatorWithKernel { // Check dims of inputs const auto x_dims = ctx->GetInputDim("X"); const auto w_dims = ctx->GetInputDim("FilterX"); - const auto bn_param_dims = ctx->GetInputDim("ScaleX"); + std::vector bn_param_shape = + framework::vectorize(ctx->GetInputDim("ScaleX")); + if (1 == bn_param_shape.size()) { + bn_param_shape = {1, 1, 1, bn_param_shape[0]}; + } + framework::DDim bn_param_dims = framework::make_ddim(bn_param_shape); PADDLE_ENFORCE_EQ(x_dims.size(), 4, platform::errors::InvalidArgument( "The dimensions of input " "must equal to 4." diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cu b/paddle/fluid/operators/fused/resnet_unit_op.cu index b121864f80e4d970911d8bd68afa1966b9c08dc7..6084ee36f2ce26ad74fe1e7d87a14fb56d5b63b6 100644 --- a/paddle/fluid/operators/fused/resnet_unit_op.cu +++ b/paddle/fluid/operators/fused/resnet_unit_op.cu @@ -68,8 +68,16 @@ class ResNetUnitKernel : public framework::OpKernel { auto input_x_shape = framework::vectorize(input_x->dims()); auto filter_x_shape = framework::vectorize(filter_x->dims()); + // std::swap used to convert shape of filter from conv2d when kernel size is + // 1. + if (filter_x_shape[1] != filter_x_shape[2] && 1 == filter_x_shape[2]) { + std::swap(filter_x_shape[1], filter_x_shape[3]); + } auto param_dims = scale_x->dims(); auto param_shape = framework::vectorize(scale_x->dims()); + if (1 == param_shape.size()) { + param_shape = {1, 1, 1, param_shape[0]}; + } auto output_shape = framework::vectorize(output->dims()); auto bitmask_shape = framework::vectorize(bitmask->dims()); int output_channel = filter_x_shape[0]; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3a20fa1d3aeb845b0e77baab50548d3cd6289ff0..bf943b7327e6d815bc90fd9dbf54a0657eb633f1 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2508,13 +2508,13 @@ All parameter, weight, gradient are variables in Paddle. m.def("disable_profiler", platform::DisableProfiler); m.def("is_profiler_enabled", platform::IsProfileEnabled); m.def("reset_profiler", platform::ResetProfiler); - m.def("register_pass", [](const std::string &pass_type, - const py::object &callable) { + m.def("register_pass", [](const std::string &pass_type, py::object callable) { PADDLE_ENFORCE_EQ( framework::ir::PassRegistry::Instance().Has(pass_type), false, platform::errors::AlreadyExists( "Pass '%s' is registered more than once. Please use another name.", pass_type)); + callable.inc_ref(); framework::ir::PassRegistry::Instance().Insert(pass_type, [pass_type, callable]() { py::gil_scoped_acquire guard; diff --git a/python/paddle/fluid/ir.py b/python/paddle/fluid/ir.py index adeab721fc2dd50535fbb238515ea8494a8e99b1..55297ed516ffb4f2e64abb44030b642785f03cbd 100644 --- a/python/paddle/fluid/ir.py +++ b/python/paddle/fluid/ir.py @@ -201,6 +201,12 @@ class RegisterPassHelper(object): return vars, program.current_block().ops def _convert_vars_to_pass_desc(self, patterns, replaces, desc): + def _add_element_conditions(conditions, elements): + for element in elements: + if element._condition: + conditions.append(element._condition) + _add_element_conditions(conditions, element._elements) + for (pattern, replace) in zip(patterns, replaces): # Convert maps of inputs and outputs. var_map = desc.var_maps.add() @@ -218,10 +224,7 @@ class RegisterPassHelper(object): # Convert attr conditions. if PassDesc.VarHelper == pattern.__class__: for attr in pattern._attrs.values(): - if attr._condition is not None: - conditions.append(attr._condition) - conditions.extend( - [e._condition for e in attr._elements if e._condition]) + _add_element_conditions(conditions, [attr]) def _convert_ops_to_pass_desc(self, patterns, replaces, desc): for replace in replaces: @@ -324,6 +327,10 @@ class PassDesc(object): return self._clone_with_operation( pass_desc_pb2.PassDesc.OperationType.kAdd, value) + def Mod(self, value): + return self._clone_with_operation( + pass_desc_pb2.PassDesc.OperationType.kMod, value) + def Size(self): return self._clone_with_operation( pass_desc_pb2.PassDesc.OperationType.kSize) @@ -336,13 +343,20 @@ class PassDesc(object): value._to_pass_desc_attr(condition.condition_attr) else: self._to_op_desc_attr(value, condition.condition_value) + if self._operation: + condition.operation.CopyFrom(self._operation) self._condition = condition def EQ(self, value): self._set_with_condition(pass_desc_pb2.PassDesc.ConditionType.kEQ, value) - def MappedPattern(self, var=None, op=None, index=0, name=None): + def MappedPattern(self, + var=None, + op=None, + index=0, + name=None, + element_index=None): if all([var, op]): raise ValueError("Only mapped one of which var or op.") @@ -356,7 +370,8 @@ class PassDesc(object): raise ValueError( "Index '{}' of operator '{}' is incorrect.".format( index, op)) - return PassDesc.AttrHelper(ops[index], name) + return PassDesc.AttrHelper( + ops[index], name, element_index=element_index) self._mapped = mapped_op if var is None else mapped_var @@ -460,6 +475,13 @@ class PassDesc(object): def Outputs(self): return self._outputs + def SetOutputs(self, **kwargs): + for param, arg in kwargs.items(): + if arg is None: + self._desc.remove_output(param) + else: + self._desc.set_output(param, [arg.name]) + OP = OpHelper() diff --git a/python/paddle/fluid/tests/unittests/ir/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/CMakeLists.txt index 5fc05a3a7cfab0666237b3037a9b8ca70d34e44f..3d80d92595b1739bc6c6638d03cf4df16213daac 100644 --- a/python/paddle/fluid/tests/unittests/ir/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/CMakeLists.txt @@ -10,3 +10,5 @@ foreach(target ${TEST_IR_PASSES}) endforeach() add_subdirectory(inference) + +set_tests_properties(test_fuse_resnet_unit PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/ir/test_fuse_resnet_unit.py b/python/paddle/fluid/tests/unittests/ir/test_fuse_resnet_unit.py new file mode 100644 index 0000000000000000000000000000000000000000..711891216b68a1c50a4a6469b84d0367925de83b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/test_fuse_resnet_unit.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.incubate +from paddle.fluid import core +from paddle.vision.models import ResNet +from paddle.vision.models.resnet import BottleneckBlock, BasicBlock + +paddle.enable_static() +np.random.seed(0) + + +@unittest.skipIf(not paddle.is_compiled_with_cuda() or + paddle.get_cudnn_version() < 8000, + "only support with cuda and cudnn version is at least 8.0.") +class TestFuseResNetUnit(unittest.TestCase): + def test_fuse_resenet_unit(self): + place = paddle.CUDAPlace(0) + program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.amp.fp16_guard(): + with paddle.static.program_guard(program, startup_program): + x = paddle.static.data("x", [1, 64, 64, 8]) + conv2d = paddle.nn.Conv2D( + 8, 32, 1, bias_attr=False, data_format='NHWC') + batch_norm = paddle.nn.BatchNorm( + 32, act='relu', data_layout='NHWC') + out = batch_norm(conv2d(x)) + graph = core.Graph(program.desc) + core.get_pass("fuse_resnet_unit").apply(graph) + after_program = paddle.fluid.framework.IrGraph(graph).to_program() + params = paddle.static.amp.cast_model_to_fp16(program) + after_params = paddle.static.amp.cast_model_to_fp16(after_program) + exe = paddle.static.Executor(place) + exe.run(startup_program) + paddle.static.amp.cast_parameters_to_fp16( + place, program, to_fp16_var_names=params) + paddle.static.amp.cast_parameters_to_fp16( + place, after_program, to_fp16_var_names=after_params) + feed = {"x": np.random.randn(1, 64, 64, 8).astype("float16")} + before_out = exe.run(program, feed=feed, fetch_list=[out.name]) + after_out = exe.run(after_program, feed=feed, fetch_list=[out.name]) + self.assertTrue(np.allclose(before_out[0], after_out[0], atol=5e-3)) diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index e5215cf506413ca559593f1b3e0ae85b16a07baa..7c7206d6e89c429ed98cba10d6662fba4079b633 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -23,6 +23,7 @@ from .tensor import segment_sum from .tensor import segment_mean from .tensor import segment_max from .tensor import segment_min +from .passes import fuse_resnet_unit_pass from . import nn #noqa: F401 diff --git a/python/paddle/incubate/passes/__init__.py b/python/paddle/incubate/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba160fcc405544a74aa4a387e2fd0bffa4055831 --- /dev/null +++ b/python/paddle/incubate/passes/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/incubate/passes/fuse_resnet_unit_pass.py b/python/paddle/incubate/passes/fuse_resnet_unit_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..4b5dca6141879db9d8f2346fea3b422d62d65dd0 --- /dev/null +++ b/python/paddle/incubate/passes/fuse_resnet_unit_pass.py @@ -0,0 +1,101 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid.ir as ir + + +def set_resnet_unit_attrs(resnet_unit, has_shortcut): + resnet_unit.SetAttr("fuse_add", False) + resnet_unit.SetAttr("act_type", "relu") + resnet_unit.SetAttr("has_shortcut", has_shortcut) + resnet_unit.SetAttr("data_format", 'NHWC') + resnet_unit.SetAttr("dilation", 1) + resnet_unit.Attr("stride").MappedPattern( + op="conv2d", name="strides", element_index=0) + resnet_unit.Attr("padding").MappedPattern( + op="conv2d", name="paddings", element_index=0) + resnet_unit.Attr("group").MappedPattern(op="conv2d", name="groups") + resnet_unit.Attr("op_device").MappedPattern(op="conv2d", name="op_device") + resnet_unit.Attr("op_namescope").MappedPattern( + op="conv2d", name="op_namescope") + resnet_unit.Attr("momentum").MappedPattern(op="batch_norm", name="momentum") + resnet_unit.Attr("epsilon").MappedPattern(op="batch_norm", name="epsilon") + resnet_unit.Attr("use_global_stats").MappedPattern( + op="batch_norm", name="use_global_stats") + + +def set_resnet_unit_outputs(resnet_unit, meanX, varX, meanZ=None, varZ=None): + resnet_unit.SetOutputs( + RunningMeanX=meanX, + RunningVarX=varX, + RunningMeanZ=meanZ, + RunningVarZ=varZ) + + +@ir.RegisterPass +def fuse_resnet_unit(): + def pattern_conv_bn(x, filter, scale, bias, mean, var): + filter.Attr("shape")[0].Mod(32).EQ(0) + filter.Attr("shape")[1].Mod(8).EQ(0) + filter.Attr("shape")[2].EQ(1) + filter.Attr("shape")[3].EQ(1) + conv2d = ir.PassDesc.OP.conv2d(Input=x, Filter=filter) + conv2d.SetAttr("data_format", 'NHWC') + bn = ir.PassDesc.OP.batch_norm( + X=conv2d, Bias=bias, Mean=mean, Scale=scale, Variance=var) + return bn + + def pattern_one_input(x, filter, scale, bias, mean, var): + bn = pattern_conv_bn(x, filter, scale, bias, mean, var) + relu = ir.PassDesc.OP.relu(X=bn.Output("Y")) + return relu + + def replace_one_input(x, filter, scale, bias, mean, var): + resnet_unit = ir.PassDesc.OP.resnet_unit( + X=x, FilterX=filter, ScaleX=scale, BiasX=bias, MeanX=mean, VarX=var) + set_resnet_unit_attrs(resnet_unit, False) + set_resnet_unit_outputs(resnet_unit, mean, var) + return resnet_unit.Output("Y") + + def pattern_two_input(x, filterX, scaleX, biasX, meanX, varX, z, filterZ, + scaleZ, biasZ, meanZ, varZ): + bnX = pattern_conv_bn(x, filterX, scaleX, biasX, meanX, varX) + bnZ = pattern_conv_bn(x, filterZ, scaleZ, biasZ, meanZ, varZ) + ewadd = ir.PassDesc.OP.elementwise_add( + X=bnX.Output("Y"), Y=bnZ.Output("Y")) + relu = ir.PassDesc.OP.relu(X=ewadd) + return relu + + def replace_two_input(x, filterX, scaleX, biasX, meanX, varX, z, filterZ, + scaleZ, biasZ, meanZ, varZ): + resnet_unit = ir.PassDesc.OP.resnet_unit( + X=x, + FilterX=filterX, + ScaleX=scaleX, + BiasX=biasX, + MeanX=meanX, + VarX=varX, + Z=z, + FilterZ=filterZ, + ScaleZ=scaleZ, + BiasZ=biasZ, + MeanZ=meanZ, + VarZ=varZ) + set_resnet_unit_attrs(resnet_unit, True) + set_resnet_unit_outputs(resnet_unit, meanX, varX, meanZ, varZ) + return resnet_unit.Output("Y") + + return (pattern_one_input, replace_one_input), (pattern_two_input, + replace_two_input) diff --git a/python/setup.py.in b/python/setup.py.in index 05418e1f46814a423dbbc77106ef2f08543376bd..60d9434e8566345c8a7357daef6c8b6746d67e62 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -275,6 +275,7 @@ packages=['paddle', 'paddle.incubate.operators', 'paddle.incubate.tensor', 'paddle.incubate.nn', + 'paddle.incubate.passes', 'paddle.distributed.fleet', 'paddle.distributed.fleet.base', 'paddle.distributed.fleet.elastic',