未验证 提交 ae622479 编写于 作者: G Guo Sheng 提交者: GitHub

[Zero-Dim] Support 0D for kron (#49847)

* Add unittest for zero-dim kron.

* Fix zero dim kron.
上级 af23efe0
...@@ -20,15 +20,15 @@ namespace phi { ...@@ -20,15 +20,15 @@ namespace phi {
template <typename T> template <typename T>
struct KronGradElemFunctor { struct KronGradElemFunctor {
KronGradElemFunctor(const T* dout, KronGradElemFunctor(const T *dout,
const T* A, const T *A,
const T* B, const T *B,
T* dout_a, T *dout_a,
T* dout_b, T *dout_b,
const int64_t* stride_dout, const int64_t *stride_dout,
const int64_t* stride_a, const int64_t *stride_a,
const int64_t* stride_b, const int64_t *stride_b,
const int64_t* shape_b, const int64_t *shape_b,
const int64_t numel_a, const int64_t numel_a,
const int64_t numel_b, const int64_t numel_b,
const int ndims) const int ndims)
...@@ -69,15 +69,15 @@ struct KronGradElemFunctor { ...@@ -69,15 +69,15 @@ struct KronGradElemFunctor {
} }
private: private:
const T* dout_; const T *dout_;
const T* A_; const T *A_;
const T* B_; const T *B_;
T* dout_a_; T *dout_a_;
T* dout_b_; T *dout_b_;
const int64_t* stride_dout_; const int64_t *stride_dout_;
const int64_t* stride_a_; const int64_t *stride_a_;
const int64_t* stride_b_; const int64_t *stride_b_;
const int64_t* shape_b_; const int64_t *shape_b_;
const int64_t numel_a_; const int64_t numel_a_;
const int64_t numel_b_; const int64_t numel_b_;
const int ndims_; const int ndims_;
...@@ -85,15 +85,15 @@ struct KronGradElemFunctor { ...@@ -85,15 +85,15 @@ struct KronGradElemFunctor {
template <typename T> template <typename T>
struct KronGradElemFunctor<dtype::complex<T>> { struct KronGradElemFunctor<dtype::complex<T>> {
KronGradElemFunctor(const dtype::complex<T>* dout, KronGradElemFunctor(const dtype::complex<T> *dout,
const dtype::complex<T>* A, const dtype::complex<T> *A,
const dtype::complex<T>* B, const dtype::complex<T> *B,
dtype::complex<T>* dout_a, dtype::complex<T> *dout_a,
dtype::complex<T>* dout_b, dtype::complex<T> *dout_b,
const int64_t* stride_dout, const int64_t *stride_dout,
const int64_t* stride_a, const int64_t *stride_a,
const int64_t* stride_b, const int64_t *stride_b,
const int64_t* shape_b, const int64_t *shape_b,
const int64_t numel_a, const int64_t numel_a,
const int64_t numel_b, const int64_t numel_b,
const int ndims) const int ndims)
...@@ -136,15 +136,15 @@ struct KronGradElemFunctor<dtype::complex<T>> { ...@@ -136,15 +136,15 @@ struct KronGradElemFunctor<dtype::complex<T>> {
} }
private: private:
const dtype::complex<T>* dout_; const dtype::complex<T> *dout_;
const dtype::complex<T>* A_; const dtype::complex<T> *A_;
const dtype::complex<T>* B_; const dtype::complex<T> *B_;
dtype::complex<T>* dout_a_; dtype::complex<T> *dout_a_;
dtype::complex<T>* dout_b_; dtype::complex<T> *dout_b_;
const int64_t* stride_dout_; const int64_t *stride_dout_;
const int64_t* stride_a_; const int64_t *stride_a_;
const int64_t* stride_b_; const int64_t *stride_b_;
const int64_t* shape_b_; const int64_t *shape_b_;
const int64_t numel_a_; const int64_t numel_a_;
const int64_t numel_b_; const int64_t numel_b_;
const int ndims_; const int ndims_;
...@@ -152,29 +152,31 @@ struct KronGradElemFunctor<dtype::complex<T>> { ...@@ -152,29 +152,31 @@ struct KronGradElemFunctor<dtype::complex<T>> {
template <typename Context, typename T> template <typename Context, typename T>
struct KronGradOpFunctor { struct KronGradOpFunctor {
void operator()(const Context& dev_ctx, void operator()(const Context &dev_ctx,
const DenseTensor& dout, const DenseTensor &dout,
const DenseTensor& x, const DenseTensor &x,
const DenseTensor& y, const DenseTensor &y,
DenseTensor* dx, DenseTensor *dx,
DenseTensor* dy) { DenseTensor *dy) {
int ndims = dout.dims().size(); int ndims = dout.dims().size();
int64_t numel = dout.numel(); int64_t numel = dout.numel();
int64_t numel_x = x.numel(); int64_t numel_x = x.numel();
int64_t numel_y = y.numel(); int64_t numel_y = y.numel();
const phi::DDim& dim_x = x.dims(); const phi::DDim &dim_x = x.dims();
const phi::DDim& dim_y = y.dims(); const phi::DDim &dim_y = y.dims();
const phi::DDim& dim_dout = dout.dims(); const phi::DDim &dim_dout = dout.dims();
const phi::DDim stride_x =
dim_x.size() == 0 ? phi::DDim(dim_x) : phi::stride(dim_x);
const phi::DDim stride_y =
dim_y.size() == 0 ? phi::DDim(dim_y) : phi::stride(dim_y);
const phi::DDim stride_dout =
dim_dout.size() == 0 ? phi::DDim(dim_dout) : phi::stride(dim_dout);
const phi::DDim stride_x = phi::stride(dim_x); const int64_t *p_stride_x = nullptr;
const phi::DDim stride_y = phi::stride(dim_y); const int64_t *p_stride_y = nullptr;
const phi::DDim stride_dout = phi::stride(dim_dout); const int64_t *p_stride_dout = nullptr;
const int64_t *p_shape_y = nullptr;
const int64_t* p_stride_x = nullptr;
const int64_t* p_stride_y = nullptr;
const int64_t* p_stride_dout = nullptr;
const int64_t* p_shape_y = nullptr;
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<int64_t> d_stride_x(ndims); thrust::device_vector<int64_t> d_stride_x(ndims);
thrust::device_vector<int64_t> d_stride_y(ndims); thrust::device_vector<int64_t> d_stride_y(ndims);
...@@ -199,14 +201,14 @@ struct KronGradOpFunctor { ...@@ -199,14 +201,14 @@ struct KronGradOpFunctor {
// dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y) // dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y)
// dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x) // dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x)
DenseTensor dout_x; DenseTensor dout_x;
T* p_dout_x = nullptr; T *p_dout_x = nullptr;
if (dx) { if (dx) {
dout_x.Resize({numel_x, numel_y}); dout_x.Resize({numel_x, numel_y});
dev_ctx.template Alloc<T>(&dout_x); dev_ctx.template Alloc<T>(&dout_x);
p_dout_x = dout_x.data<T>(); p_dout_x = dout_x.data<T>();
} }
DenseTensor dout_y; DenseTensor dout_y;
T* p_dout_y = nullptr; T *p_dout_y = nullptr;
if (dy) { if (dy) {
dout_y.Resize({numel_y, numel_x}); dout_y.Resize({numel_y, numel_x});
dev_ctx.template Alloc<T>(&dout_y); dev_ctx.template Alloc<T>(&dout_y);
...@@ -240,7 +242,7 @@ struct KronGradOpFunctor { ...@@ -240,7 +242,7 @@ struct KronGradOpFunctor {
dev_ctx, dout_y, dy, kps::IdentityFunctor<T>(), {1}); dev_ctx, dout_y, dy, kps::IdentityFunctor<T>(), {1});
} }
#else #else
auto* place = dev_ctx.eigen_device(); auto *place = dev_ctx.eigen_device();
Eigen::array<int, 1> reduce_dim = {1}; Eigen::array<int, 1> reduce_dim = {1};
if (dx) { if (dx) {
auto eigen_dout_x = EigenMatrix<T>::Reshape(dout_x, 1); auto eigen_dout_x = EigenMatrix<T>::Reshape(dout_x, 1);
...@@ -257,12 +259,12 @@ struct KronGradOpFunctor { ...@@ -257,12 +259,12 @@ struct KronGradOpFunctor {
}; };
template <typename T, typename Context> template <typename T, typename Context>
void KronGradKernel(const Context& ctx, void KronGradKernel(const Context &ctx,
const DenseTensor& x, const DenseTensor &x,
const DenseTensor& y, const DenseTensor &y,
const DenseTensor& out_grad, const DenseTensor &out_grad,
DenseTensor* x_grad, DenseTensor *x_grad,
DenseTensor* y_grad) { DenseTensor *y_grad) {
if (x_grad) { if (x_grad) {
ctx.template Alloc<T>(x_grad); ctx.template Alloc<T>(x_grad);
} }
...@@ -274,8 +276,8 @@ void KronGradKernel(const Context& ctx, ...@@ -274,8 +276,8 @@ void KronGradKernel(const Context& ctx,
DenseTensor xx = UnsqueezeTo(x, ndims); DenseTensor xx = UnsqueezeTo(x, ndims);
DenseTensor yy = UnsqueezeTo(y, ndims); DenseTensor yy = UnsqueezeTo(y, ndims);
DenseTensor* pdxx = nullptr; DenseTensor *pdxx = nullptr;
DenseTensor* pdyy = nullptr; DenseTensor *pdyy = nullptr;
DenseTensor dxx; DenseTensor dxx;
DenseTensor dyy; DenseTensor dyy;
if (x_grad) { if (x_grad) {
......
...@@ -27,8 +27,8 @@ ...@@ -27,8 +27,8 @@
namespace phi { namespace phi {
inline DenseTensor UnsqueezeTo(const DenseTensor& src, int ndims) { inline DenseTensor UnsqueezeTo(const DenseTensor &src, int ndims) {
const phi::DDim& shape = src.dims(); const phi::DDim &shape = src.dims();
int rank = shape.size(); int rank = shape.size();
DenseTensor res; DenseTensor res;
res.ShareDataWith(src); res.ShareDataWith(src);
...@@ -52,13 +52,13 @@ inline DenseTensor UnsqueezeTo(const DenseTensor& src, int ndims) { ...@@ -52,13 +52,13 @@ inline DenseTensor UnsqueezeTo(const DenseTensor& src, int ndims) {
template <typename T> template <typename T>
struct KronElemFunctor { struct KronElemFunctor {
KronElemFunctor(const T* a, KronElemFunctor(const T *a,
const T* b, const T *b,
T* out, T *out,
const int64_t* shape_b, const int64_t *shape_b,
const int64_t* stride_a, const int64_t *stride_a,
const int64_t* stride_b, const int64_t *stride_b,
const int64_t* stride_out, const int64_t *stride_out,
int ndims) int ndims)
: a_(a), : a_(a),
b_(b), b_(b),
...@@ -86,31 +86,34 @@ struct KronElemFunctor { ...@@ -86,31 +86,34 @@ struct KronElemFunctor {
} }
private: private:
const T* a_; const T *a_;
const T* b_; const T *b_;
T* out_; T *out_;
const int64_t* shape_b_; const int64_t *shape_b_;
const int64_t* stride_a_; const int64_t *stride_a_;
const int64_t* stride_b_; const int64_t *stride_b_;
const int64_t* stride_out_; const int64_t *stride_out_;
const int ndims_; const int ndims_;
}; };
template <typename Context, typename T> template <typename Context, typename T>
struct KronOpFunctor { struct KronOpFunctor {
void operator()(const Context& dev_ctx, void operator()(const Context &dev_ctx,
const DenseTensor& x, const DenseTensor &x,
const DenseTensor& y, const DenseTensor &y,
DenseTensor* out) { DenseTensor *out) {
int ndims = out->dims().size(); int ndims = out->dims().size();
int64_t numel = out->numel(); int64_t numel = out->numel();
const phi::DDim& dim_x = x.dims(); const phi::DDim &dim_x = x.dims();
const phi::DDim& dim_y = y.dims(); const phi::DDim &dim_y = y.dims();
const phi::DDim& dim_out = out->dims(); const phi::DDim &dim_out = out->dims();
const phi::DDim stride_x = phi::stride(dim_x); const phi::DDim stride_x =
const phi::DDim stride_y = phi::stride(dim_y); dim_x.size() == 0 ? phi::DDim(dim_x) : phi::stride(dim_x);
const phi::DDim stride_out = phi::stride(dim_out); const phi::DDim stride_y =
dim_y.size() == 0 ? phi::DDim(dim_y) : phi::stride(dim_y);
const phi::DDim stride_out =
dim_out.size() == 0 ? phi::DDim(dim_out) : phi::stride(dim_out);
const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr, const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr,
*p_stride_out = nullptr, *p_shape_y = nullptr; *p_stride_out = nullptr, *p_shape_y = nullptr;
...@@ -150,10 +153,10 @@ struct KronOpFunctor { ...@@ -150,10 +153,10 @@ struct KronOpFunctor {
}; };
template <typename T, typename Context> template <typename T, typename Context>
void KronKernel(const Context& ctx, void KronKernel(const Context &ctx,
const DenseTensor& x, const DenseTensor &x,
const DenseTensor& y, const DenseTensor &y,
DenseTensor* out) { DenseTensor *out) {
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
int ndims = out->dims().size(); int ndims = out->dims().size();
......
...@@ -293,6 +293,7 @@ binary_api_list = [ ...@@ -293,6 +293,7 @@ binary_api_list = [
paddle.fmax, paddle.fmax,
paddle.fmin, paddle.fmin,
paddle.complex, paddle.complex,
paddle.kron,
] ]
binary_int_api_list = [ binary_int_api_list = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册