未验证 提交 e012930a 编写于 作者: C chentianyu03 提交者: GitHub

complex gradient matmul (#29966)

* dot op support complex types

* matmul support complex types

* add test case

* matmul broadcast gradient support complex

* move conjFunctor to complex_functor.h
上级 b0bd93de
......@@ -17,49 +17,13 @@
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
using EnableComplex =
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type;
template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type;
template <typename T, typename Enable = void>
struct ConjFunctor;
template <typename T>
struct ConjFunctor<T, EnableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[idx] = T(input_[idx].real, -input_[idx].imag);
}
const T* input_;
int64_t numel_;
T* output_;
};
template <typename T>
struct ConjFunctor<T, DisableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; }
const T* input_;
int64_t numel_;
T* output_;
};
template <typename DeviceContext, typename T>
class ConjKernel : public framework::OpKernel<T> {
public:
......@@ -74,7 +38,7 @@ class ConjKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
ConjFunctor<T> functor(x_data, numel, out_data);
math::ConjFunctor<T> functor(x_data, numel, out_data);
for_range(functor);
}
};
......
......@@ -152,9 +152,17 @@ REGISTER_OP_CPU_KERNEL(
dot, ops::DotKernel<paddle::platform::CPUDeviceContext, float>,
ops::DotKernel<paddle::platform::CPUDeviceContext, double>,
ops::DotKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
dot_grad, ops::DotGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
......@@ -17,12 +17,17 @@
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(dot, ops::DotKernel<plat::CUDADeviceContext, float>,
ops::DotKernel<plat::CUDADeviceContext, double>,
ops::DotKernel<plat::CUDADeviceContext, int>,
ops::DotKernel<plat::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(dot_grad,
ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
dot, ops::DotKernel<plat::CUDADeviceContext, float>,
ops::DotKernel<plat::CUDADeviceContext, double>,
ops::DotKernel<plat::CUDADeviceContext, int>,
ops::DotKernel<plat::CUDADeviceContext, int64_t>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
dot_grad, ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
......@@ -16,95 +16,233 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using complex64 = platform::complex64;
using complex128 = platform::complex128;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, typename R>
struct P {
void operator()(T a, R b);
};
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotGradFunction {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx);
};
template <typename DeviceContext, typename T>
void DotGradFunction(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
struct DotGradFunction<DeviceContext, T, math::EnableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
tensor_dx->data<T>());
for_range(functor);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
tensor_dy->data<T>());
for_range(functor);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
tensor_dx->data<T>());
for_range(functor);
auto dx = EigenMatrix<T>::From(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_x->numel());
math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
tensor_dy->data<T>());
for_range(functor);
auto dy = EigenMatrix<T>::From(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s];
}
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s];
}
}
#endif
}
};
template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s];
}
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s];
}
}
}
#endif
}
}
};
template <typename DeviceContext, typename T>
class DotKernel : public framework::OpKernel<T> {
......@@ -165,8 +303,8 @@ class DotGradKernel : public framework::OpKernel<T> {
if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
DotGradFunction<DeviceContext, T>(tensor_x, tensor_y, tensor_dout,
tensor_dx, tensor_dy, ctx);
DotGradFunction<DeviceContext, T>()(tensor_x, tensor_y, tensor_dout,
tensor_dx, tensor_dy, ctx);
}
};
......
......@@ -135,6 +135,43 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> {
int64_t numel_;
};
template <typename T>
using EnableComplex =
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type;
template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type;
template <typename T, typename Enable = void>
struct ConjFunctor;
template <typename T>
struct ConjFunctor<T, EnableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[idx] = T(input_[idx].real, -input_[idx].imag);
}
const T* input_;
int64_t numel_;
T* output_;
};
template <typename T>
struct ConjFunctor<T, DisableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; }
const T* input_;
int64_t numel_;
T* output_;
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/dot_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#ifdef __NVCC__
......@@ -468,6 +469,61 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}
template <typename DeviceContext, typename T>
struct ConjHelper {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
dst.set_layout(src.layout());
dst.ShareDataWith(src);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex64> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex64>();
auto* dst_data = dst.mutable_data<paddle::platform::complex64>(
ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex64)));
platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex64> functor(
src_data, src.numel(), dst_data);
for_range(functor);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex128> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex128>();
auto* dst_data = dst.mutable_data<paddle::platform::complex128>(
ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex128)));
platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex128> functor(
src_data, src.numel(), dst_data);
for_range(functor);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext, typename T>
class MatMulV2GradKernel : public framework::OpKernel<T> {
public:
......@@ -519,6 +575,8 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
auto x = *ctx.Input<framework::Tensor>("X");
auto y = *ctx.Input<framework::Tensor>("Y");
auto dout = *ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor y_conj(y.type());
framework::Tensor x_conj(y.type());
// get dims
std::vector<std::int64_t> x_dims = vectorize(x.dims());
......@@ -537,7 +595,7 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dx) dx->mutable_data<T>(ctx.GetPlace());
if (dy) dy->mutable_data<T>(ctx.GetPlace());
if (dout.numel() == 1) {
DotGradFunction<DeviceContext, T>(&x, &y, &dout, dx, dy, ctx);
DotGradFunction<DeviceContext, T>()(&x, &y, &dout, dx, dy, ctx);
return;
}
}
......@@ -562,6 +620,10 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
// for complex
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(y, y_conj);
}
framework::DDim dy_dims;
......@@ -570,19 +632,23 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
// for complex
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(x, x_conj);
}
if (transpose_x && transpose_y) {
CalcInputGrad(ctx, y, true, true, dout, true, false, dx);
CalcInputGrad(ctx, dout, true, true, x, true, false, dy);
CalcInputGrad(ctx, y_conj, true, true, dout, true, false, dx);
CalcInputGrad(ctx, dout, true, true, x_conj, true, false, dy);
} else if (transpose_x) {
CalcInputGrad(ctx, y, false, false, dout, true, false, dx);
CalcInputGrad(ctx, x, false, false, dout, false, true, dy);
CalcInputGrad(ctx, y_conj, false, false, dout, true, false, dx);
CalcInputGrad(ctx, x_conj, false, false, dout, false, true, dy);
} else if (transpose_y) {
CalcInputGrad(ctx, dout, false, false, y, false, true, dx);
CalcInputGrad(ctx, dout, true, true, x, false, true, dy);
CalcInputGrad(ctx, dout, false, false, y_conj, false, true, dx);
CalcInputGrad(ctx, dout, true, true, x_conj, false, true, dy);
} else {
CalcInputGrad(ctx, dout, false, false, y, true, false, dx);
CalcInputGrad(ctx, x, true, true, dout, false, true, dy);
CalcInputGrad(ctx, dout, false, false, y_conj, true, false, dx);
CalcInputGrad(ctx, x_conj, true, true, dout, false, true, dy);
}
if (dx) {
......@@ -602,40 +668,44 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
Tensor dx_help, dy_help;
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
if (transpose_x) {
if (transpose_y) {
// X'Y': dA = Y'G', dB = G'X'
if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
MatMulFunction<DeviceContext, T>(&y_conj, &dout, y_dims, dout_dims,
&dx_help, true, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
MatMulFunction<DeviceContext, T>(&dout, &x_conj, dout_dims, x_dims,
&dy_help, true, true, ctx);
} else {
// X'Y: dX = YG', dY = XG
if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
MatMulFunction<DeviceContext, T>(&y_conj, &dout, y_dims, dout_dims,
&dx_help, false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
MatMulFunction<DeviceContext, T>(&x_conj, &dout, x_dims, dout_dims,
&dy_help, false, false, ctx);
}
} else {
if (transpose_y) {
// XY': dX = GY, dY = G'X
if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
MatMulFunction<DeviceContext, T>(&dout, &y_conj, dout_dims, y_dims,
&dx_help, false, false, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
MatMulFunction<DeviceContext, T>(&dout, &x_conj, dout_dims, x_dims,
&dy_help, true, false, ctx);
} else {
// XY: dX = GY', dY = X'G
if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
MatMulFunction<DeviceContext, T>(&dout, &y_conj, dout_dims, y_dims,
&dx_help, false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
MatMulFunction<DeviceContext, T>(&x_conj, &dout, x_dims, dout_dims,
&dy_help, true, false, ctx);
}
}
......
......@@ -101,5 +101,127 @@ class TestDygraph(unittest.TestCase):
paddle.dot(x1, y1).numpy(), np.array([[17], [58]])))
class TestComplexDotOp(OpTest):
def setUp(self):
self.op_type = "dot"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(100).astype(
self.dtype) + 1J * np.random.random(100).astype(self.dtype)
self.y = np.random.random(100).astype(
self.dtype) + 1J * np.random.random(100).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones(1, self.dtype) + 1J * np.ones(1, self.dtype)
self.grad_x = self.grad_out * np.conj(self.y)
self.grad_y = self.grad_out * np.conj(self.x)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestComplexDotOp2D(OpTest):
def setUp(self):
self.op_type = "dot"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(2, 100)).astype(self.dtype) + 1J * np.random.random(
(2, 100)).astype(self.dtype)
self.y = np.random.random(
(2, 100)).astype(self.dtype) + 1J * np.random.random(
(2, 100)).astype(self.dtype)
self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1, 1)
def init_grad_input_output(self):
self.grad_out = np.ones((2, 1), self.dtype) + 1J * np.ones(
(2, 1), self.dtype)
self.grad_x = self._get_grad(self.grad_out, self.y)
self.grad_y = self._get_grad(self.grad_out, self.x)
def _get_grad(self, grad_out, input):
grad = np.empty((0, input.shape[1]))
for i in range(grad_out.shape[0]):
grad = np.append(grad, [grad_out[i] * np.conj(input[i])], axis=0)
return grad
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -405,5 +405,126 @@ class TestMatMulV2API(unittest.TestCase):
result = paddle.matmul(x, y)
class TestComplexMatMulOp(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(10, 10)).astype(self.dtype) + 1J * np.random.random(
(10, 10)).astype(self.dtype)
self.y = np.random.random(
(10, 10)).astype(self.dtype) + 1J * np.random.random(
(10, 10)).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones((10, 10), self.dtype) + 1J * np.ones(
(10, 10), self.dtype)
self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestComplexMatMulOpBroadcast(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(10, 2, 5)).astype(self.dtype) + 1J * np.random.random(
(10, 2, 5)).astype(self.dtype)
self.y = np.random.random(
(5, 20)).astype(self.dtype) + 1J * np.random.random(
(5, 20)).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones((10, 2, 20), self.dtype) + 1J * np.ones(
(10, 2, 20), self.dtype)
self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
self.grad_y = np.sum(np.matmul(
np.conj(self.x).transpose(0, 2, 1), self.grad_out),
axis=0)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -59,6 +59,7 @@ NEED_TO_FIX_OP_LIST = [
'lstmp',
'margin_rank_loss',
'matmul',
'matmul_v2',
'mul',
'multiplex',
'rank_loss',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册