未验证 提交 4e578c2b 编写于 作者: B baoachun 提交者: GitHub

update seqconv_eltadd_relu_fuse_pass ut (#37907)

* update seqconv_eltadd_relu_fuse_pass ut

* update ut

* update ut

* update ut
上级 aadc8674
...@@ -924,6 +924,7 @@ PDNode *patterns::SeqConvEltAddRelu::operator()( ...@@ -924,6 +924,7 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
seqconv_input->assert_is_op_input("sequence_conv", "X"); seqconv_input->assert_is_op_input("sequence_conv", "X");
auto *seqconv_op = pattern->NewNode(seqconv_repr()) auto *seqconv_op = pattern->NewNode(seqconv_repr())
->assert_is_op("sequence_conv") ->assert_is_op("sequence_conv")
->assert_has_n_inputs(2)
->assert_op_attr<bool>("paddingTrainable", false) ->assert_op_attr<bool>("paddingTrainable", false)
->assert_op_attr<int>("contextStride", 1); ->assert_op_attr<int>("contextStride", 1);
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,128 +12,102 @@ ...@@ -12,128 +12,102 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from auto_scan_test import PassAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import unittest
import numpy as np import numpy as np
from inference_pass_test import InferencePassTest import paddle.inference as paddle_infer
import paddle.fluid as fluid from functools import partial
import paddle.fluid.core as core from typing import Optional, List, Callable, Dict, Any, Set
from paddle.fluid.core import AnalysisConfig import unittest
from paddle.fluid.core import PassVersionChecker
import hypothesis
from hypothesis import given, settings, seed, example, assume
class SeqconvEltaddReluFusePassTest(InferencePassTest): import hypothesis.strategies as st
def setUp(self): from functools import reduce
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data", shape=[100, 100], dtype="float32")
param_attr = fluid.ParamAttr( class TestSeqconvEltaddReluFusePass(PassAutoScanTest):
initializer=fluid.initializer.Xavier(uniform=False), def is_program_valid(self, program_config: ProgramConfig) -> bool:
learning_rate=0.001) return True
conv_out = fluid.layers.sequence_conv(
input=data, def sample_program_config(self, draw):
num_filters=16, contextLength = draw(st.sampled_from([1, 2, 3, 4]))
filter_size=4, contextStart = draw(st.sampled_from([1, 2, 3]))
padding_start=0, contextStride = draw(st.sampled_from([1]))
act="relu", paddingTrainable = False
bias_attr=param_attr) axis = draw(st.sampled_from([1]))
batch_size = draw(st.integers(min_value=1, max_value=4))
np_data = np.random.random((80, 100)).astype('float32')
x_lod_tensor = fluid.create_lod_tensor(np_data, [[10, 20, 30, 20]], def generate_input():
fluid.CPUPlace()) shape = [batch_size, 128, 6, 120]
self.feeds = {"data": x_lod_tensor} return np.random.random(shape).astype(np.float32)
self.fetch_list = [conv_out]
self.enable_mkldnn = True def generate_weight(shape):
return np.random.random(shape).astype(np.float32)
def test_check_output(self):
self.check_output() im2sequence_op = OpConfig(
self.assertTrue( type="im2sequence",
PassVersionChecker.IsCompatible('seqconv_eltadd_relu_fuse_pass')) inputs={"X": ["input_data"]},
outputs={"Out": ["seq_out"]},
attrs={
class SeqconvEltaddReluFusePassTestPaddingStartPositive(InferencePassTest): "kernels": [6, 1],
def setUp(self): "out_stride": [1, 1],
with fluid.program_guard(self.main_program, self.startup_program): "paddings": [0, 0, 0, 0],
data = fluid.data(name="data", shape=[-1, 4], dtype="float32") "strides": [1, 1]
param_attr = fluid.ParamAttr( })
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001) sequence_conv_op = OpConfig(
conv_out = fluid.layers.sequence_conv( type="sequence_conv",
input=data, inputs={"X": ["seq_out"],
num_filters=16, "Filter": ["conv_weight"]},
filter_size=3, outputs={"Out": ["conv_out"]},
padding_start=2, attrs={
act="relu", "contextLength": contextLength,
bias_attr=param_attr) "contextStart": contextStart,
"contextStride": contextStride,
np_data = np.array([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], "paddingTrainable": paddingTrainable
[4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6], })
[7, 7, 7, 7]]).astype('float32')
x_lod_tensor = fluid.create_lod_tensor(np_data, [[5, 2]], elementwise_add_op = OpConfig(
fluid.CPUPlace()) type="elementwise_add",
self.feeds = {"data": x_lod_tensor} inputs={"X": ["conv_out"],
self.fetch_list = [conv_out] "Y": ["elt_weight"]},
self.enable_mkldnn = True outputs={"Out": ["elt_output"]},
attrs={'axis': axis})
def test_check_output(self):
self.check_output() relu_op = OpConfig(
self.assertTrue( type="relu",
PassVersionChecker.IsCompatible('seqconv_eltadd_relu_fuse_pass')) inputs={"X": ["elt_output"]},
outputs={"Out": ["relu_output"]},
attrs={})
class SeqconvEltaddReluFusePassTestPaddingStartNegative(InferencePassTest):
def setUp(self): model_net = [
with fluid.program_guard(self.main_program, self.startup_program): im2sequence_op, sequence_conv_op, elementwise_add_op, relu_op
data = fluid.data(name="data", shape=[100, 100], dtype="float32") ]
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False), program_config = ProgramConfig(
learning_rate=0.001) ops=model_net,
conv_out = fluid.layers.sequence_conv( weights={
input=data, "conv_weight": TensorConfig(data_gen=partial(
num_filters=16, generate_weight, [768 * contextLength, 16])),
filter_size=4, "elt_weight":
padding_start=-1, TensorConfig(data_gen=partial(generate_weight, [16]))
act="relu", },
bias_attr=param_attr) inputs={
"input_data": TensorConfig(data_gen=partial(generate_input))
np_data = np.random.random((80, 100)).astype('float32') },
x_lod_tensor = fluid.create_lod_tensor(np_data, [[10, 20, 30, 20]], outputs=["relu_output"])
fluid.CPUPlace())
self.feeds = {"data": x_lod_tensor} return program_config
self.fetch_list = [conv_out]
self.enable_mkldnn = True def sample_predictor_configs(self, program_config):
config = self.create_inference_config()
def test_check_output(self): yield config, ["im2sequence", "fusion_seqconv_eltadd_relu"], (1e-5,
self.check_output() 1e-5)
self.assertTrue(
PassVersionChecker.IsCompatible('seqconv_eltadd_relu_fuse_pass')) def test(self):
self.run_and_statis(
quant=False, passes=["seqconv_eltadd_relu_fuse_pass"])
class SeqconvEltaddReluFusePassTestPaddingStartNone(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data", shape=[100, 100], dtype="float32")
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv_out = fluid.layers.sequence_conv(
input=data,
num_filters=16,
filter_size=4,
act="relu",
bias_attr=param_attr)
np_data = np.random.random((80, 100)).astype('float32')
x_lod_tensor = fluid.create_lod_tensor(np_data, [[10, 20, 30, 20]],
fluid.CPUPlace())
self.feeds = {"data": x_lod_tensor}
self.fetch_list = [conv_out]
self.enable_mkldnn = True
def test_check_output(self):
self.check_output()
self.assertTrue(
PassVersionChecker.IsCompatible('seqconv_eltadd_relu_fuse_pass'))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册