未验证 提交 7346edc2 编写于 作者: C chentianyu03 提交者: GitHub

[Cherry-pick] Complex grad for matmul, kron and type promotion (#30304)

* complex gradient matmul  (#29966)

* dot op support complex types

* matmul support complex types

* add test case

* matmul broadcast gradient support complex

* move conjFunctor to complex_functor.h

* change the kron gradient when complex types (#29995)

* type promotion for grad (#30177)

* type promotion for grad

* add type promotion for div op
上级 501b11de
......@@ -17,49 +17,13 @@
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
using EnableComplex =
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type;
template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type;
template <typename T, typename Enable = void>
struct ConjFunctor;
template <typename T>
struct ConjFunctor<T, EnableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[idx] = T(input_[idx].real, -input_[idx].imag);
}
const T* input_;
int64_t numel_;
T* output_;
};
template <typename T>
struct ConjFunctor<T, DisableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; }
const T* input_;
int64_t numel_;
T* output_;
};
template <typename DeviceContext, typename T>
class ConjKernel : public framework::OpKernel<T> {
public:
......@@ -74,7 +38,7 @@ class ConjKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
ConjFunctor<T> functor(x_data, numel, out_data);
math::ConjFunctor<T> functor(x_data, numel, out_data);
for_range(functor);
}
};
......
......@@ -152,9 +152,17 @@ REGISTER_OP_CPU_KERNEL(
dot, ops::DotKernel<paddle::platform::CPUDeviceContext, float>,
ops::DotKernel<paddle::platform::CPUDeviceContext, double>,
ops::DotKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
dot_grad, ops::DotGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
......@@ -17,12 +17,17 @@
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(dot, ops::DotKernel<plat::CUDADeviceContext, float>,
ops::DotKernel<plat::CUDADeviceContext, double>,
ops::DotKernel<plat::CUDADeviceContext, int>,
ops::DotKernel<plat::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(dot_grad,
ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
dot, ops::DotKernel<plat::CUDADeviceContext, float>,
ops::DotKernel<plat::CUDADeviceContext, double>,
ops::DotKernel<plat::CUDADeviceContext, int>,
ops::DotKernel<plat::CUDADeviceContext, int64_t>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
dot_grad, ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
......@@ -16,95 +16,233 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using complex64 = platform::complex64;
using complex128 = platform::complex128;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, typename R>
struct P {
void operator()(T a, R b);
};
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotGradFunction {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx);
};
template <typename DeviceContext, typename T>
void DotGradFunction(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
struct DotGradFunction<DeviceContext, T, math::EnableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
tensor_dx->data<T>());
for_range(functor);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
tensor_dy->data<T>());
for_range(functor);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
tensor_dx->data<T>());
for_range(functor);
auto dx = EigenMatrix<T>::From(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_x->numel());
math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
tensor_dy->data<T>());
for_range(functor);
auto dy = EigenMatrix<T>::From(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s];
}
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s];
}
}
#endif
}
};
template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#ifdef __NVCC__
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_y[i] * data_dout[s];
}
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_x[i] * data_dout[s];
}
}
}
#endif
}
}
};
template <typename DeviceContext, typename T>
class DotKernel : public framework::OpKernel<T> {
......@@ -165,8 +303,8 @@ class DotGradKernel : public framework::OpKernel<T> {
if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
DotGradFunction<DeviceContext, T>(tensor_x, tensor_y, tensor_dout,
tensor_dx, tensor_dy, ctx);
DotGradFunction<DeviceContext, T>()(tensor_x, tensor_y, tensor_dout,
tensor_dx, tensor_dy, ctx);
}
};
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
......@@ -203,7 +204,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
......@@ -214,6 +215,19 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
};
template <typename DeviceContext, typename T>
......
......@@ -288,6 +288,19 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
};
class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
......@@ -325,6 +338,19 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
};
class ElementwiseOpDoubleGradWithoutDXDY
......
......@@ -134,6 +134,19 @@ class KronGradOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
};
template <typename T>
......
......@@ -26,6 +26,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
// Process an element in the output, used with a parallel-for
template <typename T>
struct KronElemFunctor {
......@@ -172,6 +175,128 @@ struct KronGradElemFunctor {
const int ndims_;
};
template <>
struct KronGradElemFunctor<complex64> {
KronGradElemFunctor(const complex64* dout, const complex64* A,
const complex64* B, complex64* dout_a, complex64* dout_b,
const int64_t* stride_dout, const int64_t* stride_a,
const int64_t* stride_b, const int64_t* shape_b,
const int64_t numel_a, const int64_t numel_b,
const int ndims)
: dout_(dout),
A_(A),
B_(B),
dout_a_(dout_a),
dout_b_(dout_b),
stride_dout_(stride_dout),
stride_a_(stride_a),
stride_b_(stride_b),
shape_b_(shape_b),
numel_a_(numel_a),
numel_b_(numel_b),
ndims_(ndims) {}
HOSTDEVICE void operator()(int64_t idx) {
int64_t index = idx;
int64_t index_a = 0;
int64_t index_b = 0;
for (int i = 0; i < ndims_; i++) {
auto pos_i = index / stride_dout_[i];
index = index % stride_dout_[i];
auto pos_ai = pos_i / shape_b_[i];
auto pos_bi = pos_i % shape_b_[i];
index_a += stride_a_[i] * pos_ai;
index_b += stride_b_[i] * pos_bi;
}
if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] =
dout_[idx] * complex64(B_[index_b].real, -B_[index_b].imag);
}
if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] =
dout_[idx] * complex64(A_[index_a].real, -A_[index_a].imag);
}
}
private:
const complex64* dout_;
const complex64* A_;
const complex64* B_;
complex64* dout_a_;
complex64* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* shape_b_;
const int64_t numel_a_;
const int64_t numel_b_;
const int ndims_;
};
template <>
struct KronGradElemFunctor<complex128> {
KronGradElemFunctor(const complex128* dout, const complex128* A,
const complex128* B, complex128* dout_a,
complex128* dout_b, const int64_t* stride_dout,
const int64_t* stride_a, const int64_t* stride_b,
const int64_t* shape_b, const int64_t numel_a,
const int64_t numel_b, const int ndims)
: dout_(dout),
A_(A),
B_(B),
dout_a_(dout_a),
dout_b_(dout_b),
stride_dout_(stride_dout),
stride_a_(stride_a),
stride_b_(stride_b),
shape_b_(shape_b),
numel_a_(numel_a),
numel_b_(numel_b),
ndims_(ndims) {}
HOSTDEVICE void operator()(int64_t idx) {
int64_t index = idx;
int64_t index_a = 0;
int64_t index_b = 0;
for (int i = 0; i < ndims_; i++) {
auto pos_i = index / stride_dout_[i];
index = index % stride_dout_[i];
auto pos_ai = pos_i / shape_b_[i];
auto pos_bi = pos_i % shape_b_[i];
index_a += stride_a_[i] * pos_ai;
index_b += stride_b_[i] * pos_bi;
}
if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] =
dout_[idx] * complex128(B_[index_b].real, -B_[index_b].imag);
}
if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] =
dout_[idx] * complex128(A_[index_a].real, -A_[index_a].imag);
}
}
private:
const complex128* dout_;
const complex128* A_;
const complex128* B_;
complex128* dout_a_;
complex128* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* shape_b_;
const int64_t numel_a_;
const int64_t numel_b_;
const int ndims_;
};
template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
......
......@@ -135,6 +135,43 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> {
int64_t numel_;
};
template <typename T>
using EnableComplex =
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type;
template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type;
template <typename T, typename Enable = void>
struct ConjFunctor;
template <typename T>
struct ConjFunctor<T, EnableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[idx] = T(input_[idx].real, -input_[idx].imag);
}
const T* input_;
int64_t numel_;
T* output_;
};
template <typename T>
struct ConjFunctor<T, DisableComplex<T>> {
ConjFunctor(const T* input, int64_t numel, T* output)
: input_(input), numel_(numel), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const { output_[idx] = input_[idx]; }
const T* input_;
int64_t numel_;
T* output_;
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -150,6 +150,27 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
context->SetOutputDim(y_grad_name, y_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
};
template <typename T>
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/dot_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#ifdef __NVCC__
......@@ -439,6 +440,61 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}
template <typename DeviceContext, typename T>
struct ConjHelper {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
dst.set_layout(src.layout());
dst.ShareDataWith(src);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex64> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex64>();
auto* dst_data = dst.mutable_data<paddle::platform::complex64>(
ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex64)));
platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex64> functor(
src_data, src.numel(), dst_data);
for_range(functor);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex128> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex128>();
auto* dst_data = dst.mutable_data<paddle::platform::complex128>(
ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex128)));
platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex128> functor(
src_data, src.numel(), dst_data);
for_range(functor);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext, typename T>
class MatMulV2GradKernel : public framework::OpKernel<T> {
public:
......@@ -490,6 +546,8 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
auto x = *ctx.Input<framework::Tensor>("X");
auto y = *ctx.Input<framework::Tensor>("Y");
auto dout = *ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor y_conj(y.type());
framework::Tensor x_conj(y.type());
// get dims
std::vector<std::int64_t> x_dims = vectorize(x.dims());
......@@ -508,7 +566,7 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dx) dx->mutable_data<T>(ctx.GetPlace());
if (dy) dy->mutable_data<T>(ctx.GetPlace());
if (dout.numel() == 1) {
DotGradFunction<DeviceContext, T>(&x, &y, &dout, dx, dy, ctx);
DotGradFunction<DeviceContext, T>()(&x, &y, &dout, dx, dy, ctx);
return;
}
}
......@@ -533,6 +591,10 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
// for complex
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(y, y_conj);
}
framework::DDim dy_dims;
......@@ -541,19 +603,23 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
// for complex
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(x, x_conj);
}
if (transpose_x && transpose_y) {
CalcInputGrad(ctx, y, true, true, dout, true, false, dx);
CalcInputGrad(ctx, dout, true, true, x, true, false, dy);
CalcInputGrad(ctx, y_conj, true, true, dout, true, false, dx);
CalcInputGrad(ctx, dout, true, true, x_conj, true, false, dy);
} else if (transpose_x) {
CalcInputGrad(ctx, y, false, false, dout, true, false, dx);
CalcInputGrad(ctx, x, false, false, dout, false, true, dy);
CalcInputGrad(ctx, y_conj, false, false, dout, true, false, dx);
CalcInputGrad(ctx, x_conj, false, false, dout, false, true, dy);
} else if (transpose_y) {
CalcInputGrad(ctx, dout, false, false, y, false, true, dx);
CalcInputGrad(ctx, dout, true, true, x, false, true, dy);
CalcInputGrad(ctx, dout, false, false, y_conj, false, true, dx);
CalcInputGrad(ctx, dout, true, true, x_conj, false, true, dy);
} else {
CalcInputGrad(ctx, dout, false, false, y, true, false, dx);
CalcInputGrad(ctx, x, true, true, dout, false, true, dy);
CalcInputGrad(ctx, dout, false, false, y_conj, true, false, dx);
CalcInputGrad(ctx, x_conj, true, true, dout, false, true, dy);
}
if (dx) {
......@@ -573,40 +639,44 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
Tensor dx_help, dy_help;
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
if (transpose_x) {
if (transpose_y) {
// X'Y': dA = Y'G', dB = G'X'
if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
MatMulFunction<DeviceContext, T>(&y_conj, &dout, y_dims, dout_dims,
&dx_help, true, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
MatMulFunction<DeviceContext, T>(&dout, &x_conj, dout_dims, x_dims,
&dy_help, true, true, ctx);
} else {
// X'Y: dX = YG', dY = XG
if (dx)
MatMulFunction<DeviceContext, T>(&y, &dout, y_dims, dout_dims,
MatMulFunction<DeviceContext, T>(&y_conj, &dout, y_dims, dout_dims,
&dx_help, false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
MatMulFunction<DeviceContext, T>(&x_conj, &dout, x_dims, dout_dims,
&dy_help, false, false, ctx);
}
} else {
if (transpose_y) {
// XY': dX = GY, dY = G'X
if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
MatMulFunction<DeviceContext, T>(&dout, &y_conj, dout_dims, y_dims,
&dx_help, false, false, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x, dout_dims, x_dims,
MatMulFunction<DeviceContext, T>(&dout, &x_conj, dout_dims, x_dims,
&dy_help, true, false, ctx);
} else {
// XY: dX = GY', dY = X'G
if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y, dout_dims, y_dims,
MatMulFunction<DeviceContext, T>(&dout, &y_conj, dout_dims, y_dims,
&dx_help, false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&x, &dout, x_dims, dout_dims,
MatMulFunction<DeviceContext, T>(&x_conj, &dout, x_dims, dout_dims,
&dy_help, true, false, ctx);
}
}
......
......@@ -101,5 +101,127 @@ class TestDygraph(unittest.TestCase):
paddle.dot(x1, y1).numpy(), np.array([[17], [58]])))
class TestComplexDotOp(OpTest):
def setUp(self):
self.op_type = "dot"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(100).astype(
self.dtype) + 1J * np.random.random(100).astype(self.dtype)
self.y = np.random.random(100).astype(
self.dtype) + 1J * np.random.random(100).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones(1, self.dtype) + 1J * np.ones(1, self.dtype)
self.grad_x = self.grad_out * np.conj(self.y)
self.grad_y = self.grad_out * np.conj(self.x)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestComplexDotOp2D(OpTest):
def setUp(self):
self.op_type = "dot"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(2, 100)).astype(self.dtype) + 1J * np.random.random(
(2, 100)).astype(self.dtype)
self.y = np.random.random(
(2, 100)).astype(self.dtype) + 1J * np.random.random(
(2, 100)).astype(self.dtype)
self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1, 1)
def init_grad_input_output(self):
self.grad_out = np.ones((2, 1), self.dtype) + 1J * np.ones(
(2, 1), self.dtype)
self.grad_x = self._get_grad(self.grad_out, self.y)
self.grad_y = self._get_grad(self.grad_out, self.x)
def _get_grad(self, grad_out, input):
grad = np.empty((0, input.shape[1]))
for i in range(grad_out.shape[0]):
grad = np.append(grad, [grad_out[i] * np.conj(input[i])], axis=0)
return grad
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -320,6 +320,21 @@ class TestComplexElementwiseDivOp(OpTest):
user_defined_grad_outputs=[self.grad_out])
class TestRealComplexElementwiseDivOp(TestComplexElementwiseDivOp):
def init_input_output(self):
self.x = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.y = np.random.random(
(2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random(
(2, 3, 4, 5)).astype(self.dtype)
self.out = self.x / self.y
def init_grad_input_output(self):
self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones(
(2, 3, 4, 5), self.dtype)
self.grad_x = np.real(self.grad_out / np.conj(self.y))
self.grad_y = -self.grad_out * np.conj(self.x / self.y / self.y)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -304,6 +304,21 @@ class TestComplexElementwiseMulOp(OpTest):
user_defined_grad_outputs=[self.grad_out])
class TestRealComplexElementwiseMulOp(TestComplexElementwiseMulOp):
def init_input_output(self):
self.x = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.y = np.random.random(
(2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random(
(2, 3, 4, 5)).astype(self.dtype)
self.out = self.x * self.y
def init_grad_input_output(self):
self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones(
(2, 3, 4, 5), self.dtype)
self.grad_x = np.real(self.grad_out * np.conj(self.y))
self.grad_y = self.grad_out * np.conj(self.x)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -15,6 +15,7 @@
from __future__ import print_function
import unittest
import numpy as np
import paddle
from op_test import OpTest, skip_check_grad_ci
......@@ -164,5 +165,78 @@ class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp):
}
class TestComplexElementwiseSubOp(OpTest):
def setUp(self):
self.op_type = "elementwise_sub"
self.dtype = np.float64
self.shape = (2, 3, 4, 5)
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(self.shape).astype(
self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype)
self.y = np.random.random(self.shape).astype(
self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype)
self.out = self.x - self.y
def init_grad_input_output(self):
self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones(
self.shape, self.dtype)
self.grad_x = self.grad_out
self.grad_y = -self.grad_out
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestRealComplexElementwiseSubOp(TestComplexElementwiseSubOp):
def init_input_output(self):
self.x = np.random.random(self.shape).astype(self.dtype)
self.y = np.random.random(self.shape).astype(
self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype)
self.out = self.x - self.y
def init_grad_input_output(self):
self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones(
self.shape, self.dtype)
self.grad_x = np.real(self.grad_out)
self.grad_y = -self.grad_out
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -102,5 +102,104 @@ class TestKronLayer(unittest.TestCase):
np.testing.assert_allclose(c, np.kron(a, b))
class TestComplexKronOp(OpTest):
def setUp(self):
self.op_type = "kron"
self.x_shape = np.array([10, 10])
self.y_shape = np.array([3, 35])
self.out_shape = self.x_shape * self.y_shape
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(self.x_shape).astype(
self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(
self.dtype) + 1J * np.random.random(self.y_shape).astype(self.dtype)
self.out = np.kron(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones(self.out_shape, self.dtype) + 1J * np.ones(
self.out_shape, self.dtype)
self.grad_x = self.get_grad_x_by_numpy()
self.grad_y = self.get_grad_y_by_numpy()
def get_grad_x_by_numpy(self):
grad_x = np.zeros(self.x_shape, np.complex)
for x_i in range(self.x_shape[0]):
for x_j in range(self.x_shape[1]):
for i in range(self.y_shape[0]):
for j in range(self.y_shape[1]):
idx_i = x_i * self.y_shape[0] + i
idx_j = x_j * self.y_shape[1] + j
grad_x[x_i][x_j] += self.grad_out[idx_i][
idx_j] * np.conj(self.y[i][j])
return grad_x
def get_grad_y_by_numpy(self):
grad_y = np.zeros(self.y_shape, np.complex)
for y_i in range(self.y_shape[0]):
for y_j in range(self.y_shape[1]):
for x_i in range(self.x_shape[0]):
for x_j in range(self.x_shape[1]):
idx_i = x_i * self.y_shape[0] + y_i
idx_j = x_j * self.y_shape[1] + y_j
grad_y[y_i][y_j] += self.grad_out[idx_i][
idx_j] * np.conj(self.x[x_i][x_j])
return grad_y
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestKronOpTypePromotion(TestComplexKronOp):
def init_input_output(self):
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(
self.dtype) + 1J * np.random.random(self.y_shape).astype(self.dtype)
self.out = np.kron(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones(self.out_shape, self.dtype) + 1J * np.ones(
self.out_shape, self.dtype)
self.grad_x = self.get_grad_x_by_numpy().real
self.grad_y = self.get_grad_y_by_numpy()
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -405,5 +405,141 @@ class TestMatMulV2API(unittest.TestCase):
result = paddle.matmul(x, y)
class TestComplexMatMulOp(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(10, 10)).astype(self.dtype) + 1J * np.random.random(
(10, 10)).astype(self.dtype)
self.y = np.random.random(
(10, 10)).astype(self.dtype) + 1J * np.random.random(
(10, 10)).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones((10, 10), self.dtype) + 1J * np.ones(
(10, 10), self.dtype)
self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestComplexMatMulOpBroadcast(OpTest):
def setUp(self):
self.op_type = "matmul_v2"
self.init_base_dtype()
self.init_input_output()
self.init_grad_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': -1, 'use_mkldnn': False}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self):
self.x = np.random.random(
(10, 2, 5)).astype(self.dtype) + 1J * np.random.random(
(10, 2, 5)).astype(self.dtype)
self.y = np.random.random(
(5, 20)).astype(self.dtype) + 1J * np.random.random(
(5, 20)).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones((10, 2, 20), self.dtype) + 1J * np.ones(
(10, 2, 20), self.dtype)
self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T)
self.grad_y = np.sum(np.matmul(
np.conj(self.x).transpose(0, 2, 1), self.grad_out),
axis=0)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.grad_x, self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.grad_y],
user_defined_grad_outputs=[self.grad_out])
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
class TestMatMulTypePromotion(TestComplexMatMulOp):
def init_input_output(self):
self.x = np.random.random((10, 10)).astype(self.dtype)
self.y = np.random.random(
(10, 10)).astype(self.dtype) + 1J * np.random.random(
(10, 10)).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def init_grad_input_output(self):
self.grad_out = np.ones((10, 10), self.dtype) + 1J * np.ones(
(10, 10), self.dtype)
self.grad_x = np.matmul(self.grad_out, np.conj(self.y).T).real
self.grad_y = np.matmul(np.conj(self.x).T, self.grad_out)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -59,6 +59,7 @@ NEED_TO_FIX_OP_LIST = [
'lstmp',
'margin_rank_loss',
'matmul',
'matmul_v2',
'mul',
'multiplex',
'rank_loss',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册