未验证 提交 e012930a 编写于 作者: C chentianyu03 提交者: GitHub

complex gradient matmul (#29966)

* dot op support complex types

* matmul support complex types

* add test case

* matmul broadcast gradient support complex

* move conjFunctor to complex_functor.h
上级 b0bd93de
...@@ -17,49 +17,13 @@ ...@@ -17,49 +17,13 @@
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T>
using EnableComplex =
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type;
template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type;
template <typename T, typename Enable = void>
struct ConjFunctor;
template <typename T>
struct ConjFunctor<T, EnableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[idx] = T(input_[idx].real, -input_[idx].imag);
}
const T* input_;
int64_t numel_;
T* output_;
};
template <typename T>
struct ConjFunctor<T, DisableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; }
const T* input_;
int64_t numel_;
T* output_;
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ConjKernel : public framework::OpKernel<T> { class ConjKernel : public framework::OpKernel<T> {
public: public:
...@@ -74,7 +38,7 @@ class ConjKernel : public framework::OpKernel<T> { ...@@ -74,7 +38,7 @@ class ConjKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel); platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
ConjFunctor<T> functor(x_data, numel, out_data); math::ConjFunctor<T> functor(x_data, numel, out_data);
for_range(functor); for_range(functor);
} }
}; };
......
...@@ -152,9 +152,17 @@ REGISTER_OP_CPU_KERNEL( ...@@ -152,9 +152,17 @@ REGISTER_OP_CPU_KERNEL(
dot, ops::DotKernel<paddle::platform::CPUDeviceContext, float>, dot, ops::DotKernel<paddle::platform::CPUDeviceContext, float>,
ops::DotKernel<paddle::platform::CPUDeviceContext, double>, ops::DotKernel<paddle::platform::CPUDeviceContext, double>,
ops::DotKernel<paddle::platform::CPUDeviceContext, int>, ops::DotKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dot_grad, ops::DotGradKernel<paddle::platform::CPUDeviceContext, float>, dot_grad, ops::DotGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, double>, ops::DotGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int>, ops::DotGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
...@@ -17,12 +17,17 @@ ...@@ -17,12 +17,17 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(dot, ops::DotKernel<plat::CUDADeviceContext, float>, REGISTER_OP_CUDA_KERNEL(
ops::DotKernel<plat::CUDADeviceContext, double>, dot, ops::DotKernel<plat::CUDADeviceContext, float>,
ops::DotKernel<plat::CUDADeviceContext, int>, ops::DotKernel<plat::CUDADeviceContext, double>,
ops::DotKernel<plat::CUDADeviceContext, int64_t>); ops::DotKernel<plat::CUDADeviceContext, int>,
REGISTER_OP_CUDA_KERNEL(dot_grad, ops::DotKernel<plat::CUDADeviceContext, int64_t>,
ops::DotGradKernel<plat::CUDADeviceContext, float>, ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotGradKernel<plat::CUDADeviceContext, double>, ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
ops::DotGradKernel<plat::CUDADeviceContext, int>, REGISTER_OP_CUDA_KERNEL(
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>); dot_grad, ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
...@@ -16,95 +16,233 @@ ...@@ -16,95 +16,233 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using complex64 = platform::complex64;
using complex128 = platform::complex128;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, typename R>
struct P {
void operator()(T a, R b);
};
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotGradFunction {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx);
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void DotGradFunction(const Tensor* tensor_x, const Tensor* tensor_y, struct DotGradFunction<DeviceContext, T, math::EnableComplex<T>> {
const Tensor* tensor_dout, Tensor* tensor_dx, void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
Tensor* tensor_dy, const Tensor* tensor_dout, Tensor* tensor_dx,
const paddle::framework::ExecutionContext& ctx) { Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__ #ifdef __NVCC__
if (1 == tensor_dout->dims().size()) { if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout); auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) { if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y); auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx); auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device(); auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel()); Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) { paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
auto x = framework::EigenVector<T>::Flatten(*tensor_x); tensor_y->numel());
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy); math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device(); tensor_dx->data<T>());
Eigen::DSizes<int, 1> size(tensor_dy->numel()); for_range(functor);
dy.device(dev) = x * dout.broadcast(size); auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
tensor_dy->data<T>());
for_range(functor);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
tensor_dx->data<T>());
for_range(functor);
auto dx = EigenMatrix<T>::From(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_x->numel());
math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
tensor_dy->data<T>());
for_range(functor);
auto dy = EigenMatrix<T>::From(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
} }
} else { #else
auto dout = EigenMatrix<T>::From(*tensor_dout); const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) { if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace()); auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y); const auto* data_y = tensor_y->data<T>();
auto dx = EigenMatrix<T>::From(*tensor_dx); const framework::DDim& dim = tensor_x->dims();
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device(); size_t N = static_cast<size_t>(framework::product(dim));
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size); auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s];
}
} }
if (tensor_dy) { if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace()); auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x); const auto* data_x = tensor_x->data<T>();
auto dy = EigenMatrix<T>::From(*tensor_dy); const framework::DDim& dim = tensor_y->dims();
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device(); size_t N = static_cast<size_t>(framework::product(dim));
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size); auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s];
}
} }
#endif
} }
};
template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else #else
const auto* data_dout = tensor_dout->data<T>(); const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) { if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace()); auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>(); const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims(); const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim)); size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1]; auto step = dim[dim.size() - 1];
int s = -1; int s = -1;
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s; if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s]; data_dx[i] = data_y[i] * data_dout[s];
}
} }
}
if (tensor_dy) { if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace()); auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>(); const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims(); const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim)); size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1]; auto step = dim[dim.size() - 1];
int s = -1; int s = -1;
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s; if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s]; data_dy[i] = data_x[i] * data_dout[s];
}
} }
}
#endif #endif
} }
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class DotKernel : public framework::OpKernel<T> { class DotKernel : public framework::OpKernel<T> {
...@@ -165,8 +303,8 @@ class DotGradKernel : public framework::OpKernel<T> { ...@@ -165,8 +303,8 @@ class DotGradKernel : public framework::OpKernel<T> {
if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace()); if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace()); if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
DotGradFunction<DeviceContext, T>(tensor_x, tensor_y, tensor_dout, DotGradFunction<DeviceContext, T>()(tensor_x, tensor_y, tensor_dout,
tensor_dx, tensor_dy, ctx); tensor_dx, tensor_dy, ctx);
} }
}; };
......
...@@ -135,6 +135,43 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> { ...@@ -135,6 +135,43 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> {
int64_t numel_; int64_t numel_;
}; };
template <typename T>
using EnableComplex =
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type;
template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type;
template <typename T, typename Enable = void>
struct ConjFunctor;
template <typename T>
struct ConjFunctor<T, EnableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[idx] = T(input_[idx].real, -input_[idx].imag);
}
const T* input_;
int64_t numel_;
T* output_;
};
template <typename T>
struct ConjFunctor<T, DisableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; }
const T* input_;
int64_t numel_;
T* output_;
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/dot_op.h" #include "paddle/fluid/operators/dot_op.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#ifdef __NVCC__ #ifdef __NVCC__
...@@ -468,6 +469,61 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, ...@@ -468,6 +469,61 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
ReshapeTensorIntoMatrixSequence(y, mat_dim_y); ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
} }
template <typename DeviceContext, typename T>
struct ConjHelper {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
dst.set_layout(src.layout());
dst.ShareDataWith(src);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex64> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex64>();
auto* dst_data = dst.mutable_data<paddle::platform::complex64>(
ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex64)));
platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex64> functor(
src_data, src.numel(), dst_data);
for_range(functor);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex128> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex128>();
auto* dst_data = dst.mutable_data<paddle::platform::complex128>(
ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex128)));
platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex128> functor(
src_data, src.numel(), dst_data);
for_range(functor);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatMulV2GradKernel : public framework::OpKernel<T> { class MatMulV2GradKernel : public framework::OpKernel<T> {
public: public:
...@@ -519,6 +575,8 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -519,6 +575,8 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
auto x = *ctx.Input<framework::Tensor>("X"); auto x = *ctx.Input<framework::Tensor>("X");
auto y = *ctx.Input<framework::Tensor>("Y"); auto y = *ctx.Input<framework::Tensor>("Y");
auto dout = *ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto dout = *ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor y_conj(y.type());
framework::Tensor x_conj(y.type());
// get dims // get dims
std::vector<std::int64_t> x_dims = vectorize(x.dims()); std::vector<std::int64_t> x_dims = vectorize(x.dims());
...@@ -537,7 +595,7 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -537,7 +595,7 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dx) dx->mutable_data<T>(ctx.GetPlace()); if (dx) dx->mutable_data<T>(ctx.GetPlace());
if (dy) dy->mutable_data<T>(ctx.GetPlace()); if (dy) dy->mutable_data<T>(ctx.GetPlace());
if (dout.numel() == 1) { if (dout.numel() == 1) {
DotGradFunction<DeviceContext, T>(&x, &y, &dout, dx, dy, ctx); DotGradFunction<DeviceContext, T>()(&x, &y, &dout, dx, dy, ctx);
return; return;
} }
} }
...@@ -562,6 +620,10 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -562,6 +620,10 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dx_dims != x.dims()) { if (dx_dims != x.dims()) {
dx->Resize(x.dims()); dx->Resize(x.dims());
} }
// for complex
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(y, y_conj);
} }
framework::DDim dy_dims; framework::DDim dy_dims;
...@@ -570,19 +632,23 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -570,19 +632,23 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dy_dims != y.dims()) { if (dy_dims != y.dims()) {
dy->Resize(y.dims()); dy->Resize(y.dims());
} }
// for complex
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(x, x_conj);
} }
if (transpose_x && transpose_y) { if (transpose_x && transpose_y) {
CalcInputGrad(ctx, y, true, true, dout, true, false, dx); CalcInputGrad(ctx, y_conj, true, true, dout, true, false, dx);
CalcInputGrad(ctx, dout, true, true, x, true, false, dy); CalcInputGrad(ctx, dout, true, true, x_conj, true, false, dy);
} else if (transpose_x) { } else if (transpose_x) {
CalcInputGrad(ctx, y, false, false, dout, true, false, dx); CalcInputGrad(ctx, y_conj, false, false, dout, true, false, dx);
CalcInputGrad(ctx, x, false, false, dout, false, true, dy); CalcInputGrad(ctx, x_conj, false, false, dout, false, true, dy);
} else if (transpose_y) { } else if (transpose_y) {
CalcInputGrad(ctx, dout, false, false, y, false, true, dx); CalcInputGrad(ctx, dout, false, false, y_conj, false, true, dx);
CalcInputGrad(ctx, dout, true, true, x, false, true, dy); CalcInputGrad(ctx, dout, true, true, x_conj, false, true, dy);
} else { } else {
CalcInputGrad(ctx, dout, false, false, y, true, false, dx); CalcInputGrad(ctx, dout, false, false, y_conj, true, false, dx);
CalcInputGrad(ctx, x, true, true, dout, false, true, dy); CalcInputGrad(ctx, x_conj, true, true, dout, false, true, dy);
} }
if (dx) { if (dx) {
...@@ -602,40 +668,44 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -602,40 +668,44 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
VLOG(3) << "It need cost much time to reduce sum for the broadcast and " VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality"; "wastes the memory. So we should avoid the case in reality";
Tensor dx_help, dy_help; Tensor dx_help, dy_help;
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
if (transpose_x) { if (transpose_x) {
if (transpose_y) { if (transpose_y) {
// X'Y': dA = Y'G', dB = G'X' // X'Y': dA = Y'G', dB = G'X'
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, MatMulFunction<DeviceContext, T>(&y_conj, &dout, y_dims, dout_dims,
&dx_help, true, true, ctx); &dx_help, true, true, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, MatMulFunction<DeviceContext, T>(&dout, &x_conj, dout_dims, x_dims,
&dy_help, true, true, ctx); &dy_help, true, true, ctx);
} else { } else {
// X'Y: dX = YG', dY = XG // X'Y: dX = YG', dY = XG
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims, MatMulFunction<DeviceContext, T>(&y_conj, &dout, y_dims, dout_dims,
&dx_help, false, true, ctx); &dx_help, false, true, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, MatMulFunction<DeviceContext, T>(&x_conj, &dout, x_dims, dout_dims,
&dy_help, false, false, ctx); &dy_help, false, false, ctx);
} }
} else { } else {
if (transpose_y) { if (transpose_y) {
// XY': dX = GY, dY = G'X // XY': dX = GY, dY = G'X
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, MatMulFunction<DeviceContext, T>(&dout, &y_conj, dout_dims, y_dims,
&dx_help, false, false, ctx); &dx_help, false, false, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims, MatMulFunction<DeviceContext, T>(&dout, &x_conj, dout_dims, x_dims,
&dy_help, true, false, ctx); &dy_help, true, false, ctx);
} else { } else {
// XY: dX = GY', dY = X'G // XY: dX = GY', dY = X'G
if (dx) if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims, MatMulFunction<DeviceContext, T>(&dout, &y_conj, dout_dims, y_dims,
&dx_help, false, true, ctx); &dx_help, false, true, ctx);
if (dy) if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims, MatMulFunction<DeviceContext, T>(&x_conj, &dout, x_dims, dout_dims,
&dy_help, true, false, ctx); &dy_help, true, false, ctx);
} }
} }
......
...@@ -101,5 +101,127 @@ class TestDygraph(unittest.TestCase): ...@@ -101,5 +101,127 @@ class TestDygraph(unittest.TestCase):
paddle.dot(x1, y1).numpy(), np.array([[17], [58]]))) paddle.dot(x1, y1).numpy(), np.array([[17], [58]])))
class TestComplexDotOp(OpTest):
def setUp(self):
self.op_type = "dot"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(100).astype(
self.dtype) + 1J * np.random.random(100).astype(self.dtype)
self.y = np.random.random(100).astype(
self.dtype) + 1J * np.random.random(100).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones(1, self.dtype) + 1J * np.ones(1, self.dtype)
self.grad_x = self.grad_out * np.conj(self.y)
self.grad_y = self.grad_out * np.conj(self.x)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestComplexDotOp2D(OpTest):
def setUp(self):
self.op_type = "dot"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(2, 100)).astype(self.dtype) + 1J * np.random.random(
(2, 100)).astype(self.dtype)
self.y = np.random.random(
(2, 100)).astype(self.dtype) + 1J * np.random.random(
(2, 100)).astype(self.dtype)
self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1, 1)
def init_grad_input_output(self):
self.grad_out = np.ones((2, 1), self.dtype) + 1J * np.ones(
(2, 1), self.dtype)
self.grad_x = self._get_grad(self.grad_out, self.y)
self.grad_y = self._get_grad(self.grad_out, self.x)
def _get_grad(self, grad_out, input):
grad = np.empty((0, input.shape[1]))
for i in range(grad_out.shape[0]):
grad = np.append(grad, [grad_out[i] * np.conj(input[i])], axis=0)
return grad
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -405,5 +405,126 @@ class TestMatMulV2API(unittest.TestCase): ...@@ -405,5 +405,126 @@ class TestMatMulV2API(unittest.TestCase):
result = paddle.matmul(x, y) result = paddle.matmul(x, y)
class TestComplexMatMulOp(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(10, 10)).astype(self.dtype) + 1J * np.random.random(
(10, 10)).astype(self.dtype)
self.y = np.random.random(
(10, 10)).astype(self.dtype) + 1J * np.random.random(
(10, 10)).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones((10, 10), self.dtype) + 1J * np.ones(
(10, 10), self.dtype)
self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestComplexMatMulOpBroadcast(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(10, 2, 5)).astype(self.dtype) + 1J * np.random.random(
(10, 2, 5)).astype(self.dtype)
self.y = np.random.random(
(5, 20)).astype(self.dtype) + 1J * np.random.random(
(5, 20)).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones((10, 2, 20), self.dtype) + 1J * np.ones(
(10, 2, 20), self.dtype)
self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
self.grad_y = np.sum(np.matmul(
np.conj(self.x).transpose(0, 2, 1), self.grad_out),
axis=0)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
...@@ -59,6 +59,7 @@ NEED_TO_FIX_OP_LIST = [ ...@@ -59,6 +59,7 @@ NEED_TO_FIX_OP_LIST = [
'lstmp', 'lstmp',
'margin_rank_loss', 'margin_rank_loss',
'matmul', 'matmul',
'matmul_v2',
'mul', 'mul',
'multiplex', 'multiplex',
'rank_loss', 'rank_loss',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册