diff --git a/paddle/phi/kernels/impl/kron_grad_kernel_impl.h b/paddle/phi/kernels/impl/kron_grad_kernel_impl.h index 30297b53eabb99c4fcccc5c3c7faa04f86d4bb93..4829ae0a9f0c954dc366ac8c06c022fa0bf43d39 100644 --- a/paddle/phi/kernels/impl/kron_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/kron_grad_kernel_impl.h @@ -20,15 +20,15 @@ namespace phi { template struct KronGradElemFunctor { - KronGradElemFunctor(const T* dout, - const T* A, - const T* B, - T* dout_a, - T* dout_b, - const int64_t* stride_dout, - const int64_t* stride_a, - const int64_t* stride_b, - const int64_t* shape_b, + KronGradElemFunctor(const T *dout, + const T *A, + const T *B, + T *dout_a, + T *dout_b, + const int64_t *stride_dout, + const int64_t *stride_a, + const int64_t *stride_b, + const int64_t *shape_b, const int64_t numel_a, const int64_t numel_b, const int ndims) @@ -69,15 +69,15 @@ struct KronGradElemFunctor { } private: - const T* dout_; - const T* A_; - const T* B_; - T* dout_a_; - T* dout_b_; - const int64_t* stride_dout_; - const int64_t* stride_a_; - const int64_t* stride_b_; - const int64_t* shape_b_; + const T *dout_; + const T *A_; + const T *B_; + T *dout_a_; + T *dout_b_; + const int64_t *stride_dout_; + const int64_t *stride_a_; + const int64_t *stride_b_; + const int64_t *shape_b_; const int64_t numel_a_; const int64_t numel_b_; const int ndims_; @@ -85,15 +85,15 @@ struct KronGradElemFunctor { template struct KronGradElemFunctor> { - KronGradElemFunctor(const dtype::complex* dout, - const dtype::complex* A, - const dtype::complex* B, - dtype::complex* dout_a, - dtype::complex* dout_b, - const int64_t* stride_dout, - const int64_t* stride_a, - const int64_t* stride_b, - const int64_t* shape_b, + KronGradElemFunctor(const dtype::complex *dout, + const dtype::complex *A, + const dtype::complex *B, + dtype::complex *dout_a, + dtype::complex *dout_b, + const int64_t *stride_dout, + const int64_t *stride_a, + const int64_t *stride_b, + const int64_t *shape_b, const int64_t numel_a, const int64_t numel_b, const int ndims) @@ -136,15 +136,15 @@ struct KronGradElemFunctor> { } private: - const dtype::complex* dout_; - const dtype::complex* A_; - const dtype::complex* B_; - dtype::complex* dout_a_; - dtype::complex* dout_b_; - const int64_t* stride_dout_; - const int64_t* stride_a_; - const int64_t* stride_b_; - const int64_t* shape_b_; + const dtype::complex *dout_; + const dtype::complex *A_; + const dtype::complex *B_; + dtype::complex *dout_a_; + dtype::complex *dout_b_; + const int64_t *stride_dout_; + const int64_t *stride_a_; + const int64_t *stride_b_; + const int64_t *shape_b_; const int64_t numel_a_; const int64_t numel_b_; const int ndims_; @@ -152,29 +152,31 @@ struct KronGradElemFunctor> { template struct KronGradOpFunctor { - void operator()(const Context& dev_ctx, - const DenseTensor& dout, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* dx, - DenseTensor* dy) { + void operator()(const Context &dev_ctx, + const DenseTensor &dout, + const DenseTensor &x, + const DenseTensor &y, + DenseTensor *dx, + DenseTensor *dy) { int ndims = dout.dims().size(); int64_t numel = dout.numel(); int64_t numel_x = x.numel(); int64_t numel_y = y.numel(); - const phi::DDim& dim_x = x.dims(); - const phi::DDim& dim_y = y.dims(); - const phi::DDim& dim_dout = dout.dims(); + const phi::DDim &dim_x = x.dims(); + const phi::DDim &dim_y = y.dims(); + const phi::DDim &dim_dout = dout.dims(); + const phi::DDim stride_x = + dim_x.size() == 0 ? phi::DDim(dim_x) : phi::stride(dim_x); + const phi::DDim stride_y = + dim_y.size() == 0 ? phi::DDim(dim_y) : phi::stride(dim_y); + const phi::DDim stride_dout = + dim_dout.size() == 0 ? phi::DDim(dim_dout) : phi::stride(dim_dout); - const phi::DDim stride_x = phi::stride(dim_x); - const phi::DDim stride_y = phi::stride(dim_y); - const phi::DDim stride_dout = phi::stride(dim_dout); - - const int64_t* p_stride_x = nullptr; - const int64_t* p_stride_y = nullptr; - const int64_t* p_stride_dout = nullptr; - const int64_t* p_shape_y = nullptr; + const int64_t *p_stride_x = nullptr; + const int64_t *p_stride_y = nullptr; + const int64_t *p_stride_dout = nullptr; + const int64_t *p_shape_y = nullptr; #if defined(__NVCC__) || defined(__HIPCC__) thrust::device_vector d_stride_x(ndims); thrust::device_vector d_stride_y(ndims); @@ -199,14 +201,14 @@ struct KronGradOpFunctor { // dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y) // dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x) DenseTensor dout_x; - T* p_dout_x = nullptr; + T *p_dout_x = nullptr; if (dx) { dout_x.Resize({numel_x, numel_y}); dev_ctx.template Alloc(&dout_x); p_dout_x = dout_x.data(); } DenseTensor dout_y; - T* p_dout_y = nullptr; + T *p_dout_y = nullptr; if (dy) { dout_y.Resize({numel_y, numel_x}); dev_ctx.template Alloc(&dout_y); @@ -240,7 +242,7 @@ struct KronGradOpFunctor { dev_ctx, dout_y, dy, kps::IdentityFunctor(), {1}); } #else - auto* place = dev_ctx.eigen_device(); + auto *place = dev_ctx.eigen_device(); Eigen::array reduce_dim = {1}; if (dx) { auto eigen_dout_x = EigenMatrix::Reshape(dout_x, 1); @@ -257,12 +259,12 @@ struct KronGradOpFunctor { }; template -void KronGradKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out_grad, - DenseTensor* x_grad, - DenseTensor* y_grad) { +void KronGradKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out_grad, + DenseTensor *x_grad, + DenseTensor *y_grad) { if (x_grad) { ctx.template Alloc(x_grad); } @@ -274,8 +276,8 @@ void KronGradKernel(const Context& ctx, DenseTensor xx = UnsqueezeTo(x, ndims); DenseTensor yy = UnsqueezeTo(y, ndims); - DenseTensor* pdxx = nullptr; - DenseTensor* pdyy = nullptr; + DenseTensor *pdxx = nullptr; + DenseTensor *pdyy = nullptr; DenseTensor dxx; DenseTensor dyy; if (x_grad) { diff --git a/paddle/phi/kernels/impl/kron_kernel_impl.h b/paddle/phi/kernels/impl/kron_kernel_impl.h index 47c76f59df23bfee68a2660b76a09df747048378..e1fcb49949a748493a500b7c7a753480efbbc0d3 100644 --- a/paddle/phi/kernels/impl/kron_kernel_impl.h +++ b/paddle/phi/kernels/impl/kron_kernel_impl.h @@ -27,8 +27,8 @@ namespace phi { -inline DenseTensor UnsqueezeTo(const DenseTensor& src, int ndims) { - const phi::DDim& shape = src.dims(); +inline DenseTensor UnsqueezeTo(const DenseTensor &src, int ndims) { + const phi::DDim &shape = src.dims(); int rank = shape.size(); DenseTensor res; res.ShareDataWith(src); @@ -52,13 +52,13 @@ inline DenseTensor UnsqueezeTo(const DenseTensor& src, int ndims) { template struct KronElemFunctor { - KronElemFunctor(const T* a, - const T* b, - T* out, - const int64_t* shape_b, - const int64_t* stride_a, - const int64_t* stride_b, - const int64_t* stride_out, + KronElemFunctor(const T *a, + const T *b, + T *out, + const int64_t *shape_b, + const int64_t *stride_a, + const int64_t *stride_b, + const int64_t *stride_out, int ndims) : a_(a), b_(b), @@ -86,31 +86,34 @@ struct KronElemFunctor { } private: - const T* a_; - const T* b_; - T* out_; - const int64_t* shape_b_; - const int64_t* stride_a_; - const int64_t* stride_b_; - const int64_t* stride_out_; + const T *a_; + const T *b_; + T *out_; + const int64_t *shape_b_; + const int64_t *stride_a_; + const int64_t *stride_b_; + const int64_t *stride_out_; const int ndims_; }; template struct KronOpFunctor { - void operator()(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { + void operator()(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + DenseTensor *out) { int ndims = out->dims().size(); int64_t numel = out->numel(); - const phi::DDim& dim_x = x.dims(); - const phi::DDim& dim_y = y.dims(); - const phi::DDim& dim_out = out->dims(); - const phi::DDim stride_x = phi::stride(dim_x); - const phi::DDim stride_y = phi::stride(dim_y); - const phi::DDim stride_out = phi::stride(dim_out); + const phi::DDim &dim_x = x.dims(); + const phi::DDim &dim_y = y.dims(); + const phi::DDim &dim_out = out->dims(); + const phi::DDim stride_x = + dim_x.size() == 0 ? phi::DDim(dim_x) : phi::stride(dim_x); + const phi::DDim stride_y = + dim_y.size() == 0 ? phi::DDim(dim_y) : phi::stride(dim_y); + const phi::DDim stride_out = + dim_out.size() == 0 ? phi::DDim(dim_out) : phi::stride(dim_out); const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr, *p_stride_out = nullptr, *p_shape_y = nullptr; @@ -150,10 +153,10 @@ struct KronOpFunctor { }; template -void KronKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { +void KronKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &y, + DenseTensor *out) { ctx.template Alloc(out); int ndims = out->dims().size(); diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 5d9de5dd1b64d667505bd4c317a52f2ed30a81ac..0b3a6c20ec1a9ab6a89c2a53b3479b29760624f2 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -293,6 +293,7 @@ binary_api_list = [ paddle.fmax, paddle.fmin, paddle.complex, + paddle.kron, ] binary_int_api_list = [