From 3eafa1fc54cf57713e58b6d694826ab9c1868ca4 Mon Sep 17 00:00:00 2001 From: Xianduo Li <30922914+lxd-cumt@users.noreply.github.com> Date: Wed, 6 Sep 2023 14:36:36 +0800 Subject: [PATCH] Auto codegen for supporting calling new_ir api in static operants (#56955) * support new ir primitive operator in static operants * support more vjp code gen * support more vjp code gen * support more vjp code gen * use code gen * fix operants codegen * support more vjp code gen * Fix ci build error * set FLAGS_tensor_operants_mode to static in generated_vjp for testing * fix bugs * change the order of ops_name of divide_grad * replace FLAGS_enable_new_ir_in_executor by FLAGS_enable_new_ir_api in codegen and test_vjp_prim --------- Co-authored-by: Charles-hit Co-authored-by: 0x45f --- .../tensor_operants_gen.py | 82 +++++++++++++++---- paddle/fluid/prim/utils/static/CMakeLists.txt | 2 +- .../rule/vjp/generated/generated_vjp.cc.j2 | 3 + paddle/fluid/primitive/rule/vjp/details.h | 7 +- paddle/phi/core/extended_tensor.cc | 2 +- test/prim/new_ir_prim/test_vjp_prim.py | 6 +- 6 files changed, 80 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py b/paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py index 8322b0ba2be..783066f0fc9 100644 --- a/paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py +++ b/paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py @@ -211,6 +211,11 @@ static_source_include = """// Generated by paddle/fluid/prim/api/auto_code_gener #include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" +#include "paddle/fluid/primitive/backend/backend.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" + +PHI_DECLARE_bool(enable_new_ir_api); + """ @@ -219,47 +224,88 @@ namespace paddle { namespace prim { using DescTensor = paddle::prim::DescTensor; +using LazyTensor = paddle::primitive::LazyTensor; Tensor StaticTensorOperants::add(const Tensor& x, const Scalar& y) { - return paddle::prim::add(x, paddle::prim::full(x.shape(), y, x.dtype(), x.place())); + if (FLAGS_enable_new_ir_api) { + return paddle::primitive::backend::add(x, paddle::primitive::backend::full(x.shape(), y, x.dtype(), x.place())); + } else { + return paddle::prim::add(x, paddle::prim::full(x.shape(), y, x.dtype(), x.place())); + } } Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) { - return paddle::prim::subtract(x, paddle::prim::full(x.shape(), y, x.dtype(), x.place())); + if (FLAGS_enable_new_ir_api) { + return paddle::primitive::backend::subtract(x, paddle::primitive::backend::full(x.shape(), y, x.dtype(), x.place())); + } else { + return paddle::prim::subtract(x, paddle::prim::full(x.shape(), y, x.dtype(), x.place())); + } } Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) { - return paddle::prim::scale(x, y, 0.0f, true); + if (FLAGS_enable_new_ir_api) { + return paddle::primitive::backend::scale(x, y, 0.0f, true); + } else { + return paddle::prim::scale(x, y, 0.0f, true); + } } Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) { - return paddle::prim::divide(x, paddle::prim::full(x.shape(), y, x.dtype(), x.place())); + if (FLAGS_enable_new_ir_api) { + return paddle::primitive::backend::divide(x, paddle::primitive::backend::full(x.shape(), y, x.dtype(), x.place())); + } else { + return paddle::prim::divide(x, paddle::prim::full(x.shape(), y, x.dtype(), x.place())); + } } Tensor StaticTensorOperants::add(const Scalar& x, const Tensor& y) { - return paddle::prim::add(paddle::prim::full(y.shape(), x, y.dtype(), y.place()), y); + if (FLAGS_enable_new_ir_api) { + return paddle::primitive::backend::add(paddle::primitive::backend::full(y.shape(), x, y.dtype(), y.place()), y); + } else { + return paddle::prim::add(paddle::prim::full(y.shape(), x, y.dtype(), y.place()), y); + } } + Tensor StaticTensorOperants::subtract(const Scalar& x, const Tensor& y) { - return paddle::prim::subtract(paddle::prim::full(y.shape(), x, y.dtype(), y.place()), y); + if (FLAGS_enable_new_ir_api) { + return paddle::primitive::backend::subtract(paddle::primitive::backend::full(y.shape(), x, y.dtype(), y.place()), y); + } else { + return paddle::prim::subtract(paddle::prim::full(y.shape(), x, y.dtype(), y.place()), y); + } } Tensor StaticTensorOperants::multiply(const Scalar& x, const Tensor& y) { - return paddle::prim::scale(y, x, 0.0f, true); + if (FLAGS_enable_new_ir_api) { + return paddle::primitive::backend::scale(y, x, 0.0f, true); + } else { + return paddle::prim::scale(y, x, 0.0f, true); + } } Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) { - return paddle::prim::divide(paddle::prim::full(y.shape(), x, y.dtype(), y.place()), y); + if (FLAGS_enable_new_ir_api) { + return paddle::primitive::backend::divide(paddle::primitive::backend::full(y.shape(), x, y.dtype(), y.place()), y); + } else { + return paddle::prim::divide(paddle::prim::full(y.shape(), x, y.dtype(), y.place()), y); + } } Tensor StaticTensorOperants::pow(const Tensor& x, const Tensor& y) { - return paddle::prim::elementwise_pow(x, y); + if (FLAGS_enable_new_ir_api) { + return paddle::primitive::backend::elementwise_pow(x, y); + } else { + return paddle::prim::elementwise_pow(x, y); + } } Tensor StaticTensorOperants::pow(const Tensor& x, const Scalar& y) { - return paddle::prim::elementwise_pow(x, paddle::prim::full(x.shape(), y, x.dtype(), x.place())); + if (FLAGS_enable_new_ir_api) { + return paddle::primitive::backend::elementwise_pow(x, paddle::primitive::backend::full(x.shape(), y, x.dtype(), x.place())); + } else { + return paddle::prim::elementwise_pow(x, paddle::prim::full(x.shape(), y, x.dtype(), x.place())); + } } - """ @@ -339,13 +385,21 @@ class PrimTensorAPI(BaseAPI): def gene_static_tensor_func_call(self): api_func_name = self.get_api_func_name() - + backend_static_func_name = ( + 'paddle::primitive::backend::' + api_func_name + '' + ) prim_static_func_name = ( 'paddle::prim::' + api_func_name + '' ) - prim_static_func_parameters = self.get_func_args() + static_func_parameters = self.get_func_args() + + static_tensor_func_call = f"""if (FLAGS_enable_new_ir_api) {{ + return {backend_static_func_name}({static_func_parameters}); + }} else {{ + return {prim_static_func_name}({static_func_parameters}); + }}""" - return f"""return {prim_static_func_name}({prim_static_func_parameters});""" + return static_tensor_func_call def gene_static_tensor_operants_implementation(self): api_code = "" diff --git a/paddle/fluid/prim/utils/static/CMakeLists.txt b/paddle/fluid/prim/utils/static/CMakeLists.txt index aa72fadb591..483c3eabc05 100644 --- a/paddle/fluid/prim/utils/static/CMakeLists.txt +++ b/paddle/fluid/prim/utils/static/CMakeLists.txt @@ -6,4 +6,4 @@ cc_library( cc_library( static_tensor_operants SRCS static_tensor_operants.cc - DEPS static_prim_api) + DEPS static_prim_api primitive_backend_static_experimental) diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 index ab040254355..6d694337376 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 @@ -10,7 +10,9 @@ #include "paddle/fluid/primitive/type/lazy_tensor.h" #include "paddle/fluid/primitive/utils/utils.h" #include "paddle/ir/core/operation.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_string(tensor_operants_mode); namespace paddle { namespace primitive { @@ -95,6 +97,7 @@ for (size_t i=0; i< stop_gradients[0].size(); i++ ) { {% endmacro %} {% macro body_prim(api) %} +FLAGS_tensor_operants_mode = "static"; {% for i in range(api.outputs|length) %} {% if api.outputs[i].typename=='Tensor' %} paddle::Tensor* {{api.outputs[i].name}} = !stop_gradients[{{i}}][0] ? &vjp_res[{{i}}][0] : nullptr; diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index e018cccdef7..12fb66127a2 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -39,10 +39,7 @@ void divide_grad(const Tensor& x, Tensor* dy) { if (dy) { // dy = -(x/y^2) * dout - auto denominator = - elementwise_pow(y, full(y.shape(), 2.0, y.dtype(), y.place())); - auto dy_res = scale( - multiply(divide(x, denominator), out_grad), -1.0, 0.0, true); + auto dy_res = -(x / y.pow(2.0)) * out_grad; if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); @@ -61,7 +58,7 @@ void divide_grad(const Tensor& x, if (dx) { // dx = (1/y) * dout auto one_tensor = full(phi::vectorize(y.dims()), 1.0, y.dtype()); - auto dx_res = multiply(divide(one_tensor, y), out_grad); + auto dx_res = one_tensor / y * out_grad; if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); diff --git a/paddle/phi/core/extended_tensor.cc b/paddle/phi/core/extended_tensor.cc index e5b5c3773f8..31d0fb25c88 100644 --- a/paddle/phi/core/extended_tensor.cc +++ b/paddle/phi/core/extended_tensor.cc @@ -38,7 +38,7 @@ DataType ExtendedTensor::dtype() const { DataLayout ExtendedTensor::layout() const { PADDLE_THROW(phi::errors::Unavailable( - "ExtendedTensor does not support `dtype` method.")); + "ExtendedTensor does not support `layout` method.")); } bool ExtendedTensor::valid() const { diff --git a/test/prim/new_ir_prim/test_vjp_prim.py b/test/prim/new_ir_prim/test_vjp_prim.py index 2a29ae9f69f..22309a08823 100644 --- a/test/prim/new_ir_prim/test_vjp_prim.py +++ b/test/prim/new_ir_prim/test_vjp_prim.py @@ -63,6 +63,7 @@ class TestVjpPrim(unittest.TestCase): def test_divide_grad_prim_case1(self): newir_program = get_ir_divide_program() paddle.framework.core._set_prim_backward_enabled(True) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) dout = newir_program.block().ops[-2].result(0) out_grads = [[dout]] stop_gradients = [[False], [False]] @@ -83,9 +84,9 @@ class TestVjpPrim(unittest.TestCase): "pd.full", "pd.elementwise_pow", "pd.divide", - "pd.multiply", "pd.full", "pd.scale", + "pd.multiply", "pd.full_int_array", "pd.sum", "pd.full_int_array", @@ -101,6 +102,7 @@ class TestVjpPrim(unittest.TestCase): for idx, op in enumerate(newir_program.block().ops): self.assertEqual(op.name(), all_op_names[idx]) paddle.framework.core._set_prim_backward_enabled(False) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) def test_divide_grad_no_prim(self): newir_program = get_ir_divide_program() @@ -123,6 +125,7 @@ class TestVjpPrim(unittest.TestCase): def test_sum_grad_prim(self): newir_program = get_ir_sum_program() paddle.framework.core._set_prim_backward_enabled(True) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) dout = newir_program.block().ops[-3].result(0) out_grads = [[dout]] stop_gradients = [[False], [True]] @@ -147,6 +150,7 @@ class TestVjpPrim(unittest.TestCase): for idx, op in enumerate(newir_program.block().ops): self.assertEqual(op.name(), all_op_names[idx]) paddle.framework.core._set_prim_backward_enabled(False) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) def test_sum_grad_no_prim(self): newir_program = get_ir_sum_program() -- GitLab