From 2b5466f9944c5ddd78ff8e7191b7169c9e131088 Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Fri, 18 Aug 2023 14:43:44 +0800 Subject: [PATCH] [PRIM][IR]Support prim in new ir (#56342) * support ir api form prim * convert vector of int to intarray * support ir api for prim * support vjp prim mode in new ir * remove useless code * add test for prim * modify utils * remove useless code --------- Co-authored-by: cyber-pioneer --- paddle/fluid/framework/type_info.cc | 5 +- .../fluid/ir/dialect/op_generator/api_gen.py | 2 + .../fluid/ir/dialect/op_generator/op_gen.py | 12 +- .../dialect/op_generator/op_interface_gen.py | 14 +- .../op_generator/vjp_interface_gen_op_list.py | 3 +- paddle/fluid/ir/dialect/pd_op_vjp_manual.cc | 31 ++- paddle/fluid/primitive/CMakeLists.txt | 1 + .../fluid/primitive/backend/eager_backend.cc | 4 +- .../fluid/primitive/backend/eager_backend.h | 4 +- .../fluid/primitive/backend/static_backend.cc | 213 +++++++++++++++--- .../fluid/primitive/backend/static_backend.h | 56 ++++- paddle/fluid/primitive/primitive/primitive.h | 63 +++++- .../fluid/primitive/rule/vjp/CMakeLists.txt | 6 +- paddle/fluid/primitive/rule/vjp/details.h | 137 +++++++++++ paddle/fluid/primitive/rule/vjp/vjp.cc | 90 ++++++-- paddle/fluid/primitive/rule/vjp/vjp.h | 28 +-- .../type/{static_tensor.h => lazy_tensor.h} | 16 +- paddle/fluid/primitive/utils/CMakeLists.txt | 10 + paddle/fluid/primitive/utils/eager_utils.cc | 26 +++ paddle/fluid/primitive/utils/static_utils.cc | 25 ++ paddle/fluid/primitive/utils/utils.h | 91 ++++++++ test/prim/new_ir_prim/test_vjp_prim.py | 167 ++++++++++++++ 22 files changed, 903 insertions(+), 101 deletions(-) create mode 100644 paddle/fluid/primitive/rule/vjp/details.h rename paddle/fluid/primitive/type/{static_tensor.h => lazy_tensor.h} (77%) create mode 100644 paddle/fluid/primitive/utils/CMakeLists.txt create mode 100644 paddle/fluid/primitive/utils/eager_utils.cc create mode 100644 paddle/fluid/primitive/utils/static_utils.cc create mode 100644 paddle/fluid/primitive/utils/utils.h create mode 100644 test/prim/new_ir_prim/test_vjp_prim.py diff --git a/paddle/fluid/framework/type_info.cc b/paddle/fluid/framework/type_info.cc index 8ab4dc3cc47..442800d035f 100644 --- a/paddle/fluid/framework/type_info.cc +++ b/paddle/fluid/framework/type_info.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/raw_tensor.h" #include "paddle/fluid/framework/string_array.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" -#include "paddle/fluid/primitive/type/static_tensor.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" namespace phi { @@ -41,8 +41,7 @@ template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; -template class TypeInfoTraits; +template class TypeInfoTraits; template class TypeInfoTraits; diff --git a/paddle/fluid/ir/dialect/op_generator/api_gen.py b/paddle/fluid/ir/dialect/op_generator/api_gen.py index 7680ddfb122..6e313267059 100644 --- a/paddle/fluid/ir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/api_gen.py @@ -95,6 +95,8 @@ API_LIST = [ 'expand', 'tile', 'add_grad', + 'divide_grad', + 'sum_grad', ] OP_RESULT = 'ir::OpResult' VECTOR_TYPE = 'ir::VectorType' diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 4d0f28ea075..da1d7cbdde0 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -21,10 +21,13 @@ from op_interface_gen import ( gen_exclusive_interface_str, gen_op_infer_meta_str, gen_op_vjp_str, - vjp_interface_gen_op_list, ) from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str +from vjp_interface_gen_op_list import ( + vjp_interface_declare_gen_op_list, + vjp_interface_implementation_gen_op_list, +) # ===================================== # String Template for h file code gen @@ -112,7 +115,7 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g #include "paddle/phi/infermeta/backward.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" -#include "paddle/fluid/primitive/type/static_tensor.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" #include "paddle/ir/core/op_base.h" {input} @@ -756,7 +759,7 @@ def OpGenerator( if ( op_info.backward_name - and op_info.op_phi_name[0] in vjp_interface_gen_op_list + and op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list ): op_interfaces += ["VjpInterface"] exclusive_interface_str = gen_exclusive_interface_str(op_info) @@ -1055,7 +1058,8 @@ def OpGenerator( # TODO(chenzhiyang) add vjp gen code if ( op_info.backward_name - and op_info.op_phi_name[0] in vjp_interface_gen_op_list + and op_info.op_phi_name[0] + in vjp_interface_implementation_gen_op_list ): op_vjp_str = gen_op_vjp_str( op_class_name, diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py index 7c04fa14033..8762c6328e1 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -13,7 +13,7 @@ # limitations under the License. # generator interfaces -from vjp_interface_gen_op_list import vjp_interface_gen_op_list +from vjp_interface_gen_op_list import vjp_interface_declare_gen_op_list OP_INFER_SHAPE_TEMPLATE = """ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ @@ -23,13 +23,13 @@ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ """ OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """ - {input_type} {input_name}(std::make_shared(op_obj.{input_name}()));""" + {input_type} {input_name}(std::make_shared(op_obj.{input_name}()));""" OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ - Tensor {output_grad_name}(std::make_shared(out_grads[{idx1}][{idx2}]));""" + Tensor {output_grad_name}(std::make_shared(out_grads[{idx1}][{idx2}]));""" OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ - std::vector {output_grad_name}(std::make_shared(out_grads[{idx1}]));""" + std::vector {output_grad_name}(std::make_shared(out_grads[{idx1}]));""" OP_VJP_ATTRIBUTE_TEMPLATE = """ {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();""" @@ -39,7 +39,7 @@ OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """ OP_VJP_CALL_VJP_TEMPLATE = """ std::vector> tensor_res = - primitive::experimental::{op_phi_name}_vjp( + primitive::{op_phi_name}_vjp( {inputs_list}stop_gradients);""" OP_VJP_STOPGRADIENT_TEMPLATE = """ @@ -48,7 +48,7 @@ OP_VJP_STOPGRADIENT_TEMPLATE = """ res[i].resize(tensor_res[i].size()); for (size_t j = 0; j < tensor_res[i].size(); ++j) {{ if(tensor_res[i][j].defined()){{ - res[i][j] = std::static_pointer_cast(tensor_res[i][j].impl())->getValue().dyn_cast(); + res[i][j] = std::static_pointer_cast(tensor_res[i][j].impl())->getValue().dyn_cast(); }} }} }}""" @@ -166,6 +166,6 @@ def gen_exclusive_interface_str(op_info): exclusive_interface_str += ( " static void InferMeta( phi::InferMetaContext *infer_meta );" ) - if op_info.op_phi_name[0] in vjp_interface_gen_op_list: + if op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list: exclusive_interface_str += "\n static std::vector> Vjp(ir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients);" return exclusive_interface_str diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index 6d788eb85a5..fd7d61897d8 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -21,4 +21,5 @@ # TODO(wanghao107) # remove this file and support Vjp methods # code gen. -vjp_interface_gen_op_list = ["tanh", "mean", "add"] +vjp_interface_declare_gen_op_list = ["tanh", "mean", "divide", "sum", "add"] +vjp_interface_implementation_gen_op_list = ["tanh", "mean", "divide", "add"] diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index 95eb4678607..b41cbdab519 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -15,7 +15,7 @@ #include "paddle/fluid/ir/dialect/pd_attribute.h" #include "paddle/fluid/ir/dialect/pd_op.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" -#include "paddle/fluid/primitive/type/static_tensor.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" #include "paddle/ir/core/op_base.h" #include "paddle/phi/common/int_array.h" @@ -23,5 +23,32 @@ // this file will be generated in pd_op.cc namespace paddle { -namespace dialect {} // namespace dialect +namespace dialect { +std::vector> SumOp::Vjp( + ir::Operation* op, + const std::vector>& out_grads, + const std::vector>& stop_gradients) { + SumOp op_obj = op->dyn_cast(); + Tensor x(std::make_shared(op_obj.x())); + Tensor out_grad(std::make_shared(out_grads[0][0])); + + IntArray axis = op_obj.axis() + .GetDefiningOp() + ->attribute("value") + .dyn_cast() + .data(); + bool keepdim = op->attribute("keepdim").dyn_cast().data(); + bool reduce_all = false; + std::vector> tensor_res = primitive::sum_vjp( + x, out_grad, axis, keepdim, reduce_all, stop_gradients); + std::vector> res(1, std::vector(1)); + if (tensor_res[0][0].defined()) { + res[0][0] = + std::static_pointer_cast(tensor_res[0][0].impl()) + ->getValue() + .dyn_cast(); + } + return res; +} +} // namespace dialect } // namespace paddle diff --git a/paddle/fluid/primitive/CMakeLists.txt b/paddle/fluid/primitive/CMakeLists.txt index 5134cb01349..aab7919dfe4 100644 --- a/paddle/fluid/primitive/CMakeLists.txt +++ b/paddle/fluid/primitive/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(utils) add_subdirectory(backend) add_subdirectory(rule) diff --git a/paddle/fluid/primitive/backend/eager_backend.cc b/paddle/fluid/primitive/backend/eager_backend.cc index 5c06c0143f6..ca2184c49a6 100644 --- a/paddle/fluid/primitive/backend/eager_backend.cc +++ b/paddle/fluid/primitive/backend/eager_backend.cc @@ -19,8 +19,6 @@ namespace paddle { namespace primitive { -namespace backend { -namespace experimental {} // namespace experimental -} // namespace backend +namespace backend {} // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/eager_backend.h b/paddle/fluid/primitive/backend/eager_backend.h index 1522bd1dfc3..094487bb2b1 100644 --- a/paddle/fluid/primitive/backend/eager_backend.h +++ b/paddle/fluid/primitive/backend/eager_backend.h @@ -21,8 +21,6 @@ namespace paddle { namespace primitive { -namespace backend { -namespace experimental {} // namespace experimental -} // namespace backend +namespace backend {} // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index 7eaeb599326..539237c2277 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -15,65 +15,177 @@ #include "paddle/fluid/primitive/backend/static_backend.h" #include "paddle/fluid/ir/dialect/pd_api.h" #include "paddle/fluid/primitive/primitive/primitive.h" -#include "paddle/fluid/primitive/type/static_tensor.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" namespace paddle { namespace primitive { namespace backend { -namespace experimental { -using StaticTensor = paddle::primitive::experimental::StaticTensor; +using LazyTensor = paddle::primitive::LazyTensor; template <> -Tensor tanh_grad(const Tensor& out, const Tensor& grad_out) { - ir::OpResult out_res = std::static_pointer_cast(out.impl()) +Tensor tanh_grad(const Tensor& out, const Tensor& grad_out) { + ir::OpResult out_res = std::static_pointer_cast(out.impl()) ->getValue() .dyn_cast(); ir::OpResult grad_out_res = - std::static_pointer_cast(grad_out.impl()) + std::static_pointer_cast(grad_out.impl()) ->getValue() .dyn_cast(); ir::OpResult op_res = paddle::dialect::tanh_grad(out_res, grad_out_res); - return Tensor( - std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } template <> -Tensor mean_grad(const Tensor& x, - const Tensor& out_grad, - const IntArray& axis, - bool keepdim, - bool reduce_all) { - ir::OpResult x_res = std::static_pointer_cast(x.impl()) +Tensor mean_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) ->getValue() .dyn_cast(); ir::OpResult out_grad_res = - std::static_pointer_cast(out_grad.impl()) + std::static_pointer_cast(out_grad.impl()) ->getValue() .dyn_cast(); ir::OpResult op_res = paddle::dialect::mean_grad( x_res, out_grad_res, axis.GetData(), keepdim, reduce_all); - return Tensor( - std::make_shared(op_res)); + return Tensor(std::make_shared(op_res)); } template <> -std::tuple add_grad(const Tensor& x, - const Tensor& y, - const Tensor& out_grad, - int axis) { - ir::OpResult x_res = std::static_pointer_cast(x.impl()) +Tensor divide(const Tensor& x, const Tensor& y) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) ->getValue() .dyn_cast(); - ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::divide(x_res, y_res); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor add(const Tensor& x, const Tensor& y) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::add(x_res, y_res); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor multiply(const Tensor& x, const Tensor& y) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::multiply(x_res, y_res); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor elementwise_pow(const Tensor& x, const Tensor& y) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::elementwise_pow(x_res, y_res); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor scale(const Tensor& x, + const Scalar& scale, + float bias, + bool bias_after_scale) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = + paddle::dialect::scale(x_res, scale.to(), bias, bias_after_scale); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor sum(const Tensor& x, + const IntArray& axis, + phi::DataType dtype, + bool keepdim) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = + paddle::dialect::sum(x_res, axis.GetData(), dtype, keepdim); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor full(const IntArray& shape, + const Scalar& value, + phi::DataType dtype, + phi::Place place) { + ir::OpResult op_res = + paddle::dialect::full(shape.GetData(), value.to(), dtype, place); + return Tensor(std::make_shared(op_res)); +} + +template <> +std::tuple reshape(const Tensor& x, + const IntArray& shape) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + std::tuple op_res = + paddle::dialect::reshape(x_res, shape.GetData()); + return std::make_tuple( + Tensor(std::make_shared(std::get<0>(op_res))), + Tensor(std::make_shared(std::get<1>(op_res)))); +} + +template <> +Tensor expand(const Tensor& x, const IntArray& shape) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::expand(x_res, shape.GetData()); + return Tensor(std::make_shared(op_res)); +} + +template <> +Tensor tile(const Tensor& x, const IntArray& repeat_times) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::tile(x_res, repeat_times.GetData()); + return Tensor(std::make_shared(op_res)); +} + +template <> +std::tuple add_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult y_res = std::static_pointer_cast(y.impl()) ->getValue() .dyn_cast(); ir::OpResult out_grad_res = - std::static_pointer_cast(out_grad.impl()) + std::static_pointer_cast(out_grad.impl()) ->getValue() .dyn_cast(); @@ -81,12 +193,55 @@ std::tuple add_grad(const Tensor& x, paddle::dialect::add_grad(x_res, y_res, out_grad_res, axis); return std::make_tuple( - Tensor(std::make_shared( - std::get<0>(op_res))), - Tensor(std::make_shared( - std::get<1>(op_res)))); + Tensor(std::make_shared(std::get<0>(op_res))), + Tensor(std::make_shared(std::get<1>(op_res)))); +} + +template <> +std::tuple divide_grad(const Tensor& x, + const Tensor& y, + const Tensor& out, + const Tensor& out_grad, + int axis) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult y_res = std::static_pointer_cast(y.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult out_res = std::static_pointer_cast(out.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult out_grad_res = + std::static_pointer_cast(out_grad.impl()) + ->getValue() + .dyn_cast(); + + std::tuple op_res = + paddle::dialect::divide_grad(x_res, y_res, out_res, out_grad_res, axis); + + return std::make_tuple( + Tensor(std::make_shared(std::get<0>(op_res))), + Tensor(std::make_shared(std::get<1>(op_res)))); +} + +template <> +Tensor sum_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all) { + ir::OpResult x_res = std::static_pointer_cast(x.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult out_grad_res = + std::static_pointer_cast(out_grad.impl()) + ->getValue() + .dyn_cast(); + ir::OpResult op_res = paddle::dialect::sum_grad( + x_res, out_grad_res, axis.GetData(), keepdim, reduce_all); + return Tensor(std::make_shared(op_res)); } -} // namespace experimental } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/static_backend.h b/paddle/fluid/primitive/backend/static_backend.h index 063532d8cd2..1e484aa35e6 100644 --- a/paddle/fluid/primitive/backend/static_backend.h +++ b/paddle/fluid/primitive/backend/static_backend.h @@ -23,7 +23,6 @@ namespace paddle { namespace primitive { namespace backend { -namespace experimental { using Tensor = paddle::Tensor; using IntArray = paddle::experimental::IntArray; @@ -43,7 +42,60 @@ std::tuple add_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, int axis); -} // namespace experimental + +template +Tensor divide(const Tensor& x, const Tensor& y); + +template +Tensor add(const Tensor& x, const Tensor& y); + +template +Tensor multiply(const Tensor& x, const Tensor& y); + +template +Tensor elementwise_pow(const Tensor& x, const Tensor& y); + +template +Tensor scale(const Tensor& x, + const Scalar& scale = 1.0, + float bias = 0.0, + bool bias_after_scale = true); + +template +Tensor sum(const Tensor& x, + const IntArray& axis = {}, + phi::DataType dtype = phi::DataType::UNDEFINED, + bool keepdim = false); + +template +Tensor full(const IntArray& shape, + const Scalar& value, + phi::DataType dtype = phi::DataType::FLOAT32, + phi::Place place = phi::CPUPlace()); + +template +std::tuple reshape(const Tensor& x, const IntArray& shape); + +template +Tensor expand(const Tensor& x, const IntArray& shape); + +template +Tensor tile(const Tensor& x, const IntArray& repeat_times = {}); + +template +std::tuple divide_grad(const Tensor& x, + const Tensor& y, + const Tensor& out, + const Tensor& out_grad, + int axis); + +template +Tensor sum_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all); + } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/primitive/primitive.h b/paddle/fluid/primitive/primitive/primitive.h index a15334851c8..80510ed921c 100644 --- a/paddle/fluid/primitive/primitive/primitive.h +++ b/paddle/fluid/primitive/primitive/primitive.h @@ -18,12 +18,71 @@ namespace paddle { namespace primitive { -namespace experimental { // why exist this file? // We provide this file to divide // the primitive ops set in the backend. // It will be called by the vjp composite // rules and composite ops rules. -} // namespace experimental +using Tensor = paddle::Tensor; +using IntArray = paddle::experimental::IntArray; + +template +Tensor divide(const Tensor& x, const Tensor& y) { + return backend::divide(x, y); +} + +template +Tensor add(const Tensor& x, const Tensor& y) { + return backend::add(x, y); +} + +template +Tensor multiply(const Tensor& x, const Tensor& y) { + return backend::multiply(x, y); +} + +template +Tensor elementwise_pow(const Tensor& x, const Tensor& y) { + return backend::elementwise_pow(x, y); +} + +template +Tensor scale(const Tensor& x, + const Scalar& scale = 1.0, + float bias = 0.0, + bool bias_after_scale = true) { + return backend::scale(x, scale, bias, bias_after_scale); +} + +template +Tensor sum(const Tensor& x, + const IntArray& axis = {}, + phi::DataType dtype = phi::DataType::UNDEFINED, + bool keepdim = false) { + return backend::sum(x, axis, dtype, keepdim); +} + +template +Tensor full(const IntArray& shape, + const Scalar& value, + phi::DataType dtype = phi::DataType::FLOAT32, + phi::Place place = phi::CPUPlace()) { + return backend::full(shape, value, dtype, place); +} + +template +std::tuple reshape(const Tensor& x, const IntArray& shape) { + return backend::reshape(x, shape); +} + +template +Tensor expand(const Tensor& x, const IntArray& shape) { + return backend::expand(x, shape); +} + +template +Tensor tile(const Tensor& x, const IntArray& repeat_times = {}) { + return backend::tile(x, repeat_times); +} } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/CMakeLists.txt b/paddle/fluid/primitive/rule/vjp/CMakeLists.txt index eb72b0c9ecc..3243228d112 100644 --- a/paddle/fluid/primitive/rule/vjp/CMakeLists.txt +++ b/paddle/fluid/primitive/rule/vjp/CMakeLists.txt @@ -1,7 +1,7 @@ -file(GLOB VJP_SRCS "*.cc") - +file(GLOB VJP_SRCS "vjp.cc") cc_library( primitive_vjp_experimental SRCS ${VJP_SRCS} - DEPS primitive_backend_static_experimental) + DEPS primitive_backend_static_experimental static_global_utils + primitive_static_utils_experimental) add_dependencies(primitive_vjp_experimental pd_dialect) diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h new file mode 100644 index 00000000000..6ee9c5880b6 --- /dev/null +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -0,0 +1,137 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif + +#include +#include + +#include "paddle/fluid/primitive/primitive/primitive.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/fluid/primitive/utils/utils.h" + +namespace paddle { +namespace primitive { +namespace details { + +template +void divide_grad(const Tensor& x, + const Tensor& y, + const Tensor& out, + const Tensor& out_grad, + int axis, + Tensor* dx, + Tensor* dy) { + if (dy) { + // dy = -(x/y^2) * dout + auto denominator = + elementwise_pow(y, full(y.shape(), 2.0, y.dtype(), y.place())); + auto dy_res = scale( + multiply(divide(x, denominator), out_grad), -1.0, 0.0, true); + if (x.dims() != y.dims()) { + // Maybe need reduce here + phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + if (!reduce_dim.size()) { + set_output(dy_res, dy); + } else { + auto dy_reduce_res = + sum(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); + auto reshape_res = reshape(dy_reduce_res, phi::vectorize(y.dims())); + auto dy_tmp = std::get<0>(reshape_res); + set_output(dy_tmp, dy); + } + } else { + set_output(dy_res, dy); + } + } // indicate we will compute dy + if (dx) { + // dx = (1/y) * dout + auto one_tensor = full(phi::vectorize(y.dims()), 1.0, y.dtype()); + auto dx_res = multiply(divide(one_tensor, y), out_grad); + if (y.dims() != x.dims()) { + // Maybe need reduce here + auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + if (!reduce_dim.size()) { + set_output(dx_res, dx); + } else { + auto dx_reduce_res = + sum(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_reduce_reshape_res = + reshape(dx_reduce_res, phi::vectorize(x.dims())); + auto dx_tmp = std::get<0>(dx_reduce_reshape_res); + set_output(dx_tmp, dx); + } + + } else { + set_output(dx_res, dx); + } + } // indicate we will compute dx +} + +template +void sum_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all, + Tensor* x_grad) { + if (!x_grad) { + return; + } + std::vector x_dim = phi::vectorize(x.dims()); + int64_t axis_size = axis.size(); + int64_t x_dim_size = x_dim.size(); + reduce_all = false; + if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { + reduce_all = true; + } else { + reduce_all = false; + } + auto x_grad_tmp = Tensor(); + if (x_dim_size == 1) { + x_grad_tmp = expand(out_grad, IntArray(x_dim)); + } else { + if (!keepdim) { + auto axis_ = std::vector(); + if (reduce_all) { + for (int64_t i = 0; i < x_dim_size; i++) { + axis_.push_back(i); + } + } else { + axis_ = axis.GetData(); + for (int64_t i = 0; i < axis_size; i++) { + if (axis[i] < 0) { + axis_[i] = axis[i] + x_dim_size; + } + } + } + auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_); + auto out_grad_reshape_res = reshape(out_grad, out_grad_shape); + auto out_grad_ = std::get<0>(out_grad_reshape_res); + x_grad_tmp = expand(out_grad, IntArray(x_dim)); + } else { + x_grad_tmp = expand(out_grad, IntArray(x_dim)); + } + } + + set_output(x_grad_tmp, x_grad); +} + +} // namespace details +} // namespace primitive +} // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index c8287a6b3a5..59fabfc87cf 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -14,15 +14,17 @@ #include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/ir/dialect/pd_api.h" +#include "paddle/fluid/prim/utils/static/static_global_utils.h" #include "paddle/fluid/primitive/backend/static_backend.h" -#include "paddle/fluid/primitive/type/static_tensor.h" +#include "paddle/fluid/primitive/rule/vjp/details.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/fluid/primitive/utils/utils.h" #include "paddle/ir/core/operation.h" // TODO(wanghao107): // op's vjp will be auto generated. namespace paddle { namespace primitive { -namespace experimental { std::vector> tanh_vjp( const Tensor& out, @@ -31,16 +33,13 @@ std::vector> tanh_vjp( std::vector> vjp_res( 1, std::vector(1)); // get tanh_grad res. - Tensor op_res = - backend::experimental::tanh_grad( - out, grad_out); + Tensor op_res = backend::tanh_grad(out, grad_out); // set op stop_gradient info // TODO(wanghao107): Replace with more generic code. // Support set stop_gradients for all ops. ir::Operation* grad_op = - std::static_pointer_cast( - op_res.impl()) + std::static_pointer_cast(op_res.impl()) ->getValue() .dyn_cast() .owner(); @@ -76,16 +75,14 @@ std::vector> mean_vjp( std::vector> vjp_res( 1, std::vector(1)); // get mean_grad res. - Tensor op_res = - backend::experimental::mean_grad( - x, out_grad, axis, keepdim, reduce_all); + Tensor op_res = backend::mean_grad( + x, out_grad, axis, keepdim, reduce_all); // set op stop_gradient info // TODO(wanghao107): Replace with more generic code. // Support set stop_gradients for all ops. ir::Operation* grad_op = - std::static_pointer_cast( - op_res.impl()) + std::static_pointer_cast(op_res.impl()) ->getValue() .dyn_cast() .owner(); @@ -119,20 +116,18 @@ std::vector> add_vjp( const std::vector>& stop_gradients) { std::vector> vjp_res( 2, std::vector(1)); - // get mean_grad res. + // get add_grad res. std::tuple op_res = - backend::experimental::add_grad( - x, y, out_grad, axis); + backend::add_grad(x, y, out_grad, axis); // set op stop_gradient info // TODO(wanghao107): Replace with more generic code. // Support set stop_gradients for all ops. - ir::Operation* grad_op = - std::static_pointer_cast( - std::get<0>(op_res).impl()) - ->getValue() - .dyn_cast() - .owner(); + ir::Operation* grad_op = std::static_pointer_cast( + std::get<0>(op_res).impl()) + ->getValue() + .dyn_cast() + .owner(); std::vector ir_stop_gradients(2); for (size_t i = 0; i < 2; i++) { if (stop_gradients[i][0]) { @@ -152,6 +147,57 @@ std::vector> add_vjp( vjp_res[1][0] = !stop_gradients[1][0] ? std::get<1>(op_res) : vjp_res[1][0]; return vjp_res; } -} // namespace experimental + +std::vector> divide_vjp( + const Tensor& x, + const Tensor& y, + const Tensor& out, + const Tensor& out_grad, + int axis, + const std::vector>& stop_gradients) { + std::vector> vjp_res( + 2, std::vector(1)); + if (!paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) { + // get divide_grad res. + std::tuple op_res = + backend::divide_grad(x, y, out, out_grad, axis); + // construct vjp result by op result and stop_gradients info + vjp_res[0][0] = !stop_gradients[0][0] ? std::get<0>(op_res) : vjp_res[0][0]; + vjp_res[1][0] = !stop_gradients[1][0] ? std::get<1>(op_res) : vjp_res[1][0]; + } else { + // get divide_grad prim mode res. + Tensor* dx = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr; + Tensor* dy = !stop_gradients[1][0] ? &vjp_res[1][0] : nullptr; + details::divide_grad(x, y, out, out_grad, axis, dx, dy); + } + return vjp_res; +} + +std::vector> sum_vjp( + const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all, + const std::vector>& stop_gradients) { + std::vector> vjp_res( + 1, std::vector(1)); + if (!paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) { + // get sum_grad res. + Tensor op_res = backend::sum_grad( + x, out_grad, axis, keepdim, reduce_all); + // construct vjp result by op result and stop_gradients info + if (!stop_gradients[0][0]) { + vjp_res[0][0] = op_res; + } + } else { + // get divide_grad prim mode res. + Tensor* x_grad = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr; + details::sum_grad( + x, out_grad, axis, keepdim, reduce_all, x_grad); + } + return vjp_res; +} + } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/vjp.h b/paddle/fluid/primitive/rule/vjp/vjp.h index 8ef03c39c6e..94b1f9d67cc 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.h +++ b/paddle/fluid/primitive/rule/vjp/vjp.h @@ -14,13 +14,6 @@ #pragma once -#ifndef _USE_MATH_DEFINES -#define _USE_MATH_DEFINES -#endif - -#include -#include - #include "paddle/fluid/primitive/primitive/primitive.h" #include "paddle/ir/core/value.h" #include "paddle/phi/api/include/tensor.h" @@ -28,7 +21,6 @@ namespace paddle { namespace primitive { -namespace experimental { using IntArray = paddle::experimental::IntArray; // TODO(wanghao107): @@ -53,11 +45,21 @@ std::vector> add_vjp( int axis, const std::vector>& stop_gradients); -namespace details { -// NOTE: this namespace will store -// primitive ops grad composite rules. +std::vector> divide_vjp( + const Tensor& x, + const Tensor& y, + const Tensor& out, + const Tensor& out_grad, + int axis, + const std::vector>& stop_gradients); + +std::vector> sum_vjp( + const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all, + const std::vector>& stop_gradients); -} // namespace details -} // namespace experimental } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/type/static_tensor.h b/paddle/fluid/primitive/type/lazy_tensor.h similarity index 77% rename from paddle/fluid/primitive/type/static_tensor.h rename to paddle/fluid/primitive/type/lazy_tensor.h index b13ee48e79c..b4387d11f5c 100644 --- a/paddle/fluid/primitive/type/static_tensor.h +++ b/paddle/fluid/primitive/type/lazy_tensor.h @@ -22,34 +22,36 @@ namespace paddle { namespace primitive { -namespace experimental { -class StaticTensor : public phi::ExtendedTensor, - public phi::TypeInfoTraits { +class LazyTensor : public phi::ExtendedTensor, + public phi::TypeInfoTraits { public: - explicit StaticTensor(ir::Value value) + explicit LazyTensor(ir::Value value) : value_(value), dims_(value.type().dyn_cast().dims()) {} - static const char* name() { return "StaticTensor"; } + static const char* name() { return "LazyTensor"; } const phi::DDim& dims() const override { return dims_; } int64_t numel() const override { return product(dims()); } DataType dtype() const override { - return paddle::dialect::TransToPhiDataType(value_.type()); + return paddle::dialect::TransToPhiDataType( + value_.type().dyn_cast().dtype()); } ir::Value getValue() const { return value_; } + const phi::Place& place() const override { return place_; } + bool initialized() const override { return value_.impl() != nullptr; } private: ir::Value value_; mutable phi::DDim dims_; + phi::Place place_; }; -} // namespace experimental } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/utils/CMakeLists.txt b/paddle/fluid/primitive/utils/CMakeLists.txt new file mode 100644 index 00000000000..044198c827f --- /dev/null +++ b/paddle/fluid/primitive/utils/CMakeLists.txt @@ -0,0 +1,10 @@ +if(WITH_PYTHON OR NOT ON_INFER) + cc_library( + primitive_eager_utils_experimental + SRCS eager_utils.cc + DEPS phi common_infer_shape_functions) +endif() +cc_library( + primitive_static_utils_experimental + SRCS static_utils.cc + DEPS phi common_infer_shape_functions) diff --git a/paddle/fluid/primitive/utils/eager_utils.cc b/paddle/fluid/primitive/utils/eager_utils.cc new file mode 100644 index 00000000000..e9ad10407e3 --- /dev/null +++ b/paddle/fluid/primitive/utils/eager_utils.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" +#include "paddle/fluid/primitive/utils/utils.h" + +namespace paddle { +namespace primitive { +template <> +void set_output(const paddle::Tensor& x_tmp, paddle::Tensor* x) { + x->set_impl(x_tmp.impl()); + x->set_autograd_meta(x_tmp.mutable_autograd_meta()); +} + +} // namespace primitive +} // namespace paddle diff --git a/paddle/fluid/primitive/utils/static_utils.cc b/paddle/fluid/primitive/utils/static_utils.cc new file mode 100644 index 00000000000..40cbbc8d21e --- /dev/null +++ b/paddle/fluid/primitive/utils/static_utils.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/fluid/primitive/utils/utils.h" + +namespace paddle { +namespace primitive { +template <> +void set_output(const paddle::Tensor& x_tmp, paddle::Tensor* x) { + x->set_impl(x_tmp.impl()); +} + +} // namespace primitive +} // namespace paddle diff --git a/paddle/fluid/primitive/utils/utils.h b/paddle/fluid/primitive/utils/utils.h new file mode 100644 index 00000000000..e1765357aa9 --- /dev/null +++ b/paddle/fluid/primitive/utils/utils.h @@ -0,0 +1,91 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/operators/common_infer_shape_functions.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/core/ddim.h" + +namespace paddle { +namespace primitive { + +template +void set_output(const Tensor& x_tmp, Tensor* x); + +// This fucction compute unsqueeze dims for reshape to replace unsqueeze. +static std::vector get_unsqueeze_dims( + const Tensor& origin, const std::vector& axis) { + auto origin_dims = origin.shape(); + auto total_shape_size = origin_dims.size() + axis.size(); + std::vector result; + size_t j = 0, k = 0; + for (size_t i = 0; i < total_shape_size; ++i) { + if (j < axis.size() && axis[j] == int64_t(i)) { + result.push_back(1); + j++; + } else { + PADDLE_ENFORCE_LT( + k, + origin_dims.size(), + platform::errors::OutOfRange("Your index [%lu] exceeds the number of " + "elements in origin_dims[%lu].", + k, + origin_dims.size())); + result.push_back(origin_dims[k]); + k++; + } + } + return result; +} + +// These method don't need to be specified +static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims, + const phi::DDim& in_dims) { + std::vector result; + int bat = dout_dims.size() - in_dims.size(); + for (int i = 0; i < bat; ++i) { + result.push_back(i); + } + for (int i = 0; i < in_dims.size(); ++i) { + if (in_dims[i] == 1) { + result.push_back(i + bat); + } else { + PADDLE_ENFORCE_EQ( + in_dims[i], + dout_dims[i + bat], + platform::errors::InvalidArgument( + "ReduceDims dimension mismatch. Operands could " + "not be broadcast together with the shape of dout = [%s] and " + "the shape of in_dims = [%s]. Received [%d] in X is not equal to " + "[%d] in Y at i:%d.", + dout_dims, + in_dims, + dout_dims[i + bat], + in_dims[i], + i)); + } + } + return phi::make_ddim(result); +} + +static phi::DDim get_reduce_dims(const phi::DDim& x_dims, + const phi::DDim& y_dims) { + auto out_dims = paddle::operators::details::BroadcastTwoDims(x_dims, y_dims); + return get_reduce_dims_from_out(out_dims, x_dims); +} + +} // namespace primitive +} // namespace paddle diff --git a/test/prim/new_ir_prim/test_vjp_prim.py b/test/prim/new_ir_prim/test_vjp_prim.py new file mode 100644 index 00000000000..8c2fd4ebd76 --- /dev/null +++ b/test/prim/new_ir_prim/test_vjp_prim.py @@ -0,0 +1,167 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle import ir +from paddle.fluid.core import call_vjp + +paddle.enable_static() + + +def get_ir_program_0(): + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x = paddle.tensor.fill_constant( + shape=[1, 4], dtype='float32', value=2.0 + ) + x.stop_gradient = False + y = paddle.tensor.fill_constant(shape=[4], dtype='float32', value=1.0) + y.stop_gradiable = False + dout = paddle.tensor.fill_constant( + shape=[1, 4], dtype='float32', value=1.0 + ) + dout.stop_gradiable = False + out = paddle.divide(x, y) + newir_program = ir.translate_to_new_ir(main_program.desc) + return newir_program + + +def get_ir_program_1(): + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x = paddle.tensor.fill_constant( + shape=[4, 5], dtype='float32', value=2.0 + ) + x.stop_gradient = False + dout = paddle.tensor.fill_constant( + shape=[1], dtype='float32', value=1.0 + ) + dout.stop_gradiable = False + out = paddle.sum(x) + newir_program = ir.translate_to_new_ir(main_program.desc) + return newir_program + + +class TestVjpPrim(unittest.TestCase): + def test_divide_grad_prim_case1(self): + newir_program = get_ir_program_0() + paddle.fluid.core._set_prim_backward_enabled(True) + dout = newir_program.block().ops[-2].result(0) + out_grads = [[dout]] + stop_gradients = [[False], [False]] + divide_op = newir_program.block().ops[-1] + with paddle.ir.core.program_guard(newir_program): + grad_outs = call_vjp(divide_op, out_grads, stop_gradients) + reshape_op2 = newir_program.block().ops[-1] + reshape_op1 = newir_program.block().ops[-8] + self.assertEqual(len(grad_outs), 2) + self.assertEqual(len(newir_program.block().ops), 21) + self.assertEqual(reshape_op2.result(0), grad_outs[0][0]) + self.assertEqual(reshape_op1.result(0), grad_outs[1][0]) + all_op_names = [ + "pd.full", + "pd.full", + "pd.full", + "pd.divide", + "pd.full", + "pd.elementwise_pow", + "pd.divide", + "pd.multiply", + "pd.full", + "pd.scale", + "pd.full_int_array", + "pd.sum", + "pd.full_int_array", + "pd.reshape", + "pd.full", + "pd.divide", + "pd.multiply", + "pd.full_int_array", + "pd.sum", + "pd.full_int_array", + "pd.reshape", + ] + for idx, op in enumerate(newir_program.block().ops): + self.assertEqual(op.name(), all_op_names[idx]) + + def test_divide_grad_no_prim(self): + newir_program = get_ir_program_0() + paddle.fluid.core._set_prim_backward_enabled(False) + dout = newir_program.block().ops[-2].result(0) + out_grads = [[dout]] + stop_gradients = [[False], [False]] + divide_op = newir_program.block().ops[-1] + with paddle.ir.core.program_guard(newir_program): + grad_outs = call_vjp(divide_op, out_grads, stop_gradients) + self.assertEqual(len(grad_outs), 2) + self.assertEqual( + grad_outs[0][0].get_defining_op().name(), "pd.divide_grad" + ) + self.assertEqual( + grad_outs[1][0].get_defining_op().name(), "pd.divide_grad" + ) + self.assertEqual(len(newir_program.block().ops), 5) + + def test_sum_grad_prim(self): + newir_program = get_ir_program_1() + paddle.fluid.core._set_prim_backward_enabled(True) + dout = newir_program.block().ops[-2].result(0) + out_grads = [[dout]] + stop_gradients = [[False]] + sum_op = newir_program.block().ops[-1] + with paddle.ir.core.program_guard(newir_program): + grad_outs = call_vjp(sum_op, out_grads, stop_gradients) + expand_op = newir_program.block().ops[-1] + self.assertEqual(len(grad_outs), 1) + self.assertEqual(len(newir_program.block().ops), 8) + self.assertEqual(expand_op.result(0), grad_outs[0][0]) + all_op_names = [ + "pd.full", + "pd.full", + "pd.full_int_array", + "pd.sum", + "pd.full_int_array", + "pd.reshape", + "pd.full_int_array", + "pd.expand", + ] + for idx, op in enumerate(newir_program.block().ops): + self.assertEqual(op.name(), all_op_names[idx]) + + def test_sum_grad_no_prim(self): + newir_program = get_ir_program_1() + paddle.fluid.core._set_prim_backward_enabled(False) + dout = newir_program.block().ops[-2].result(0) + out_grads = [[dout]] + stop_gradients = [[False]] + sum_op = newir_program.block().ops[-1] + with paddle.ir.core.program_guard(newir_program): + grad_outs = call_vjp(sum_op, out_grads, stop_gradients) + self.assertEqual(len(grad_outs), 1) + self.assertEqual( + grad_outs[0][0].get_defining_op().name(), "pd.sum_grad" + ) + self.assertEqual(len(newir_program.block().ops), 6) + + +if __name__ == "__main__": + unittest.main() -- GitLab