未验证 提交 3eafa1fc 编写于 作者: X Xianduo Li 提交者: GitHub

Auto codegen for supporting calling new_ir api in static operants (#56955)

* support new ir primitive operator in static operants

* support more vjp code gen

* support more vjp code gen

* support more vjp code gen

* use code gen

* fix operants codegen

* support more vjp code gen

* Fix ci build error

* set FLAGS_tensor_operants_mode to static in generated_vjp for testing

* fix bugs

* change the order of ops_name of divide_grad

* replace FLAGS_enable_new_ir_in_executor by FLAGS_enable_new_ir_api in codegen and test_vjp_prim

---------
Co-authored-by: NCharles-hit <wanghao107@baidu.com>
Co-authored-by: N0x45f <wangzhen45@baidu.com>
上级 c62902ee
...@@ -211,6 +211,11 @@ static_source_include = """// Generated by paddle/fluid/prim/api/auto_code_gener ...@@ -211,6 +211,11 @@ static_source_include = """// Generated by paddle/fluid/prim/api/auto_code_gener
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h" #include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/primitive/backend/backend.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
PHI_DECLARE_bool(enable_new_ir_api);
""" """
...@@ -219,47 +224,88 @@ namespace paddle { ...@@ -219,47 +224,88 @@ namespace paddle {
namespace prim { namespace prim {
using DescTensor = paddle::prim::DescTensor; using DescTensor = paddle::prim::DescTensor;
using LazyTensor = paddle::primitive::LazyTensor;
Tensor StaticTensorOperants::add(const Tensor& x, const Scalar& y) { Tensor StaticTensorOperants::add(const Tensor& x, const Scalar& y) {
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::add<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
} else {
return paddle::prim::add<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place())); return paddle::prim::add<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
} }
Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) { Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) {
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::subtract<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
} else {
return paddle::prim::subtract<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place())); return paddle::prim::subtract<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
} }
Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) { Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) {
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::scale<LazyTensor>(x, y, 0.0f, true);
} else {
return paddle::prim::scale<DescTensor>(x, y, 0.0f, true); return paddle::prim::scale<DescTensor>(x, y, 0.0f, true);
}
} }
Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) { Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) {
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::divide<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
} else {
return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place())); return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
} }
Tensor StaticTensorOperants::add(const Scalar& x, const Tensor& y) { Tensor StaticTensorOperants::add(const Scalar& x, const Tensor& y) {
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::add<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
} else {
return paddle::prim::add<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y); return paddle::prim::add<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}
} }
Tensor StaticTensorOperants::subtract(const Scalar& x, const Tensor& y) { Tensor StaticTensorOperants::subtract(const Scalar& x, const Tensor& y) {
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::subtract<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
} else {
return paddle::prim::subtract<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y); return paddle::prim::subtract<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}
} }
Tensor StaticTensorOperants::multiply(const Scalar& x, const Tensor& y) { Tensor StaticTensorOperants::multiply(const Scalar& x, const Tensor& y) {
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::scale<LazyTensor>(y, x, 0.0f, true);
} else {
return paddle::prim::scale<DescTensor>(y, x, 0.0f, true); return paddle::prim::scale<DescTensor>(y, x, 0.0f, true);
}
} }
Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) { Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) {
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::divide<LazyTensor>(paddle::primitive::backend::full<LazyTensor>(y.shape(), x, y.dtype(), y.place()), y);
} else {
return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y); return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}
} }
Tensor StaticTensorOperants::pow(const Tensor& x, const Tensor& y) { Tensor StaticTensorOperants::pow(const Tensor& x, const Tensor& y) {
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::elementwise_pow<LazyTensor>(x, y);
} else {
return paddle::prim::elementwise_pow<DescTensor>(x, y); return paddle::prim::elementwise_pow<DescTensor>(x, y);
}
} }
Tensor StaticTensorOperants::pow(const Tensor& x, const Scalar& y) { Tensor StaticTensorOperants::pow(const Tensor& x, const Scalar& y) {
if (FLAGS_enable_new_ir_api) {
return paddle::primitive::backend::elementwise_pow<LazyTensor>(x, paddle::primitive::backend::full<LazyTensor>(x.shape(), y, x.dtype(), x.place()));
} else {
return paddle::prim::elementwise_pow<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place())); return paddle::prim::elementwise_pow<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
} }
""" """
...@@ -339,13 +385,21 @@ class PrimTensorAPI(BaseAPI): ...@@ -339,13 +385,21 @@ class PrimTensorAPI(BaseAPI):
def gene_static_tensor_func_call(self): def gene_static_tensor_func_call(self):
api_func_name = self.get_api_func_name() api_func_name = self.get_api_func_name()
backend_static_func_name = (
'paddle::primitive::backend::' + api_func_name + '<LazyTensor>'
)
prim_static_func_name = ( prim_static_func_name = (
'paddle::prim::' + api_func_name + '<DescTensor>' 'paddle::prim::' + api_func_name + '<DescTensor>'
) )
prim_static_func_parameters = self.get_func_args() static_func_parameters = self.get_func_args()
static_tensor_func_call = f"""if (FLAGS_enable_new_ir_api) {{
return {backend_static_func_name}({static_func_parameters});
}} else {{
return {prim_static_func_name}({static_func_parameters});
}}"""
return f"""return {prim_static_func_name}({prim_static_func_parameters});""" return static_tensor_func_call
def gene_static_tensor_operants_implementation(self): def gene_static_tensor_operants_implementation(self):
api_code = "" api_code = ""
......
...@@ -6,4 +6,4 @@ cc_library( ...@@ -6,4 +6,4 @@ cc_library(
cc_library( cc_library(
static_tensor_operants static_tensor_operants
SRCS static_tensor_operants.cc SRCS static_tensor_operants.cc
DEPS static_prim_api) DEPS static_prim_api primitive_backend_static_experimental)
...@@ -10,7 +10,9 @@ ...@@ -10,7 +10,9 @@
#include "paddle/fluid/primitive/type/lazy_tensor.h" #include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/fluid/primitive/utils/utils.h" #include "paddle/fluid/primitive/utils/utils.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_string(tensor_operants_mode);
namespace paddle { namespace paddle {
namespace primitive { namespace primitive {
...@@ -95,6 +97,7 @@ for (size_t i=0; i< stop_gradients[0].size(); i++ ) { ...@@ -95,6 +97,7 @@ for (size_t i=0; i< stop_gradients[0].size(); i++ ) {
{% endmacro %} {% endmacro %}
{% macro body_prim(api) %} {% macro body_prim(api) %}
FLAGS_tensor_operants_mode = "static";
{% for i in range(api.outputs|length) %} {% for i in range(api.outputs|length) %}
{% if api.outputs[i].typename=='Tensor' %} {% if api.outputs[i].typename=='Tensor' %}
paddle::Tensor* {{api.outputs[i].name}} = !stop_gradients[{{i}}][0] ? &vjp_res[{{i}}][0] : nullptr; paddle::Tensor* {{api.outputs[i].name}} = !stop_gradients[{{i}}][0] ? &vjp_res[{{i}}][0] : nullptr;
......
...@@ -39,10 +39,7 @@ void divide_grad(const Tensor& x, ...@@ -39,10 +39,7 @@ void divide_grad(const Tensor& x,
Tensor* dy) { Tensor* dy) {
if (dy) { if (dy) {
// dy = -(x/y^2) * dout // dy = -(x/y^2) * dout
auto denominator = auto dy_res = -(x / y.pow(2.0)) * out_grad;
elementwise_pow<T>(y, full<T>(y.shape(), 2.0, y.dtype(), y.place()));
auto dy_res = scale<T>(
multiply<T>(divide<T>(x, denominator), out_grad), -1.0, 0.0, true);
if (x.dims() != y.dims()) { if (x.dims() != y.dims()) {
// Maybe need reduce here // Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
...@@ -61,7 +58,7 @@ void divide_grad(const Tensor& x, ...@@ -61,7 +58,7 @@ void divide_grad(const Tensor& x,
if (dx) { if (dx) {
// dx = (1/y) * dout // dx = (1/y) * dout
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype()); auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
auto dx_res = multiply<T>(divide<T>(one_tensor, y), out_grad); auto dx_res = one_tensor / y * out_grad;
if (y.dims() != x.dims()) { if (y.dims() != x.dims()) {
// Maybe need reduce here // Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
......
...@@ -38,7 +38,7 @@ DataType ExtendedTensor::dtype() const { ...@@ -38,7 +38,7 @@ DataType ExtendedTensor::dtype() const {
DataLayout ExtendedTensor::layout() const { DataLayout ExtendedTensor::layout() const {
PADDLE_THROW(phi::errors::Unavailable( PADDLE_THROW(phi::errors::Unavailable(
"ExtendedTensor does not support `dtype` method.")); "ExtendedTensor does not support `layout` method."));
} }
bool ExtendedTensor::valid() const { bool ExtendedTensor::valid() const {
......
...@@ -63,6 +63,7 @@ class TestVjpPrim(unittest.TestCase): ...@@ -63,6 +63,7 @@ class TestVjpPrim(unittest.TestCase):
def test_divide_grad_prim_case1(self): def test_divide_grad_prim_case1(self):
newir_program = get_ir_divide_program() newir_program = get_ir_divide_program()
paddle.framework.core._set_prim_backward_enabled(True) paddle.framework.core._set_prim_backward_enabled(True)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
dout = newir_program.block().ops[-2].result(0) dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]] out_grads = [[dout]]
stop_gradients = [[False], [False]] stop_gradients = [[False], [False]]
...@@ -83,9 +84,9 @@ class TestVjpPrim(unittest.TestCase): ...@@ -83,9 +84,9 @@ class TestVjpPrim(unittest.TestCase):
"pd.full", "pd.full",
"pd.elementwise_pow", "pd.elementwise_pow",
"pd.divide", "pd.divide",
"pd.multiply",
"pd.full", "pd.full",
"pd.scale", "pd.scale",
"pd.multiply",
"pd.full_int_array", "pd.full_int_array",
"pd.sum", "pd.sum",
"pd.full_int_array", "pd.full_int_array",
...@@ -101,6 +102,7 @@ class TestVjpPrim(unittest.TestCase): ...@@ -101,6 +102,7 @@ class TestVjpPrim(unittest.TestCase):
for idx, op in enumerate(newir_program.block().ops): for idx, op in enumerate(newir_program.block().ops):
self.assertEqual(op.name(), all_op_names[idx]) self.assertEqual(op.name(), all_op_names[idx])
paddle.framework.core._set_prim_backward_enabled(False) paddle.framework.core._set_prim_backward_enabled(False)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
def test_divide_grad_no_prim(self): def test_divide_grad_no_prim(self):
newir_program = get_ir_divide_program() newir_program = get_ir_divide_program()
...@@ -123,6 +125,7 @@ class TestVjpPrim(unittest.TestCase): ...@@ -123,6 +125,7 @@ class TestVjpPrim(unittest.TestCase):
def test_sum_grad_prim(self): def test_sum_grad_prim(self):
newir_program = get_ir_sum_program() newir_program = get_ir_sum_program()
paddle.framework.core._set_prim_backward_enabled(True) paddle.framework.core._set_prim_backward_enabled(True)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
dout = newir_program.block().ops[-3].result(0) dout = newir_program.block().ops[-3].result(0)
out_grads = [[dout]] out_grads = [[dout]]
stop_gradients = [[False], [True]] stop_gradients = [[False], [True]]
...@@ -147,6 +150,7 @@ class TestVjpPrim(unittest.TestCase): ...@@ -147,6 +150,7 @@ class TestVjpPrim(unittest.TestCase):
for idx, op in enumerate(newir_program.block().ops): for idx, op in enumerate(newir_program.block().ops):
self.assertEqual(op.name(), all_op_names[idx]) self.assertEqual(op.name(), all_op_names[idx])
paddle.framework.core._set_prim_backward_enabled(False) paddle.framework.core._set_prim_backward_enabled(False)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
def test_sum_grad_no_prim(self): def test_sum_grad_no_prim(self):
newir_program = get_ir_sum_program() newir_program = get_ir_sum_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册