From dd5f33e1e679499c7ac3cfb1764956e37906e6f8 Mon Sep 17 00:00:00 2001 From: wangruting Date: Wed, 1 Feb 2023 07:24:35 +0000 Subject: [PATCH] original code --- .../prim/api/auto_code_generated/prim_base.py | 4 + .../api/generated/prim_api/static_prim_api.cc | 76 +++++ .../manual/backward/composite_backward_api.h | 310 +++++++++++++++++- paddle/fluid/prim/api/manual/utils/utils.h | 109 +++++- 4 files changed, 487 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/prim/api/auto_code_generated/prim_base.py b/paddle/fluid/prim/api/auto_code_generated/prim_base.py index d1ad94a7c3..fee1318777 100644 --- a/paddle/fluid/prim/api/auto_code_generated/prim_base.py +++ b/paddle/fluid/prim/api/auto_code_generated/prim_base.py @@ -25,6 +25,10 @@ white_ops_list = [ "divide", "sum", "exp", + "matmul", + "dot", + "transpose", + "add", ] inplace_out_type_map = { diff --git a/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc b/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc index 30a82b4989..262dca6350 100644 --- a/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc +++ b/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc @@ -38,6 +38,24 @@ namespace paddle { namespace prim { +template <> +Tensor add(const Tensor& x, const Tensor& y) { + Tensor out = empty({}, phi::DataType::FLOAT32, paddle::Place()); + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("elementwise_add"); + op->SetInput("X", + {std::static_pointer_cast(x.impl())->Name()}); + op->SetInput("Y", + {std::static_pointer_cast(y.impl())->Name()}); + op->SetOutput( + "Out", {std::static_pointer_cast(out.impl())->Name()}); + op->CheckAttrs(); + op->InferVarType(block); + op->InferShape(*block); + return out; +} + template <> Tensor pow(const Tensor& x, const Scalar& y) { Tensor out = empty({}, phi::DataType::FLOAT32, paddle::Place()); @@ -77,6 +95,29 @@ Tensor scale(const Tensor& x, return out; } +template <> +Tensor matmul(const Tensor& x, + const Tensor& y, + bool transpose_x, + bool transpose_y) { + Tensor out = empty({}, phi::DataType::FLOAT32, paddle::Place()); + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("MatMul"); + op->SetInput("X", + {std::static_pointer_cast(x.impl())->Name()}); + op->SetInput("Y", + {std::static_pointer_cast(y.impl())->Name()}); + op->SetOutput( + "Out", {std::static_pointer_cast(out.impl())->Name()}); + op->SetAttr("transpose_X", transpose_x); + op->SetAttr("transpose_Y", transpose_y); + op->CheckAttrs(); + op->InferVarType(block); + op->InferShape(*block); + return out; +} + template <> Tensor multiply(const Tensor& x, const Tensor& y) { // Grad infershape @@ -236,6 +277,41 @@ Tensor reshape(const Tensor& x, const IntArray& shape) { return out; } +template <> +Tensor transpose(const Tensor& x, const std::vector& perm) { + Tensor out = empty({}, phi::DataType::FLOAT32, paddle::Place()); + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("transpose"); + op->SetInput("X", + {std::static_pointer_cast(x.impl())->Name()}); + op->SetOutput( + "Out", {std::static_pointer_cast(out.impl())->Name()}); + op->SetAttr("axis", perm); + op->CheckAttrs(); + op->InferVarType(block); + op->InferShape(*block); + return out; +} + +template <> +Tensor dot(const Tensor& x, const Tensor& y) { + Tensor out = empty({}, phi::DataType::FLOAT32, paddle::Place()); + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("dot"); + op->SetInput("X", + {std::static_pointer_cast(x.impl())->Name()}); + op->SetInput("Y", + {std::static_pointer_cast(x.impl())->Name()}); + op->SetOutput( + "Out", {std::static_pointer_cast(out.impl())->Name()}); + op->CheckAttrs(); + op->InferVarType(block); + op->InferShape(*block); + return out; +} + template <> Tensor exp(const Tensor& x) { Tensor out = empty({}, phi::DataType::FLOAT32, paddle::Place()); diff --git a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h index 99ef82d088..770af8a21f 100644 --- a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" #include "paddle/fluid/prim/api/manual/prim_api/prim_api.h" #include "paddle/fluid/prim/api/manual/utils/utils.h" @@ -170,7 +171,7 @@ void divide_grad(const Tensor& x, Tensor* dx, Tensor* dy) { if (dy) { - // dy = -(x/y^2) * dout + // dy = -(x/y^2) * grad_out auto tmp0 = pow(y, 2.0); auto tmp1 = divide(x, tmp0); auto tmp2 = scale(tmp1, -1.0, 0.0, true); @@ -191,7 +192,7 @@ void divide_grad(const Tensor& x, } } // indicate we will compute dy if (dx) { - // dx = (1/y) * dout + // dx = (1/y) * grad_out auto one_tensor = full(phi::vectorize(y.dims()), 1.0, y.dtype()); auto tmp0 = divide(one_tensor, y); auto dx_res = multiply(tmp0, out_grad); @@ -303,5 +304,310 @@ void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { } } +template +void matmul_double_grad(const Tensor& x, + const Tensor& y, + const Tensor& grad_out, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + bool transpose_x, + bool transpose_y, + Tensor* x_grad, + Tensor* y_grad, + Tensor* grad_out_grad) { + // Get dims from the input x, y, output_grad + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector grad_out_dims = vectorize(grad_out.dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = grad_out_dims.size(); + + // Case1 : x's or y's dim = 1 + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal( + x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); + } + + if (!is_broadcast) { + // Case2: no broadcast or no batch size + Tensor x_help = x; + Tensor y_help = y; + Tensor grad_out_help = grad_out; + + reshape_xyout_to_matrixsequence( + x_help, y_help, grad_out_help, transpose_x, transpose_y); + + phi::DDim x_grad_dims; + if (x_grad) { + x_grad_dims = x_grad->dims(); + if (x_grad_dims != x_help.dims()) { + *x_grad = reshape(*x_grad, IntArray(phi::vectorize(x_help.dims()))); + } + } + + phi::DDim y_grad_dims; + if (y_grad) { + y_grad_dims = y_grad->dims(); + if (y_grad_dims != y_help.dims()) { + *y_grad = reshape(*y_grad, IntArray(phi::vectorize(y_help.dims()))); + } + } + + phi::DDim dgrad_out_dims; + if (grad_out_grad) { + dgrad_out_dims = grad_out_grad->dims(); + if (dgrad_out_dims != grad_out_help.dims()) { + *grad_out_grad = reshape( + *grad_out_grad, IntArray(phi::vectorize(grad_out_help.dims()))); + } + } + + bool dgrad_out_flag = false; + if (grad_x_grad) { + auto grad_x_grad_mat = grad_x_grad.get(); + if (grad_x_grad_mat.dims() != x_help.dims()) { + grad_x_grad_mat = reshape(grad_x_grad_mat, + IntArray(phi::vectorize(x_help.dims()))); + } + if (y_grad) { + Tensor y_grad_tmp; + if (transpose_x && transpose_y) { + // y_grad = grad_out' * grad_x_grad' + auto tmp = + modify_dim_for_matmul(grad_out, true, grad_x_grad_mat, false); + y_grad_tmp = + matmul(std::get<0>(tmp), std::get<1>(tmp), true, true); + } else if (transpose_x) { + // y_grad = grad_x_grad * grad_out + auto tmp = + modify_dim_for_matmul(grad_x_grad_mat, false, grad_out, true); + y_grad_tmp = + matmul(std::get<0>(tmp), std::get<1>(tmp), false, false); + } else if (transpose_y) { + // y_grad = grad_out' * grad_x_grad + auto tmp = + modify_dim_for_matmul(grad_out, true, grad_x_grad_mat, true); + y_grad_tmp = + matmul(std::get<0>(tmp), std::get<1>(tmp), true, false); + } else { + // y_grad = grad_x_grad' * grad_out + auto tmp = + modify_dim_for_matmul(grad_x_grad_mat, true, grad_out, true); + y_grad_tmp = + matmul(std::get<0>(tmp), std::get<1>(tmp), true, false); + } + set_output(y_grad_tmp, y_grad); + } + + if (grad_out_grad) { + auto tmp = modify_dim_for_matmul(grad_x_grad_mat, true, y, false); + auto grad_out_grad_tmp = matmul( + std::get<0>(tmp), std::get<1>(tmp), transpose_x, transpose_y); + set_output(grad_out_grad_tmp, grad_out_grad); + } + } else if (!grad_x_grad && y_grad) { + auto y_grad_tmp = full(phi::vectorize(y.dims()), Scalar(0.0)); + set_output(y_grad_tmp, y_grad); + } + if (grad_y_grad) { + auto grad_y_grad_mat = grad_y_grad.get(); + if (grad_y_grad_mat.dims() != y_help.dims()) { + grad_y_grad_mat = reshape(grad_y_grad_mat, + IntArray(phi::vectorize(y_help.dims()))); + } + if (x_grad) { + Tensor x_grad_tmp; + if (transpose_x && transpose_y) { + // x_grad = grad_y_grad' * grad_out' + auto tmp = + modify_dim_for_matmul(grad_y_grad_mat, true, grad_out, false); + x_grad_tmp = + matmul(std::get<0>(tmp), std::get<1>(tmp), true, true); + } else if (transpose_x) { + // x_grad = grad_y_grad * grad_out' + auto tmp = + modify_dim_for_matmul(grad_y_grad_mat, false, grad_out, false); + x_grad_tmp = + matmul(std::get<0>(tmp), std::get<1>(tmp), false, true); + } else if (transpose_y) { + // x_grad = grad_out * grad_y_grad + auto tmp = + modify_dim_for_matmul(grad_out, false, grad_y_grad_mat, true); + x_grad_tmp = + matmul(std::get<0>(tmp), std::get<1>(tmp), false, false); + } else { + // x_grad = grad_out * grad_y_grad' + auto tmp = + modify_dim_for_matmul(grad_out, false, grad_y_grad_mat, false); + x_grad_tmp = + matmul(std::get<0>(tmp), std::get<1>(tmp), false, true); + } + set_output(x_grad_tmp, x_grad); + } + + if (grad_out_grad) { + auto tmp = modify_dim_for_matmul(x, true, grad_y_grad_mat, false); + auto grad_out_grad_tmp = matmul( + std::get<0>(tmp), std::get<1>(tmp), transpose_x, transpose_y); + auto output_tmp = add(grad_out_grad_tmp, *grad_out_grad); + set_output(output_tmp, grad_out_grad); + } + } else if (!grad_y_grad && x_grad) { + auto x_grad_tmp = full(phi::vectorize(x.dims()), Scalar(0.0)); + set_output(x_grad_tmp, x_grad); + } + if (grad_out_grad && !grad_x_grad && !grad_y_grad) { + auto grad_out_grad_tmp = + full(phi::vectorize(grad_out.dims()), Scalar(0.0)); + set_output(grad_out_grad_tmp, grad_out_grad); + } + + if (x_grad) { + if (x_grad_dims != x_help.dims()) { + *x_grad = reshape(*x_grad, IntArray(phi::vectorize(x_grad_dims))); + } + } + + if (y_grad) { + if (y_grad_dims != y_help.dims()) { + *y_grad = reshape(*y_grad, IntArray(phi::vectorize(y_grad_dims))); + } + } + + if (grad_out_grad) { + if (dgrad_out_dims != grad_out_help.dims()) { + *grad_out_grad = reshape(*grad_out_grad, + IntArray(phi::vectorize(dgrad_out_dims))); + } + } + + } else { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + + Tensor x_grad_help; + Tensor y_grad_help; + Tensor grad_out_grad_help; + + if (transpose_x) { + if (transpose_y) { + if (x_grad && grad_y_grad) { + x_grad_help = matmul(grad_y_grad.get(), grad_out, true, true); + } + if (y_grad && grad_x_grad) { + y_grad_help = matmul(grad_out, grad_x_grad.get(), true, true); + } + } else { + if (x_grad && grad_y_grad) { + x_grad_help = matmul(grad_y_grad.get(), grad_out, false, true); + } + if (y_grad && grad_x_grad) { + y_grad_help = matmul(grad_x_grad.get(), grad_out, false, false); + } + } + } else { + if (transpose_y) { + if (x_grad && grad_y_grad) { + x_grad_help = matmul(grad_out, grad_y_grad.get(), false, false); + } + if (y_grad && grad_x_grad) { + y_grad_help = matmul(grad_out, grad_x_grad.get(), true, false); + } + } else { + if (x_grad && grad_y_grad) { + x_grad_help = matmul(grad_out, grad_y_grad.get(), false, true); + } + if (y_grad && grad_x_grad) { + y_grad_help = matmul(grad_x_grad.get(), grad_out, true, false); + } + } + } + + // get help dims + const std::vector x_grad_help_dims = + vectorize(x_grad_help.dims()); + const std::vector y_grad_help_dims = + vectorize(y_grad_help.dims()); + + std::vector x_grad_broadcast_dims(ndim); + std::vector y_grad_broadcast_dims(ndim); + + std::fill(x_grad_broadcast_dims.data(), + x_grad_broadcast_dims.data() + ndim - x_ndim, + 1); + std::fill(y_grad_broadcast_dims.data(), + y_grad_broadcast_dims.data() + ndim - y_ndim, + 1); + std::copy(x_dims.data(), + x_dims.data() + x_ndim, + x_grad_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), + y_dims.data() + y_ndim, + y_grad_broadcast_dims.data() + ndim - y_ndim); + + std::vector x_grad_reduce_dims; + std::vector y_grad_reduce_dims; + for (int ix_grad = 0; ix_grad <= ndim - 3; ix_grad++) { + if (x_grad_help_dims[ix_grad] != 1 && + x_grad_broadcast_dims[ix_grad] == 1) { + x_grad_reduce_dims.push_back(ix_grad); + } + if (y_grad_help_dims[ix_grad] != 1 && + y_grad_broadcast_dims[ix_grad] == 1) { + y_grad_reduce_dims.push_back(ix_grad); + } + } + // Reduce sum to get grad by ReduceSum + if (x_grad && x_grad_help.initialized()) { + if (x_grad_reduce_dims.empty()) { + x_grad_help = std::move(x_grad_help); + } else { + x_grad_help = sum(x_grad_help, IntArray(x_grad_reduce_dims)); + } + reshape(x_grad_help, IntArray(phi::vectorize(x.dims()))); + } else if (x_grad && !x_grad_help.initialized()) { + x_grad_help = full(phi::vectorize(x.dims()), Scalar(0.0)); + } + set_output(x_grad_help, x_grad); + + if (y_grad && y_grad_help.initialized()) { + if (y_grad_reduce_dims.empty()) { + y_grad_help = std::move(y_grad_help); + } else { + y_grad_help = sum(y_grad_help, IntArray(y_grad_reduce_dims)); + } + reshape(y_grad_help, IntArray(phi::vectorize(y.dims()))); + } else if (y_grad && !y_grad_help.initialized()) { + y_grad_help = full(phi::vectorize(y.dims()), Scalar(0.0)); + } + set_output(y_grad_help, y_grad); + + if (grad_out_grad) { + // Calculate the gradient of OutputGrad(Out) + if (grad_x_grad) { + grad_out_grad_help = + matmul(grad_x_grad.get(), y, transpose_x, transpose_y); + } + if (grad_y_grad) { + auto grad_out_grad_help_2 = + matmul(x, grad_y_grad.get(), transpose_x, transpose_y); + grad_out_grad_help = add(grad_out_grad_help, grad_out_grad_help_2); + } + set_output(grad_out_grad_help, grad_out_grad); + } + } +} + } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual/utils/utils.h b/paddle/fluid/prim/api/manual/utils/utils.h index 20b02f2df9..6e7b82ac34 100644 --- a/paddle/fluid/prim/api/manual/utils/utils.h +++ b/paddle/fluid/prim/api/manual/utils/utils.h @@ -17,30 +17,34 @@ #include #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/operators/common_infer_shape_functions.h" +#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" namespace paddle { namespace prim { // We put some api like utils here +using Tensor = paddle::experimental::Tensor; template -paddle::experimental::Tensor empty(const paddle::experimental::IntArray& shape, - paddle::experimental::DataType dype, - const paddle::Place& place); +Tensor empty(const paddle::experimental::IntArray& shape, + paddle::experimental::DataType dype, + const paddle::Place& place); template -paddle::experimental::Tensor empty_like(const paddle::experimental::Tensor& x, - paddle::experimental::DataType dtype, - const paddle::Place& place); +Tensor empty_like(const Tensor& x, + paddle::experimental::DataType dtype, + const paddle::Place& place); + +// copy tensor for output ptr, in static need use assigh op template -void by_pass(const paddle::experimental::Tensor& x, - paddle::experimental::Tensor* out); +void by_pass(const Tensor& x, Tensor* out); +// set output ptr impl with tmp ptr impl,in dygraph OutGradMeta should be set template -void set_output(const paddle::experimental::Tensor& x_tmp, - paddle::experimental::Tensor* x); +void set_output(const Tensor& x_tmp, Tensor* x); // These method don't need to be specified static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims, @@ -78,5 +82,90 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims, return get_reduce_dims_from_out(out_dims, x_dims); } +template +std::tuple modify_dim_for_matmul(const Tensor& a, + bool is_fold_init_dims_a, + const Tensor& b, + const Tensor* out, + bool is_fold_init_dims_b) { + Tensor a_out = a; + Tensor b_out = b; + bool need_combine = + (a.dims().size() == 3 || b.dims().size() == 3) && out->dims().size() == 2; + if (need_combine) { + auto a_dims = a.dims(); + auto b_dims = b.dims(); + if (is_fold_init_dims_a) { + if (a_dims.size() == 3) { + std::vector a_shape = {a_dims[0] * a_dims[1], a_dims[2]}; + a_out = reshape(a_out, IntArray(a_shape)); + } + } else { + if (a_dims.size() == 3) { + a_out = transpose(a, IntArray(std::vector({1, 0, 2}))); + std::vector a_shape = {a_dims[0], a_dims[1] * a_dims[2]}; + a_out = reshape(a_out, IntArray(a_shape)); + } + } + + if (is_fold_init_dims_b) { + if (b_dims.size() == 3) { + std::vector b_shape = {b_dims[0] * b_dims[1], b_dims[2]}; + b_out = reshape(b_out, IntArray(b_shape)); + } + } else { + if (b_dims.size() == 3) { + b_out = transpose(b, IntArray(std::vector({1, 0, 2}))); + std::vector b_shape = {b_dims[0], b_dims[1] * b_dims[2]}; + b_out = reshape(b_out, IntArray(b_shape)); + } + } + } + std::tuple output(a_out, b_out); + return output; +} + +template +void reshape_tensor_to_matrixsequence( + Tensor* x, const phi::funcs::MatDescriptor& descriptor) { + int64_t h, w; + h = descriptor.height_; + w = descriptor.width_; + if (descriptor.trans_) { + std::swap(w, h); + } + if (descriptor.batch_size_) { + *x = reshape(*x, std::vector({descriptor.batch_size_, h, w})); + } else { + *x = reshape(*x, std::vector({h, w})); + } +} + +template +void reshape_xyout_to_matrixsequence( + Tensor* x, Tensor* y, Tensor* out, bool trans_x, bool trans_y) { + if (x->dims().size() == 1) { + *x = reshape(*x, std::vector({1, x->dims()[0]})); + } + if (y->dims().size() == 1) { + *y = reshape(*y, std::vector({y->dims()[0], 1})); + } + auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x->dims(), 0, trans_x); + auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y->dims(), 0, trans_y); + if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { + *out = reshape( + *out, std::vector({mat_dim_x.height_, mat_dim_y.width_})); + } else { + *out = reshape(*out, + std::vector({(std::max)(mat_dim_x.batch_size_, + mat_dim_y.batch_size_), + mat_dim_x.height_, + mat_dim_y.width_})); + } + + reshape_tensor_to_matrixsequence(x, mat_dim_x); + reshape_tensor_to_matrixsequence(y, mat_dim_y); +} + } // namespace prim } // namespace paddle -- GitLab