未验证 提交 90805e2d 编写于 作者: L Leo Chen 提交者: GitHub

Register op_version for new attribute use_addto (#28463)

* register op_version for addto

* upgrade pass capability

* change eq to le

* change eq to le

* fix merge
上级 1c3eef4c
......@@ -238,11 +238,11 @@ REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_affine_channel_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("affine_channel", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_affine_channel_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0)
.EQ("affine_channel", 0));
......@@ -383,11 +383,11 @@ REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("batch_norm", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0)
.EQ("batch_norm", 0));
......@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
......@@ -119,7 +121,7 @@ REGISTER_PASS(conv_elementwise_add2_act_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0)
.EQ("relu", 0)
.EQ("identity", 0));
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
......@@ -107,7 +108,7 @@ REGISTER_PASS(conv_elementwise_add_act_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0)
.EQ("relu", 0)
.EQ("identity", 0));
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
......@@ -93,5 +94,5 @@ REGISTER_PASS(conv_elementwise_add_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_elementwise_add_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0));
......@@ -13,7 +13,9 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -107,7 +109,7 @@ REGISTER_PASS(conv_relu_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_relu_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("relu", 0));
REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass,
......@@ -115,7 +117,7 @@ REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_leaky_relu_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.LE("leaky_relu", 1));
REGISTER_PASS(conv_relu6_mkldnn_fuse_pass,
......@@ -123,7 +125,7 @@ REGISTER_PASS(conv_relu6_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_relu6_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("relu6", 0));
REGISTER_PASS(conv_swish_mkldnn_fuse_pass,
......@@ -131,5 +133,5 @@ REGISTER_PASS(conv_swish_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_swish_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("swish", 0));
......@@ -13,8 +13,10 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h"
#include <functional>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -150,7 +152,7 @@ REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0));
REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
......
......@@ -13,7 +13,9 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h"
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -128,6 +130,6 @@ REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_concat_relu_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("concat", 0)
.EQ("relu", 0));
......@@ -13,11 +13,13 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional>
#include <list>
#include <map>
#include <memory>
#include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h"
......@@ -226,19 +228,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
conv_output->AsIntermediate();
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
auto get_node_from_elementwise_add =
[&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_y,
elementwise_add_out);
};
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_y,
elementwise_add_out);
};
return ExecuteHandleOnGraph<IdentityFuseHandle>(
&gpd, graph_with_stats,
......@@ -263,19 +266,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
conv_output);
conv_output->AsIntermediate();
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
auto get_node_from_elementwise_add =
[&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_x,
elementwise_add_out);
};
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_x,
elementwise_add_out);
};
return ExecuteHandleOnGraph<IdentityFuseHandle>(
&gpd, graph_with_stats,
......@@ -302,16 +306,17 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
conv_x_output->AsIntermediate();
conv_y_output->AsIntermediate();
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
auto get_node_from_elementwise_add =
[&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_out);
};
return std::make_tuple(elementwise_add_op, elementwise_add_out);
};
return ExecuteHandleOnGraph<ProjectionFuseHandle>(
&gpd, graph_with_stats,
......@@ -345,5 +350,5 @@ REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0));
......@@ -63,5 +63,5 @@ REGISTER_PASS(depthwise_conv_mkldnn_pass,
paddle::framework::ir::DepthwiseConvMKLDNNPass);
REGISTER_PASS_CAPABILITY(depthwise_conv_mkldnn_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"depthwise_conv2d", 0));
paddle::framework::compatible::OpVersionComparatorCombination().LE(
"depthwise_conv2d", 1));
......@@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
......@@ -331,7 +332,7 @@ REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass);
REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("fc", 0)
.LE("conv2d_transpose", 1)
.EQ("fake_quantize_abs_max", 0)
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include <algorithm>
#include <map>
#include <set>
......@@ -20,7 +22,6 @@
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
......@@ -309,6 +310,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
min_input_shape, max_input_shape, opt_input_shape,
disable_trt_plugin_fp16);
trt_engine->SetUseOSS(Get<bool>("use_oss"));
trt_engine->SetWithErnie(
graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass));
......@@ -367,13 +369,13 @@ REGISTER_PASS(tensorrt_subgraph_pass,
REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("pool2d", 0)
.EQ("relu", 0)
.EQ("softmax", 0)
.EQ("sigmoid", 0)
.EQ("hard_swish", 0)
.EQ("depthwise_conv2d", 0)
.LE("depthwise_conv2d", 1)
.EQ("batch_norm", 0)
.EQ("concat", 0)
.EQ("tanh", 0)
......
......@@ -18,6 +18,8 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_helper.h"
......@@ -817,3 +819,36 @@ REGISTER_OP_CPU_KERNEL(
conv3d_grad_grad,
ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(conv2d)
.AddCheckpoint(
R"ROC(
Upgrade conv2d, add a new attribute [use_addto].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_addto",
"In order to support new feature (inplace addto strategy) for "
"gradient accumulation.",
false));
REGISTER_OP_VERSION(depthwise_conv2d)
.AddCheckpoint(
R"ROC(
Upgrade depthwise_conv2d, add a new attribute [use_addto].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_addto",
"In order to support new feature (inplace addto strategy) for "
"gradient accumulation.",
false));
REGISTER_OP_VERSION(conv3d)
.AddCheckpoint(
R"ROC(
Upgrade conv3d, add a new attribute [use_addto].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_addto",
"In order to support new feature (inplace addto strategy) for "
"gradient accumulation.",
false));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册