From e457c29852df5fe418b4743fac4cd19b069fedcf Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Wed, 30 Aug 2023 13:41:13 +0800 Subject: [PATCH] [Prim][NewIR] Support prim all in new IR (#56614) * support prim all in new ir * process makefile * fix rule bug * polish case * fix flag * fix rules bug --- python/paddle/decomposition/decomp.py | 4 +- python/paddle/decomposition/rules.py | 5 +- .../incubate/autograd/composite_rules.py | 9 +- test/prim/new_ir_prim/CMakeLists.txt | 16 +++- test/prim/new_ir_prim/test_decomp_op.py | 3 + test/prim/new_ir_prim/test_prim_program.py | 94 +++++++++++++++++++ 6 files changed, 121 insertions(+), 10 deletions(-) create mode 100644 test/prim/new_ir_prim/test_prim_program.py diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index 9bd288dacd4..940249d0ae5 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -27,7 +27,7 @@ def _build_tensor_tuple(xs): return (xs,) elif isinstance(xs, typing.Sequence): return tuple(xs) - return TypeError(f"Type {type(xs)} is not supported") + return TypeError(f"Type {type(xs)} is not supported.") def _prepare_python_api_arguments(op): @@ -125,6 +125,8 @@ def decompose( Returns: dst_vars (list): A list contains all vars which replace origin ones in src_vars. """ + if not core._is_fwd_prim_enabled(): + return src_vars if not isinstance(program, Program): raise TypeError(f"Expect type Program, but got type {type(program)}.") block = program.block() diff --git a/python/paddle/decomposition/rules.py b/python/paddle/decomposition/rules.py index ec8959cc960..4184dbdbea6 100644 --- a/python/paddle/decomposition/rules.py +++ b/python/paddle/decomposition/rules.py @@ -20,8 +20,9 @@ from .register import register_decomp def mean(x, axis, keepdim): """define composite rule of op mean""" x_shape = x.shape - axes = axis or tuple(range(0, len(x_shape))) - axes = (axes,) if isinstance(axes, int) else axes + if axis in (None, []): + axis = tuple(range(0, len(x_shape))) + axes = (axis,) if isinstance(axis, int) else axis sum_x = sum(x, axis=axes, keepdim=keepdim) value_to_fill = 1 for axis in axes: diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index caedc31a3c1..1ca8ab62a83 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -171,11 +171,11 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): out = difference * rsqrt_var if scale is not None: - if x.shape[begin_norm_axis:] is not scale.shape: + if x.shape[begin_norm_axis:] != scale.shape: scale = reshape(scale, x.shape[begin_norm_axis:]) out = out * scale if bias is not None: - if x.shape[begin_norm_axis:] is not bias.shape: + if x.shape[begin_norm_axis:] != bias.shape: bias = reshape(bias, x.shape[begin_norm_axis:]) out = out + bias @@ -266,8 +266,9 @@ def mean_composite(x, axis, keepdim): is_amp = True x = cast(x, "float32") - axes = axis or list(range(0, len(x.shape))) - axes = [axes] if isinstance(axes, int) else axes + if axis in (None, []): + axis = tuple(range(0, len(x.shape))) + axes = (axis,) if isinstance(axis, int) else axis sum_x = sum(x, axis=axes, keepdim=keepdim) ele_nums_list = [x.shape[axis] for axis in axes] if ele_nums_list == []: diff --git a/test/prim/new_ir_prim/CMakeLists.txt b/test/prim/new_ir_prim/CMakeLists.txt index 393bc869d9b..85611d846cb 100644 --- a/test/prim/new_ir_prim/CMakeLists.txt +++ b/test/prim/new_ir_prim/CMakeLists.txt @@ -1,10 +1,20 @@ +set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program) + +foreach(target ${TEST_PRIM_PURE_NEW_IR_CASES}) + py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1 + FLAGS_enable_new_ir_api=true) +endforeach() + file( - GLOB TEST_INTERP_CASES + GLOB TEST_PRIM_TRANS_NEW_IR_CASES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") -string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") +string(REPLACE ".py" "" TEST_PRIM_TRANS_NEW_IR_CASES + "${TEST_PRIM_TRANS_NEW_IR_CASES}") + +list(REMOVE_ITEM TEST_PRIM_TRANS_NEW_IR_CASES ${TEST_PRIM_PURE_NEW_IR_CASES}) -foreach(target ${TEST_INTERP_CASES}) +foreach(target ${TEST_PRIM_TRANS_NEW_IR_CASES}) py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1 FLAGS_enable_new_ir_in_executor=true) endforeach() diff --git a/test/prim/new_ir_prim/test_decomp_op.py b/test/prim/new_ir_prim/test_decomp_op.py index f90e0fe2439..413008f814f 100644 --- a/test/prim/new_ir_prim/test_decomp_op.py +++ b/test/prim/new_ir_prim/test_decomp_op.py @@ -17,6 +17,7 @@ import unittest import paddle from paddle import ir from paddle.decomposition import decompose +from paddle.framework import core paddle.enable_static() @@ -44,7 +45,9 @@ class TestBuildOp(unittest.TestCase): y = newir_program.block().ops[-2].results() orig_shape = y[0].shape paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + core._set_prim_forward_enabled(True) y_new = decompose(newir_program, y) + core._set_prim_forward_enabled(False) new_shape = y_new[0].shape assert ( orig_shape == new_shape diff --git a/test/prim/new_ir_prim/test_prim_program.py b/test/prim/new_ir_prim/test_prim_program.py new file mode 100644 index 00000000000..594f65baa9b --- /dev/null +++ b/test/prim/new_ir_prim/test_prim_program.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.autograd.backward import grad +from paddle.decomposition import decompose +from paddle.framework import core + +paddle.enable_static() + + +class TestPrimMode(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.shape_x = [8, 16, 32, 64] + self.shape_y = [8, 16, 32, 64] + self.x = np.random.random(self.shape_x).astype("float32") + self.y = np.random.random(self.shape_y).astype("float32") + + def base_net(self, flag=None): + if flag == "forward": + core._set_prim_forward_enabled(True) + elif flag == "backward": + core._set_prim_backward_enabled(True) + elif flag == "all": + core._set_prim_all_enabled(True) + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data('x', self.shape_x, dtype='float32') + y = paddle.static.data('y', self.shape_y, dtype='float32') + x.stop_gradient = False + y.stop_gradient = False + divide_out = paddle.divide(x, y) + sum_out = paddle.mean(divide_out, axis=0) + [new_out] = decompose(main_program, [sum_out]) + gradients = grad(new_out, (x, y)) + + exe = paddle.static.Executor() + [fwd, dx, dy] = exe.run( + feed={'x': self.x, 'y': self.y}, fetch_list=[new_out, gradients] + ) + + whole_ops = [op.name() for op in main_program.block().ops] + if flag == "forward": + core._set_prim_forward_enabled(False) + assert 'pd.mean' not in whole_ops and 'pd.divide_grad' in whole_ops + elif flag == "backward": + core._set_prim_backward_enabled(False) + assert 'pd.mean' in whole_ops and 'pd.divide_grad' not in whole_ops + elif flag == "all": + core._set_prim_all_enabled(False) + assert ( + 'pd.mean' not in whole_ops and 'pd.divide_grad' not in whole_ops + ) + else: + assert 'pd.mean' in whole_ops and 'pd.divide_grad' in whole_ops + return fwd, dx, dy + + def test_prim_forward(self): + res_ref = self.base_net() + res = self.base_net("forward") + for ref, actual in zip(res_ref, res): + np.testing.assert_equal(ref, actual) + + def test_prim_backward(self): + res_ref = self.base_net() + res = self.base_net("backward") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + def test_prim_all(self): + res_ref = self.base_net() + res = self.base_net("all") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + +if __name__ == "__main__": + unittest.main() -- GitLab