diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 6370d3380361c0c3b3434bfb9b73f429174f4cfa..453cfb85554ec53549bcbfd1f9be4566deb47d54 100755 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -157,7 +157,7 @@ ConvActivationFusePass::ConvActivationFusePass() { // IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute .AddAttr("data_format") .IsOptional() - .IsStringIn({"NCHW", "AnyLayout"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); AddOpCompat(OpCompat("relu")) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 9f70c829e1fb5ea8c13c45bdacbc408d8cc40173..5d325037ad20ed7d1a27d6af0b302304bf8ff8e6 100755 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -115,7 +115,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("data_format") - .IsStringIn({"NCHW"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); AddOpCompat(OpCompat("elementwise_add")) @@ -129,7 +129,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsTensor() .End() .AddAttr("axis") - .IsIntIn({1}) + .IsIntIn({1, 3}) .End(); } diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc index 0947f4756ad78c0e2b4b255e96f75ea440b022ef..5fbfef08b7209bc695f90ff9188b8e9a7db029a7 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc @@ -59,7 +59,7 @@ ConvConcatReLUFusePass::ConvConcatReLUFusePass() { .IsType>() .End() .AddAttr("data_format") - .IsStringIn({"NCHW"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); AddOpCompat(OpCompat("concat")) diff --git a/paddle/fluid/operators/distribution_helper.h b/paddle/fluid/operators/distribution_helper.h index a13ae570906871b33e841b5bcb449ead9351c637..f3bce38e3a7d8c9bf77ad02fdd22dfb2fa623828 100644 --- a/paddle/fluid/operators/distribution_helper.h +++ b/paddle/fluid/operators/distribution_helper.h @@ -28,6 +28,10 @@ limitations under the License. */ #include "paddle/fluid/platform/for_range.h" #include "paddle/pten/core/hostdevice.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/pten/kernels/primitive/kernel_primitives.h" +#endif + #if !defined(_WIN32) #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) #else @@ -91,6 +95,8 @@ struct normal_transform { #if defined(__NVCC__) || defined(__HIPCC__) +namespace kps = pten::kps; + /*********************** Distribution Function *************************/ template struct uniform_distribution; @@ -176,25 +182,26 @@ template __global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset, DistOp dist, TransformOp trans, T *out_data) { - size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - int32_t returns_count = DistOp::kReturnsCount; + size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); + static constexpr int kCount = DistOp::kReturnsCount; #if defined(__NVCC__) curandStatePhilox4_32_10_t state; - curand_init(seed, idx, offset, &state); + curand_init(seed, idx + THREAD_ID_X, offset, &state); + using SType = curandStatePhilox4_32_10_t; #else hiprandStatePhilox4_32_10_t state; - hiprand_init(seed, idx, offset, &state); + hiprand_init(seed, idx + THREAD_ID_X, offset, &state); + using SType = hiprandStatePhilox4_32_10_t; #endif - size_t total_thread = gridDim.x * blockDim.x; - for (size_t i = idx; i < size; i += total_thread * returns_count) { - auto random_tuple = dist(&state); - for (size_t j = 0; j < returns_count; j++) { - size_t index = i + j * total_thread; - if (index < size) { - auto random = (&random_tuple.x)[j]; - out_data[index] = static_cast(trans(random)); - } - } + size_t total_thread = GRID_NUM_X * BLOCK_NUM_X; + T args[kCount]; + T result[kCount]; + for (size_t i = idx; i < size; i += total_thread * kCount) { + kps::ElementwiseRandom(&args[0], dist, &state); + kps::ElementwiseUnary(&result[0], &args[0], + trans); + kps::WriteData(out_data + i, &result[0], size - i, + 1, total_thread, 1); } } diff --git a/paddle/pten/kernels/primitive/compute_primitives.h b/paddle/pten/kernels/primitive/compute_primitives.h index a8ed0816227635e3f0159bf00c3c3c6243c7daa9..02a2f7baf780b47330dc3ad4b6d3bef79e9920c1 100644 --- a/paddle/pten/kernels/primitive/compute_primitives.h +++ b/paddle/pten/kernels/primitive/compute_primitives.h @@ -428,5 +428,58 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) { } } +template +__device__ __forceinline__ void ElementwiseRandom(OutT* out, + OpFunc compute, + StateType* state) { + auto random_tuple = compute(state); +#pragma unroll + for (int i = 0; i < ReturnsCount; i++) { + out[i] = static_cast((&random_tuple.x)[i]); + } +} + +// attention please set share_size = blockDim.x; +// data and b are the register pointer +#define shared_size 64 +template +__device__ __forceinline__ void Cumsum(OutT* out, + const InT* in, + OpFunc compute) { + __shared__ InT temp[shared_size * 2 + (shared_size * 2) / 32]; + int tidx = threadIdx.x; + temp[tidx + tidx / 32] = in[0]; + temp[shared_size + tidx + (shared_size + tidx) / 32] = in[1]; + for (int stride = 1; stride <= blockDim.x; stride *= 2) { + __syncthreads(); + int index = (tidx + 1) * 2 * stride - 1; + if (index < (blockDim.x * 2)) { + temp[index + index / 32] += temp[index - stride + (index - stride) / 32]; + } + } + for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) { + __syncthreads(); + int index = (tidx + 1) * 2 * stride - 1; + if ((index + stride) < (blockDim.x * 2)) { + temp[index + stride + (stride + index) / 32] += + temp[index + (index) / 32]; + } + } + + __syncthreads(); + out[0] = static_cast(temp[tidx + tidx / 32]); + out[1] = + static_cast(temp[tidx + shared_size + (tidx + shared_size) / 32]); +} + } // namespace kps } // namespace pten diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py index d9b3b8e601755592609c4008b17cbefac81a7caf..d029bcd6a7f17b670ec544d9aa0279d5b4177310 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py @@ -53,8 +53,6 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): data_format = prog_config.ops[0].attrs["data_format"] filter_shape = prog_config.weights["filter"].shape input_shape = prog_config.inputs["input_x"].shape - if data_format != "NCHW": - return False if padding_algorithm == "VALID": if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ ((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: @@ -80,8 +78,8 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): x_shape = draw( st.lists( st.integers( - min_value=1, max_value=100), min_size=4, max_size=4)) - x_shape[1] = draw(st.integers(min_value=1, max_value=10)) + min_value=5, max_value=100), min_size=4, max_size=4)) + x_shape[1] = draw(st.integers(min_value=5, max_value=10)) # 2. Generate legal attr:data_format of conv2d data_format = draw(st.sampled_from(["NCHW", "NHWC"])) @@ -90,7 +88,7 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): f_shape = draw( st.lists( st.integers( - min_value=1, max_value=7), min_size=4, max_size=4)) + min_value=1, max_value=5), min_size=4, max_size=4)) if data_format == "NCHW": f_shape[1] = x_shape[1] else: diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py index 40fd9a418b9b1391bb9ebf6f36d1e73f55d48c14..a0213c5b1f4dfd5f542c0a8fab32e9a8c93a66c6 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py @@ -53,8 +53,6 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): data_format = prog_config.ops[0].attrs["data_format"] filter_shape = prog_config.weights["filter"].shape input_shape = prog_config.inputs["input_x"].shape - if data_format != "NCHW": - return False if padding_algorithm == "VALID": if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ ((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: @@ -80,8 +78,8 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): x_shape = draw( st.lists( st.integers( - min_value=1, max_value=100), min_size=4, max_size=4)) - x_shape[1] = draw(st.integers(min_value=1, max_value=10)) + min_value=5, max_value=100), min_size=4, max_size=4)) + x_shape[1] = draw(st.integers(min_value=5, max_value=10)) # 2. Generate legal attr:data_format of conv2d data_format = draw(st.sampled_from(["NCHW", "NHWC"])) @@ -90,7 +88,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): f_shape = draw( st.lists( st.integers( - min_value=1, max_value=7), min_size=4, max_size=4)) + min_value=1, max_value=4), min_size=4, max_size=4)) if data_format == "NCHW": f_shape[1] = x_shape[1] else: @@ -100,7 +98,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): strides = draw( st.lists( st.integers( - min_value=1, max_value=5), min_size=2, max_size=2)) + min_value=1, max_value=4), min_size=2, max_size=2)) # 5. Generate legal attr:padding_algorithm of conv2d padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) @@ -109,7 +107,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): padding = draw( st.lists( st.integers( - min_value=1, max_value=5), min_size=4, max_size=4)) + min_value=1, max_value=4), min_size=4, max_size=4)) # 7. Generate legal attr:groups of conv2d groups = draw(st.integers(min_value=1, max_value=3)) @@ -118,7 +116,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): dilations = draw( st.lists( st.integers( - min_value=1, max_value=5), min_size=2, max_size=2)) + min_value=1, max_value=4), min_size=2, max_size=2)) # 9. Generate legal shape of input:bias of elementwise_add bias_shape = [f_shape[0]] diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py index 3c823c73d37943cb9040522b21bdf3f642c10bea..6654fbba264e0ee007ea16d0c3b8131e84da01ad 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py @@ -27,15 +27,6 @@ import hypothesis.strategies as st class TestConvConcatReluMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": - return False - return True def sample_program_config(self, draw): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py index aa779f6ecbc4f99ba0ae20f4a6a1a7f674f46781..33df428388882f2e536ecebb15a1f5dae6a6afc5 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py @@ -27,15 +27,6 @@ import hypothesis.strategies as st class TestConvGeluMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": - return False - return True def sample_program_config(self, draw): @@ -108,19 +99,6 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest): config = self.create_inference_config(use_mkldnn=True) yield config, ["conv2d"], (1e-5, 1e-5) - # If the problem has been fixed, the judgment - # needs to be deleted!!! - def add_ignore_pass_case(self): - def teller1(program_config, predictor_config): - if program_config.ops[0].attrs['data_format'] == "NHWC": - return True - return False - - self.add_ignore_check_case( - teller1, SkipReasons.PASS_ACCURACY_ERROR, - "The output format of conv2d is wrong when data_format attribute is NHWC" - ) - def test(self): self.run_and_statis(quant=False, passes=["conv_gelu_mkldnn_fuse_pass"]) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py index a0c4e183930a570e25bb6aed893868f69a4edcd8..2eb071d6eb83bd89762145d4212732b6d6ad54f7 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py @@ -27,15 +27,6 @@ import hypothesis.strategies as st class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": - return False - return True def sample_program_config(self, draw): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py index 17bfb625fd37bc547a2c3eaece361f7cce5d6615..990489c32136a20e88f29591db7193aa194d00a9 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py @@ -27,15 +27,6 @@ import hypothesis.strategies as st class TestConvHardSwishMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": - return False - return True def sample_program_config(self, draw): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py index 5df7cb8d8cec35fd62c373e46d5b799bea46a2da..c5cedac226149d8dd30e0b25d0b7681b587cf3c5 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py @@ -32,9 +32,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest): for i in range(len(program_config.ops)) ] - # If the problem has been fixed, the judgment - # needs to be deleted!!! - if attrs[0]['data_format'] == "NHWC": + if attrs[0]['data_format'] == "NCHW" and attrs[1]["axis"] == 3: + return False + if attrs[0]['data_format'] == "NHWC" and attrs[1]["axis"] == 1: return False return True @@ -46,7 +46,7 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest): groups = draw(st.sampled_from([1, 2, 4, 8])) paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]])) strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) - axis = draw(st.sampled_from([1])) + axis = draw(st.sampled_from([1, 3])) batch_size = draw(st.integers(min_value=1, max_value=4)) def generate_input(): @@ -110,7 +110,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest): def test(self): self.run_and_statis( - quant=False, passes=["conv_transpose_bias_mkldnn_fuse_pass"]) + quant=False, + max_duration=300, + passes=["conv_transpose_bias_mkldnn_fuse_pass"]) if __name__ == "__main__":