未验证 提交 f3022dfa 编写于 作者: W wangxinxin08 提交者: GitHub

add elementwise sub and elementwise div in tensorrt op teller (#40806)

* add elementwise sub and elementwise div in tensorrt op teller

* add unittest of elementwise mul, sub and div
上级 c544a181
......@@ -74,7 +74,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"tanh",
"pad",
"elementwise_add",
"elementwise_sub",
"elementwise_mul",
"elementwise_div",
"dropout",
"prelu",
"conv2d_transpose",
......@@ -133,7 +135,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"tanh",
"pad",
"elementwise_add",
"elementwise_sub",
"elementwise_mul",
"elementwise_div",
"dropout",
"prelu",
"conv2d_transpose",
......@@ -958,7 +962,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
if (op_type == "elementwise_add" || op_type == "elementwise_mul") {
if (op_type == "elementwise_add" || op_type == "elementwise_mul" ||
op_type == "elementwise_sub" || op_type == "elementwise_div") {
if (desc.Input("X").size() != 1) {
VLOG(3) << "The input op's Input(\"X\").size() "
"should equal to 1, but received Input(\"X\").size() = "
......
......@@ -150,7 +150,10 @@ class TrtConvertElementwiseTest_two_input_without_broadcast(
return np.random.random(shape).astype(np.float32)
for shape in [[4], [4, 32], [2, 64, 32], [1, 8, 16, 32]]:
for op_type in ["elementwise_add", "elementwise_mul"]:
for op_type in [
"elementwise_add", "elementwise_mul", "elementwise_sub",
"elementwise_div"
]:
for axis in [0, -1]:
self.dims = len(shape)
dics = [{"axis": axis}]
......@@ -306,7 +309,10 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest):
input1_shape = input1_shape_list[i]
for j in range(6):
input2_shape = input2_shape_list[j][i]
for op_type in ["elementwise_add", "elementwise_mul"]:
for op_type in [
"elementwise_add", "elementwise_mul", "elementwise_sub",
"elementwise_div"
]:
for axis in axis_list[j][i]:
self.shape1 = input1_shape
self.shape2 = input2_shape
......
......@@ -56,5 +56,23 @@ class TensorRTSubgraphPassElementwiseBroadcastTest(InferencePassTest):
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassElementwiseBroadcastTest1(
TensorRTSubgraphPassElementwiseBroadcastTest):
def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_sub(x=data1, y=data2, axis=0)
class TensorRTSubgraphPassElementwiseBroadcastTest2(
TensorRTSubgraphPassElementwiseBroadcastTest):
def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_mul(x=data1, y=data2, axis=0)
class TensorRTSubgraphPassElementwiseBroadcastTest3(
TensorRTSubgraphPassElementwiseBroadcastTest):
def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_div(x=data1, y=data2, axis=0)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册