diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 3e70104897004fef4300077d6da4c33d2610aab8..cda8788f566f7786b32644f4c5481b48993204f7 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -83,7 +83,7 @@ optional : bias, x_max - op : fused_bias_act - args : (Tensor x, Tensor bias, Tensor dequant_scales, Tensor shift, Tensor smooth, str act_method = "gelu", str compute_dtype = "default", int rows = -1, int cols = -1, float quant_scale = -1, int quant_round_type = 1, float quant_max_bound = 127.0, float quant_min_bound = -127.0) + args : (Tensor x, Tensor bias, Tensor dequant_scales, Tensor shift, Tensor smooth, str act_method = "gelu", str compute_dtype = "default", float quant_scale = -1, int quant_round_type = 1, float quant_max_bound = 127.0, float quant_min_bound = -127.0) output : Tensor(out) infer_meta : func: FusedBiasActInferMeta diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 67c98f7decb89f47967e4a1283ee3ca16b1d730a..5259590e9284f444e4855f6c62edc17ed26ff44b 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1342,8 +1342,6 @@ void FusedBiasActInferMeta(const MetaTensor& x, const MetaTensor& smooth, const std::string& act_method, const std::string& compute_dtype, - int rows, - int cols, float quant_scale, int quant_round_type, float quant_max_bound, @@ -1358,10 +1356,14 @@ void FusedBiasActInferMeta(const MetaTensor& x, auto dim = x_dims[1]; PADDLE_ENFORCE_GT( - rows, 0, phi::errors::InvalidArgument("The size of Attr(rows) must > 0")); + x_dims[0], + 0, + phi::errors::InvalidArgument("The size of Attr(rows) must > 0")); PADDLE_ENFORCE_GT( - cols, 0, phi::errors::InvalidArgument("The size of Attr(cols) must > 0")); + x_dims[1], + 0, + phi::errors::InvalidArgument("The size of Attr(cols) must > 0")); if (act_method == "geglu" || act_method == "swiglu") { PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index b52573c830695d687dfae81ff6d5b58ff0c6516b..10ab89a81560a8ecd3aa85abcf7ef7b65d3f0e6d 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -286,8 +286,6 @@ void FusedBiasActInferMeta(const MetaTensor& x, const MetaTensor& smooth, const std::string& act_method, const std::string& compute_dtype, - int rows, - int cols, float quant_scale, int quant_round_type, float quant_max_bound, diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu index ff722a0dfdff48ea207682911406b657498e6dce..8f75d91fc682d51885113f8d7d2d06472ffa480f 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu @@ -438,14 +438,14 @@ void FusedBiasActKernel(const Context &dev_ctx, const paddle::optional &smooth, const std::string &act_method, const std::string &compute_dtype, - int rows, - int cols, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound, DenseTensor *out) { #ifndef PADDLE_WITH_HIP + int rows = x.dims()[0]; + int cols = x.dims()[1]; if (x.dtype() == phi::DataType::INT32) { if (compute_dtype == "bf16") { DispatchWithDtype( diff --git a/test/legacy_test/test_fused_bias_act_op.py b/test/legacy_test/test_fused_bias_act_op.py index 7d90fc36bbb313ad6b27c74fdb2b9de8f7dbc805..53ed9cc3306930b564d41d9d54f2b22762e9cd4d 100644 --- a/test/legacy_test/test_fused_bias_act_op.py +++ b/test/legacy_test/test_fused_bias_act_op.py @@ -73,8 +73,6 @@ def fused_act_bias_wrapper( smooth=None, act_method='gelu', compute_dtype='default', - rows=0, - cols=0, quant_scale=-1, quant_round_type=0, quant_max_bound=0, @@ -88,8 +86,6 @@ def fused_act_bias_wrapper( smooth, act_method, compute_dtype, - rows, - cols, quant_scale, quant_round_type, quant_max_bound, @@ -140,8 +136,6 @@ class TestFusedBiasActOp(unittest.TestCase): return fused_act_bias_wrapper( x=x, bias=bias, - rows=self.rows, - cols=self.cols, act_method=self.act_method, compute_dtype=self.compute_dtype, ) @@ -197,8 +191,6 @@ class TestFastGeluFP16(TestFusedBiasActOp): out = fused_act_bias_wrapper( x=x, bias=bias, - rows=self.rows, - cols=self.cols, act_method=self.act_method, ) self.use_fast_math(False) @@ -284,8 +276,6 @@ class TestQuantFP32(TestFusedBiasActOp): smooth=smooth, act_method=self.act_method, compute_dtype=self.compute_dtype, - rows=self.rows, - cols=self.cols, quant_scale=self.quant_scale, quant_round_type=self.quant_round_type, quant_max_bound=self.quant_max_bound, @@ -332,8 +322,6 @@ class TestDequantFP32(TestQuantFP32): dequant_scales=dequant_scales, act_method=self.act_method, compute_dtype=self.compute_dtype, - rows=self.rows, - cols=self.cols, ) return out @@ -441,8 +429,6 @@ class TestFusedBiasActOpBF16(unittest.TestCase): bias=bias, act_method=self.act_method, compute_dtype=self.compute_dtype, - rows=self.rows, - cols=self.cols, ) return out @@ -565,8 +551,6 @@ class TestQuantBF16(TestFusedBiasActOpBF16): smooth=smooth, act_method=self.act_method, compute_dtype=self.compute_dtype, - rows=self.rows, - cols=self.cols, quant_scale=self.quant_scale, quant_round_type=self.quant_round_type, quant_max_bound=self.quant_max_bound, @@ -678,8 +662,6 @@ class TestAssert(unittest.TestCase): out = fused_act_bias_wrapper( x=paddle.to_tensor(x), bias=paddle.to_tensor(bias), - rows=self.rows, - cols=self.cols, ) except ValueError as e: pass @@ -696,8 +678,6 @@ class TestAssert(unittest.TestCase): out = fused_act_bias_wrapper( x=paddle.to_tensor(x), bias=paddle.to_tensor(bias), - rows=self.rows, - cols=self.cols, compute_dtype='fp16', ) except ValueError as e: @@ -715,8 +695,6 @@ class TestAssert(unittest.TestCase): out = fused_act_bias_wrapper( x=paddle.to_tensor(x), bias=paddle.to_tensor(bias), - rows=self.rows, - cols=self.cols, compute_dtype='fp16', act_method=act_method, ) @@ -765,8 +743,6 @@ class TestWithoutBias(unittest.TestCase): return fused_act_bias_wrapper( x=x, bias=None, - rows=self.rows, - cols=self.cols, act_method=self.act_method, )