未验证 提交 ae622479 编写于 作者: G Guo Sheng 提交者: GitHub

[Zero-Dim] Support 0D for kron (#49847)

* Add unittest for zero-dim kron.

* Fix zero dim kron.
上级 af23efe0
......@@ -20,15 +20,15 @@ namespace phi {
template <typename T>
struct KronGradElemFunctor {
KronGradElemFunctor(const T* dout,
const T* A,
const T* B,
T* dout_a,
T* dout_b,
const int64_t* stride_dout,
const int64_t* stride_a,
const int64_t* stride_b,
const int64_t* shape_b,
KronGradElemFunctor(const T *dout,
const T *A,
const T *B,
T *dout_a,
T *dout_b,
const int64_t *stride_dout,
const int64_t *stride_a,
const int64_t *stride_b,
const int64_t *shape_b,
const int64_t numel_a,
const int64_t numel_b,
const int ndims)
......@@ -69,15 +69,15 @@ struct KronGradElemFunctor {
}
private:
const T* dout_;
const T* A_;
const T* B_;
T* dout_a_;
T* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* shape_b_;
const T *dout_;
const T *A_;
const T *B_;
T *dout_a_;
T *dout_b_;
const int64_t *stride_dout_;
const int64_t *stride_a_;
const int64_t *stride_b_;
const int64_t *shape_b_;
const int64_t numel_a_;
const int64_t numel_b_;
const int ndims_;
......@@ -85,15 +85,15 @@ struct KronGradElemFunctor {
template <typename T>
struct KronGradElemFunctor<dtype::complex<T>> {
KronGradElemFunctor(const dtype::complex<T>* dout,
const dtype::complex<T>* A,
const dtype::complex<T>* B,
dtype::complex<T>* dout_a,
dtype::complex<T>* dout_b,
const int64_t* stride_dout,
const int64_t* stride_a,
const int64_t* stride_b,
const int64_t* shape_b,
KronGradElemFunctor(const dtype::complex<T> *dout,
const dtype::complex<T> *A,
const dtype::complex<T> *B,
dtype::complex<T> *dout_a,
dtype::complex<T> *dout_b,
const int64_t *stride_dout,
const int64_t *stride_a,
const int64_t *stride_b,
const int64_t *shape_b,
const int64_t numel_a,
const int64_t numel_b,
const int ndims)
......@@ -136,15 +136,15 @@ struct KronGradElemFunctor<dtype::complex<T>> {
}
private:
const dtype::complex<T>* dout_;
const dtype::complex<T>* A_;
const dtype::complex<T>* B_;
dtype::complex<T>* dout_a_;
dtype::complex<T>* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* shape_b_;
const dtype::complex<T> *dout_;
const dtype::complex<T> *A_;
const dtype::complex<T> *B_;
dtype::complex<T> *dout_a_;
dtype::complex<T> *dout_b_;
const int64_t *stride_dout_;
const int64_t *stride_a_;
const int64_t *stride_b_;
const int64_t *shape_b_;
const int64_t numel_a_;
const int64_t numel_b_;
const int ndims_;
......@@ -152,29 +152,31 @@ struct KronGradElemFunctor<dtype::complex<T>> {
template <typename Context, typename T>
struct KronGradOpFunctor {
void operator()(const Context& dev_ctx,
const DenseTensor& dout,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* dx,
DenseTensor* dy) {
void operator()(const Context &dev_ctx,
const DenseTensor &dout,
const DenseTensor &x,
const DenseTensor &y,
DenseTensor *dx,
DenseTensor *dy) {
int ndims = dout.dims().size();
int64_t numel = dout.numel();
int64_t numel_x = x.numel();
int64_t numel_y = y.numel();
const phi::DDim& dim_x = x.dims();
const phi::DDim& dim_y = y.dims();
const phi::DDim& dim_dout = dout.dims();
const phi::DDim &dim_x = x.dims();
const phi::DDim &dim_y = y.dims();
const phi::DDim &dim_dout = dout.dims();
const phi::DDim stride_x =
dim_x.size() == 0 ? phi::DDim(dim_x) : phi::stride(dim_x);
const phi::DDim stride_y =
dim_y.size() == 0 ? phi::DDim(dim_y) : phi::stride(dim_y);
const phi::DDim stride_dout =
dim_dout.size() == 0 ? phi::DDim(dim_dout) : phi::stride(dim_dout);
const phi::DDim stride_x = phi::stride(dim_x);
const phi::DDim stride_y = phi::stride(dim_y);
const phi::DDim stride_dout = phi::stride(dim_dout);
const int64_t* p_stride_x = nullptr;
const int64_t* p_stride_y = nullptr;
const int64_t* p_stride_dout = nullptr;
const int64_t* p_shape_y = nullptr;
const int64_t *p_stride_x = nullptr;
const int64_t *p_stride_y = nullptr;
const int64_t *p_stride_dout = nullptr;
const int64_t *p_shape_y = nullptr;
#if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<int64_t> d_stride_x(ndims);
thrust::device_vector<int64_t> d_stride_y(ndims);
......@@ -199,14 +201,14 @@ struct KronGradOpFunctor {
// dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y)
// dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x)
DenseTensor dout_x;
T* p_dout_x = nullptr;
T *p_dout_x = nullptr;
if (dx) {
dout_x.Resize({numel_x, numel_y});
dev_ctx.template Alloc<T>(&dout_x);
p_dout_x = dout_x.data<T>();
}
DenseTensor dout_y;
T* p_dout_y = nullptr;
T *p_dout_y = nullptr;
if (dy) {
dout_y.Resize({numel_y, numel_x});
dev_ctx.template Alloc<T>(&dout_y);
......@@ -240,7 +242,7 @@ struct KronGradOpFunctor {
dev_ctx, dout_y, dy, kps::IdentityFunctor<T>(), {1});
}
#else
auto* place = dev_ctx.eigen_device();
auto *place = dev_ctx.eigen_device();
Eigen::array<int, 1> reduce_dim = {1};
if (dx) {
auto eigen_dout_x = EigenMatrix<T>::Reshape(dout_x, 1);
......@@ -257,12 +259,12 @@ struct KronGradOpFunctor {
};
template <typename T, typename Context>
void KronGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
void KronGradKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out_grad,
DenseTensor *x_grad,
DenseTensor *y_grad) {
if (x_grad) {
ctx.template Alloc<T>(x_grad);
}
......@@ -274,8 +276,8 @@ void KronGradKernel(const Context& ctx,
DenseTensor xx = UnsqueezeTo(x, ndims);
DenseTensor yy = UnsqueezeTo(y, ndims);
DenseTensor* pdxx = nullptr;
DenseTensor* pdyy = nullptr;
DenseTensor *pdxx = nullptr;
DenseTensor *pdyy = nullptr;
DenseTensor dxx;
DenseTensor dyy;
if (x_grad) {
......
......@@ -27,8 +27,8 @@
namespace phi {
inline DenseTensor UnsqueezeTo(const DenseTensor& src, int ndims) {
const phi::DDim& shape = src.dims();
inline DenseTensor UnsqueezeTo(const DenseTensor &src, int ndims) {
const phi::DDim &shape = src.dims();
int rank = shape.size();
DenseTensor res;
res.ShareDataWith(src);
......@@ -52,13 +52,13 @@ inline DenseTensor UnsqueezeTo(const DenseTensor& src, int ndims) {
template <typename T>
struct KronElemFunctor {
KronElemFunctor(const T* a,
const T* b,
T* out,
const int64_t* shape_b,
const int64_t* stride_a,
const int64_t* stride_b,
const int64_t* stride_out,
KronElemFunctor(const T *a,
const T *b,
T *out,
const int64_t *shape_b,
const int64_t *stride_a,
const int64_t *stride_b,
const int64_t *stride_out,
int ndims)
: a_(a),
b_(b),
......@@ -86,31 +86,34 @@ struct KronElemFunctor {
}
private:
const T* a_;
const T* b_;
T* out_;
const int64_t* shape_b_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* stride_out_;
const T *a_;
const T *b_;
T *out_;
const int64_t *shape_b_;
const int64_t *stride_a_;
const int64_t *stride_b_;
const int64_t *stride_out_;
const int ndims_;
};
template <typename Context, typename T>
struct KronOpFunctor {
void operator()(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
void operator()(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
DenseTensor *out) {
int ndims = out->dims().size();
int64_t numel = out->numel();
const phi::DDim& dim_x = x.dims();
const phi::DDim& dim_y = y.dims();
const phi::DDim& dim_out = out->dims();
const phi::DDim stride_x = phi::stride(dim_x);
const phi::DDim stride_y = phi::stride(dim_y);
const phi::DDim stride_out = phi::stride(dim_out);
const phi::DDim &dim_x = x.dims();
const phi::DDim &dim_y = y.dims();
const phi::DDim &dim_out = out->dims();
const phi::DDim stride_x =
dim_x.size() == 0 ? phi::DDim(dim_x) : phi::stride(dim_x);
const phi::DDim stride_y =
dim_y.size() == 0 ? phi::DDim(dim_y) : phi::stride(dim_y);
const phi::DDim stride_out =
dim_out.size() == 0 ? phi::DDim(dim_out) : phi::stride(dim_out);
const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr,
*p_stride_out = nullptr, *p_shape_y = nullptr;
......@@ -150,10 +153,10 @@ struct KronOpFunctor {
};
template <typename T, typename Context>
void KronKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
void KronKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &y,
DenseTensor *out) {
ctx.template Alloc<T>(out);
int ndims = out->dims().size();
......
......@@ -293,6 +293,7 @@ binary_api_list = [
paddle.fmax,
paddle.fmin,
paddle.complex,
paddle.kron,
]
binary_int_api_list = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册