未验证 提交 b33aaea8 编写于 作者: W wawltor 提交者: GitHub

add the op version check for the elementwise ops, test=op_version (#30010)

* add the op version check for the elementwise ops, test=op_version

* add the support check for elementwise_ops, test=op_version
上级 ed856d25
...@@ -153,7 +153,7 @@ REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass) ...@@ -153,7 +153,7 @@ REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1) .LE("conv2d", 1)
.EQ("elementwise_add", 0)); .LE("elementwise_add", 1));
REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass, REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DTransposeBiasFusePass); paddle::framework::ir::Conv2DTransposeBiasFusePass);
...@@ -161,7 +161,7 @@ REGISTER_PASS_CAPABILITY(conv_transpose_bias_mkldnn_fuse_pass) ...@@ -161,7 +161,7 @@ REGISTER_PASS_CAPABILITY(conv_transpose_bias_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d_transpose", 1) .LE("conv2d_transpose", 1)
.EQ("elementwise_add", 0)); .LE("elementwise_add", 1));
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv3DBiasFusePass); paddle::framework::ir::Conv3DBiasFusePass);
...@@ -228,20 +228,19 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -228,20 +228,19 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto get_node_from_elementwise_add = auto get_node_from_elementwise_add = [&elementwise_add_pattern](
[&elementwise_add_pattern]( const GraphPatternDetector::subgraph_t& subgraph)
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> { -> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern); elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
elementwise_add_pattern); elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern); elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_y, return std::make_tuple(elementwise_add_op, elementwise_add_y,
elementwise_add_out); elementwise_add_out);
}; };
return ExecuteHandleOnGraph<IdentityFuseHandle>( return ExecuteHandleOnGraph<IdentityFuseHandle>(
&gpd, graph_with_stats, &gpd, graph_with_stats,
...@@ -266,20 +265,19 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -266,20 +265,19 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
conv_output); conv_output);
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto get_node_from_elementwise_add = auto get_node_from_elementwise_add = [&elementwise_add_pattern](
[&elementwise_add_pattern]( const GraphPatternDetector::subgraph_t& subgraph)
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> { -> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern); elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern); elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern); elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_x, return std::make_tuple(elementwise_add_op, elementwise_add_x,
elementwise_add_out); elementwise_add_out);
}; };
return ExecuteHandleOnGraph<IdentityFuseHandle>( return ExecuteHandleOnGraph<IdentityFuseHandle>(
&gpd, graph_with_stats, &gpd, graph_with_stats,
...@@ -306,17 +304,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -306,17 +304,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
conv_x_output->AsIntermediate(); conv_x_output->AsIntermediate();
conv_y_output->AsIntermediate(); conv_y_output->AsIntermediate();
auto get_node_from_elementwise_add = auto get_node_from_elementwise_add = [&elementwise_add_pattern](
[&elementwise_add_pattern]( const GraphPatternDetector::subgraph_t& subgraph)
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*> { -> std::tuple<Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern); elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern); elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_out); return std::make_tuple(elementwise_add_op, elementwise_add_out);
}; };
return ExecuteHandleOnGraph<ProjectionFuseHandle>( return ExecuteHandleOnGraph<ProjectionFuseHandle>(
&gpd, graph_with_stats, &gpd, graph_with_stats,
...@@ -351,4 +348,4 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass) ...@@ -351,4 +348,4 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1) .LE("conv2d", 1)
.EQ("elementwise_add", 0)); .LE("elementwise_add", 1));
...@@ -221,5 +221,5 @@ REGISTER_PASS_CAPABILITY(mkldnn_inplace_pass) ...@@ -221,5 +221,5 @@ REGISTER_PASS_CAPABILITY(mkldnn_inplace_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("softmax", 0) .EQ("softmax", 0)
.EQ("elementwise_add", 0) .LE("elementwise_add", 1)
.EQ("tanh", 0)); .EQ("tanh", 0));
...@@ -383,8 +383,8 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) ...@@ -383,8 +383,8 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.EQ("concat", 0) .EQ("concat", 0)
.EQ("tanh", 0) .EQ("tanh", 0)
.EQ("pad", 0) .EQ("pad", 0)
.EQ("elementwise_add", 0) .LE("elementwise_add", 1)
.EQ("elementwise_mul", 0) .LE("elementwise_mul", 1)
.EQ("prelu", 0) .EQ("prelu", 0)
.LE("conv2d_transpose", 1) .LE("conv2d_transpose", 1)
.LE("leaky_relu", 1) .LE("leaky_relu", 1)
......
...@@ -44,8 +44,8 @@ REGISTER_OP_VERSION(arg_max) ...@@ -44,8 +44,8 @@ REGISTER_OP_VERSION(arg_max)
false) false)
.ModifyAttr( .ModifyAttr(
"dtype", "dtype",
"change the default value of dtype, the older version " "Change the default value of dtype from -1 to 3"
"is -1, means return the int64 indices." ", means return the int64 indices directly. The rearse why "
"The new version is 3, return the int64 indices directly." "changing the default value is that the int64 value in "
"And supporting the dtype of -1 in new version.", "VarType is 3 in the frameworke.proto.",
3)); 3));
...@@ -44,8 +44,8 @@ REGISTER_OP_VERSION(arg_min) ...@@ -44,8 +44,8 @@ REGISTER_OP_VERSION(arg_min)
false) false)
.ModifyAttr( .ModifyAttr(
"dtype", "dtype",
"change the default value of dtype, the older version " "Change the default value of dtype from -1 to 3"
"is -1, means return the int64 indices." ", means return the int64 indices directly. The rearse why "
"The new version is 3, return the int64 indices directly." "changing the default value is that the int64 value in "
"And supporting the dtype of -1 in new version.", "VarType is 3 in the frameworke.proto.",
3)); 3));
...@@ -3,7 +3,7 @@ if(WITH_UNITY_BUILD) ...@@ -3,7 +3,7 @@ if(WITH_UNITY_BUILD)
# Load Unity Build rules for operators in paddle/fluid/operators/elementwise. # Load Unity Build rules for operators in paddle/fluid/operators/elementwise.
include(unity_build_rule.cmake) include(unity_build_rule.cmake)
endif() endif()
register_operators() register_operators(DEPS op_version_registry)
cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor)
cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor) cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor)
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/complex64.h"
...@@ -178,3 +177,13 @@ REGISTER_OP_CPU_KERNEL( ...@@ -178,3 +177,13 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex64>, paddle::platform::complex64>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex128>);
REGISTER_OP_VERSION(elementwise_add)
.AddCheckpoint(
R"ROC(Register elementwise_add for adding the attribute of
Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_add.",
1.0f));
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/complex64.h"
...@@ -162,3 +163,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -162,3 +163,12 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex64>, paddle::platform::complex64>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex128>);
REGISTER_OP_VERSION(elementwise_div)
.AddCheckpoint(
R"ROC(Register elementwise_div for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_div.",
1.0f));
...@@ -69,3 +69,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -69,3 +69,12 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseFloorDivKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseFloorDivKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFloorDivKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseFloorDivKernel<paddle::platform::CPUDeviceContext,
int64_t>); int64_t>);
REGISTER_OP_VERSION(elementwise_floordiv)
.AddCheckpoint(
R"ROC(Register elementwise_floordiv for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_floordiv.",
1.0f));
...@@ -94,3 +94,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -94,3 +94,12 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, double>, ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_VERSION(elementwise_max)
.AddCheckpoint(
R"ROC(Register elementwise_max for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_max.",
1.0f));
...@@ -94,3 +94,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -94,3 +94,12 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, double>, ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_VERSION(elementwise_min)
.AddCheckpoint(
R"ROC(Register elementwise_min for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_min.",
1.0f));
...@@ -69,3 +69,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -69,3 +69,12 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseModFPKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseModFPKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseModFPKernel<paddle::platform::CPUDeviceContext, double>); ops::ElementwiseModFPKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(elementwise_mod)
.AddCheckpoint(
R"ROC(Register elementwise_mod for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_mod.",
1.0f));
...@@ -161,3 +161,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -161,3 +161,12 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex64>, paddle::platform::complex64>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex128>);
REGISTER_OP_VERSION(elementwise_mul)
.AddCheckpoint(
R"ROC(Register elementwise_mul for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_mul.",
1.0f));
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
......
...@@ -83,3 +83,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -83,3 +83,12 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwisePowGradKernel<paddle::platform::CPUDeviceContext, double>, ops::ElementwisePowGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwisePowGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwisePowGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwisePowGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::ElementwisePowGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_VERSION(elementwise_pow)
.AddCheckpoint(
R"ROC(Register elementwise_pow for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_pow.",
1.0f));
...@@ -156,3 +156,12 @@ REGISTER_OP_CPU_KERNEL( ...@@ -156,3 +156,12 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex64>, paddle::platform::complex64>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext, ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex128>);
REGISTER_OP_VERSION(elementwise_sub)
.AddCheckpoint(
R"ROC(Register elementwise_sub for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_sub.",
1.0f));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册