未验证 提交 b33aaea8 编写于 作者: W wawltor 提交者: GitHub

add the op version check for the elementwise ops, test=op_version (#30010)

* add the op version check for the elementwise ops, test=op_version

* add the support check for elementwise_ops, test=op_version
上级 ed856d25
......@@ -153,7 +153,7 @@ REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("elementwise_add", 0));
.LE("elementwise_add", 1));
REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DTransposeBiasFusePass);
......@@ -161,7 +161,7 @@ REGISTER_PASS_CAPABILITY(conv_transpose_bias_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d_transpose", 1)
.EQ("elementwise_add", 0));
.LE("elementwise_add", 1));
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv3DBiasFusePass);
......@@ -228,20 +228,19 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
conv_output->AsIntermediate();
auto get_node_from_elementwise_add =
[&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_y,
elementwise_add_out);
};
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_y,
elementwise_add_out);
};
return ExecuteHandleOnGraph<IdentityFuseHandle>(
&gpd, graph_with_stats,
......@@ -266,20 +265,19 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
conv_output);
conv_output->AsIntermediate();
auto get_node_from_elementwise_add =
[&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_x,
elementwise_add_out);
};
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_x,
elementwise_add_out);
};
return ExecuteHandleOnGraph<IdentityFuseHandle>(
&gpd, graph_with_stats,
......@@ -306,17 +304,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
conv_x_output->AsIntermediate();
conv_y_output->AsIntermediate();
auto get_node_from_elementwise_add =
[&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_out);
};
return std::make_tuple(elementwise_add_op, elementwise_add_out);
};
return ExecuteHandleOnGraph<ProjectionFuseHandle>(
&gpd, graph_with_stats,
......@@ -351,4 +348,4 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("elementwise_add", 0));
.LE("elementwise_add", 1));
......@@ -221,5 +221,5 @@ REGISTER_PASS_CAPABILITY(mkldnn_inplace_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("softmax", 0)
.EQ("elementwise_add", 0)
.LE("elementwise_add", 1)
.EQ("tanh", 0));
......@@ -383,8 +383,8 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.EQ("concat", 0)
.EQ("tanh", 0)
.EQ("pad", 0)
.EQ("elementwise_add", 0)
.EQ("elementwise_mul", 0)
.LE("elementwise_add", 1)
.LE("elementwise_mul", 1)
.EQ("prelu", 0)
.LE("conv2d_transpose", 1)
.LE("leaky_relu", 1)
......
......@@ -44,8 +44,8 @@ REGISTER_OP_VERSION(arg_max)
false)
.ModifyAttr(
"dtype",
"change the default value of dtype, the older version "
"is -1, means return the int64 indices."
"The new version is 3, return the int64 indices directly."
"And supporting the dtype of -1 in new version.",
"Change the default value of dtype from -1 to 3"
", means return the int64 indices directly. The rearse why "
"changing the default value is that the int64 value in "
"VarType is 3 in the frameworke.proto.",
3));
......@@ -44,8 +44,8 @@ REGISTER_OP_VERSION(arg_min)
false)
.ModifyAttr(
"dtype",
"change the default value of dtype, the older version "
"is -1, means return the int64 indices."
"The new version is 3, return the int64 indices directly."
"And supporting the dtype of -1 in new version.",
"Change the default value of dtype from -1 to 3"
", means return the int64 indices directly. The rearse why "
"changing the default value is that the int64 value in "
"VarType is 3 in the frameworke.proto.",
3));
......@@ -3,7 +3,7 @@ if(WITH_UNITY_BUILD)
# Load Unity Build rules for operators in paddle/fluid/operators/elementwise.
include(unity_build_rule.cmake)
endif()
register_operators()
register_operators(DEPS op_version_registry)
cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor)
cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor)
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
......@@ -178,3 +177,13 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex64>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_VERSION(elementwise_add)
.AddCheckpoint(
R"ROC(Register elementwise_add for adding the attribute of
Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_add.",
1.0f));
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
......@@ -162,3 +163,12 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex64>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_VERSION(elementwise_div)
.AddCheckpoint(
R"ROC(Register elementwise_div for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_div.",
1.0f));
......@@ -69,3 +69,12 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseFloorDivKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFloorDivKernel<paddle::platform::CPUDeviceContext,
int64_t>);
REGISTER_OP_VERSION(elementwise_floordiv)
.AddCheckpoint(
R"ROC(Register elementwise_floordiv for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_floordiv.",
1.0f));
......@@ -94,3 +94,12 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_VERSION(elementwise_max)
.AddCheckpoint(
R"ROC(Register elementwise_max for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_max.",
1.0f));
......@@ -94,3 +94,12 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_VERSION(elementwise_min)
.AddCheckpoint(
R"ROC(Register elementwise_min for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_min.",
1.0f));
......@@ -69,3 +69,12 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseModFPKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseModFPKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(elementwise_mod)
.AddCheckpoint(
R"ROC(Register elementwise_mod for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_mod.",
1.0f));
......@@ -161,3 +161,12 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex64>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_VERSION(elementwise_mul)
.AddCheckpoint(
R"ROC(Register elementwise_mul for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_mul.",
1.0f));
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
......
......@@ -83,3 +83,12 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwisePowGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwisePowGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwisePowGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_VERSION(elementwise_pow)
.AddCheckpoint(
R"ROC(Register elementwise_pow for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_pow.",
1.0f));
......@@ -156,3 +156,12 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex64>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_VERSION(elementwise_sub)
.AddCheckpoint(
R"ROC(Register elementwise_sub for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_sub.",
1.0f));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册