未验证 提交 8f45d142 编写于 作者: C chentianyu03 提交者: GitHub

add complex64 and complex128 type; add +-*/@ and slice opreator for c… (#29199)

* add complex64 and complex128 type; add +-*/@ and slice opreator for complex types

* add test cases for complex elementwise, matmul and getitem unittest

* add test cases for complex types

* add test cases for complex matmul unittest
上级 cc9c6196
......@@ -18,6 +18,8 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
......@@ -25,6 +27,8 @@ namespace paddle {
namespace platform {
struct bfloat16;
struct float16;
struct complex64;
struct complex128;
} // namespace platform
} // namespace paddle
......@@ -45,23 +49,27 @@ struct DataTypeTrait<void> {
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define _ForEachDataTypeSmall_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128);
#define _ForEachDataTypeSmall_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128);
// For the use of thrust, as index-type elements can be only integers.
#define _ForEachDataTypeTiny_(callback) \
......
......@@ -169,6 +169,10 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
#pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in)
#pragma omp declare reduction(+ : paddle::platform::bfloat16 : omp_out += \
omp_in)
#pragma omp declare reduction(+ : paddle::platform::complex64 : omp_out += \
omp_in)
#pragma omp declare reduction(+ : paddle::platform::complex128 : omp_out += \
omp_in)
#endif
template <typename T>
......@@ -222,6 +226,37 @@ void CheckNanInf<paddle::platform::bfloat16>(
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
template <>
void CheckNanInf<paddle::platform::complex64>(
const paddle::platform::complex64* value, const size_t numel, int print_num,
const std::string& op_type, const std::string& var_name) {
paddle::platform::complex64 sum(0.0, 0.0);
#pragma omp parallel for reduction(+ : sum)
for (size_t i = 0; i < numel; ++i) {
sum += (value[i] - value[i]);
}
if (std::isnan(sum) || std::isinf(sum)) {
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
template <>
void CheckNanInf<paddle::platform::complex128>(
const paddle::platform::complex128* value, const size_t numel,
int print_num, const std::string& op_type, const std::string& var_name) {
paddle::platform::complex128 sum(0.0, 0.0);
#pragma omp parallel for reduction(+ : sum)
for (size_t i = 0; i < numel; ++i) {
sum += (value[i] - value[i]);
}
if (std::isnan(sum) || std::isinf(sum)) {
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
#endif
template <>
......
......@@ -116,6 +116,8 @@ message VarType {
UINT8 = 20;
INT8 = 21;
BF16 = 22;
COMPLEX64 = 23;
COMPLEX128 = 24;
// Other types that may need additional descriptions
LOD_TENSOR = 7;
......
......@@ -22,6 +22,8 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......@@ -990,6 +992,40 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) {
return os;
}
template <>
std::ostream& print_tensor<paddle::platform::complex64>(
std::ostream& os, const framework::Tensor& tensor) {
auto inspect = tensor.data<paddle::platform::complex64>();
auto element_num = tensor.numel();
os << " - data: [";
if (element_num > 0) {
os << signed(inspect[0].real) << signed(inspect[0].imag) << "j";
for (int j = 1; j < element_num; ++j) {
os << signed(inspect[j].real) << signed(inspect[j].imag) << "j";
}
}
os << "]";
return os;
}
template <>
std::ostream& print_tensor<paddle::platform::complex128>(
std::ostream& os, const framework::Tensor& tensor) {
auto inspect = tensor.data<paddle::platform::complex128>();
auto element_num = tensor.numel();
os << " - data: [";
if (element_num > 0) {
os << signed(inspect[0].real) << signed(inspect[0].imag) << "j";
for (int j = 1; j < element_num; ++j) {
os << signed(inspect[j].real) << signed(inspect[j].imag) << "j";
}
}
os << "]";
return os;
}
std::ostream& operator<<(std::ostream& os, const Tensor& t) {
os << " - place: " << t.place() << "\n";
os << " - shape: [" << t.dims() << "]\n";
......
......@@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle {
namespace framework {
......@@ -128,13 +130,21 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_add_grad_grad,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
......@@ -144,7 +154,11 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
int64_t>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
// A specialization elementwise_add operator, used in gradient accumulation with
// inplace addto.
......@@ -159,4 +173,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -95,26 +97,35 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>);
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>);
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad_grad,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
plat::float16>);
plat::complex64>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
plat::complex128>);
REGISTER_OP_CUDA_KERNEL(
grad_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>);
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>);
......@@ -16,6 +16,8 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle {
namespace operators {
......@@ -130,13 +132,21 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_div_grad_grad,
......@@ -147,4 +157,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
int64_t>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -102,7 +104,11 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::float16>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
......@@ -110,8 +116,11 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::float16>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad_grad,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
......@@ -123,4 +132,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
int>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
int64_t>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
......@@ -16,6 +16,8 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle {
namespace operators {
......@@ -130,13 +132,21 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
......@@ -146,4 +156,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
int64_t>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -100,19 +102,26 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>);
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>);
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::float16>);
plat::complex64>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex128>);
......@@ -17,6 +17,8 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle {
namespace framework {
......@@ -125,13 +127,21 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad_grad,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
......@@ -141,4 +151,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
int64_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -99,7 +101,11 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::float16>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>,
......@@ -107,8 +113,11 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::float16>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad_grad,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
......@@ -118,4 +127,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
int>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
int64_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
......@@ -16,6 +16,7 @@
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/gpu_info.h"
DECLARE_bool(enable_cublas_tensor_op_math);
......@@ -258,6 +259,180 @@ struct CUBlas<platform::float16> {
}
};
template <>
struct CUBlas<platform::complex64> {
using complex64 = platform::complex64;
static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m,
int n, const complex64 *alpha, const complex64 *A, int lda,
const complex64 *B, int ldb, const complex64 *beta,
complex64 *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemv(
handle, transa, m, n, reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex *>(A), lda,
reinterpret_cast<const cuFloatComplex *>(B), ldb,
reinterpret_cast<const cuFloatComplex *>(beta),
reinterpret_cast<cuFloatComplex *>(C), ldc));
}
static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const complex64 *alpha, const complex64 *A,
int lda, long long int strideA, // NOLINT
const complex64 *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const complex64 *beta, complex64 *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemmStridedBatched(
handle, transa, transb, m, n, k,
reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex *>(A), lda, strideA,
reinterpret_cast<const cuFloatComplex *>(B), ldb, strideB,
reinterpret_cast<const cuFloatComplex *>(beta),
reinterpret_cast<cuFloatComplex *>(C), ldc, strideC, batchCount));
#else
PADDLE_THROW(platform::errors::Unimplemented(
"CgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
}
static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const complex64 *alpha, const complex64 *A, int lda,
const complex64 *B, int ldb, const complex64 *beta,
complex64 *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemm(
handle, transa, transb, m, n, k,
reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex *>(A), lda,
reinterpret_cast<const cuFloatComplex *>(B), ldb,
reinterpret_cast<const cuFloatComplex *>(beta),
reinterpret_cast<cuFloatComplex *>(C), ldc));
}
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template <typename... ARGS>
static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
cublasOperation_t transa, cublasOperation_t transb, int m,
int n, int k, const void *alpha, const void *A,
cudaDataType_t Atype, int lda, const void *B,
cudaDataType_t Btype, int ldb, const void *beta, void *C,
cudaDataType_t Ctype, int ldc,
cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
bool use_tensor_op_math = dev_ctx->tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
#endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmEx(
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
beta, C, Ctype, ldc, computeType, algo));
});
#else
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx is not supported on cuda <= 7.5"));
#endif
}
};
template <>
struct CUBlas<platform::complex128> {
using complex128 = platform::complex128;
static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m,
int n, const complex128 *alpha, const complex128 *A, int lda,
const complex128 *B, int ldb, const complex128 *beta,
complex128 *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemv(
handle, transa, m, n, reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex *>(A), lda,
reinterpret_cast<const cuDoubleComplex *>(B), ldb,
reinterpret_cast<const cuDoubleComplex *>(beta),
reinterpret_cast<cuDoubleComplex *>(C), ldc));
}
static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const complex128 *alpha, const complex128 *A,
int lda, long long int strideA, // NOLINT
const complex128 *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const complex128 *beta, complex128 *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemmStridedBatched(
handle, transa, transb, m, n, k,
reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex *>(A), lda, strideA,
reinterpret_cast<const cuDoubleComplex *>(B), ldb, strideB,
reinterpret_cast<const cuDoubleComplex *>(beta),
reinterpret_cast<cuDoubleComplex *>(C), ldc, strideC, batchCount));
#else
PADDLE_THROW(platform::errors::Unimplemented(
"CgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
}
static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const complex128 *alpha, const complex128 *A, int lda,
const complex128 *B, int ldb, const complex128 *beta,
complex128 *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemm(
handle, transa, transb, m, n, k,
reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex *>(A), lda,
reinterpret_cast<const cuDoubleComplex *>(B), ldb,
reinterpret_cast<const cuDoubleComplex *>(beta),
reinterpret_cast<cuDoubleComplex *>(C), ldc));
}
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template <typename... ARGS>
static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
cublasOperation_t transa, cublasOperation_t transb, int m,
int n, int k, const void *alpha, const void *A,
cudaDataType_t Atype, int lda, const void *B,
cudaDataType_t Btype, int ldb, const void *beta, void *C,
cudaDataType_t Ctype, int ldc,
cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
bool use_tensor_op_math = dev_ctx->tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
#endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmEx(
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
beta, C, Ctype, ldc, computeType, algo));
});
#else
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx is not supported on cuda <= 7.5"));
#endif
}
};
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
......@@ -338,6 +513,103 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
#endif // CUDA_VERSION >= 8000
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::complex64 alpha, const platform::complex64 *A,
const platform::complex64 *B, platform::complex64 beta,
platform::complex64 *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 53,
platform::errors::InvalidArgument(
"cublas complex64 gemm requires GPU compute capability >= 53,"
"but received %d",
context_.GetComputeCapability()));
thrust::complex<float> c_alpha =
thrust::complex<float>(alpha.real, alpha.imag);
thrust::complex<float> c_beta = thrust::complex<float>(beta.real, beta.imag);
#if CUDA_VERSION >= 8000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::complex64>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_32F, ldb, A,
CUDA_C_32F, lda, &c_beta, C, CUDA_C_32F, N, CUDA_C_32F);
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<platform::complex64>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&c_alpha, h_B, ldb, h_A, lda, &c_beta,
h_C, N);
});
#endif // CUDA_VERSION >= 8000
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::complex128 alpha, const platform::complex128 *A,
const platform::complex128 *B, platform::complex128 beta,
platform::complex128 *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 53,
platform::errors::InvalidArgument(
"cublas complex128 gemm requires GPU compute capability >= 53,"
"but received %d",
context_.GetComputeCapability()));
thrust::complex<double> c_alpha =
thrust::complex<double>(alpha.real, alpha.imag);
thrust::complex<double> c_beta =
thrust::complex<double>(beta.real, beta.imag);
#if CUDA_VERSION >= 8000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::complex128>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_64F, ldb, A,
CUDA_C_64F, lda, &c_beta, C, CUDA_C_64F, N, CUDA_C_64F);
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<platform::complex128>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&c_alpha, h_B, ldb, h_A, lda, &c_beta,
h_C, N);
});
#endif // CUDA_VERSION >= 8000
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
......
......@@ -12,11 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_MKLML
#include <mkl.h>
#endif
#include <algorithm>
#include <cmath>
#include <limits>
#include <vector>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle {
namespace operators {
......@@ -287,6 +293,246 @@ struct CBlas<double> {
}
};
template <>
struct CBlas<platform::complex64> {
template <typename... ARGS>
static void VCOPY(ARGS... args) {
platform::dynload::cblas_ccopy(args...);
}
// the libmklml_intel.so paddle used has no vcAdd, vcSub,
// vcMul, vcDiv apis before rebuild from source
// so replace with the raw operator methods
/*
template <typename... ARGS>
static void VADD(ARGS... args) {
platform::dynload::vcAdd(args...);
}
template <typename... ARGS>
static void VSUB(ARGS... args) {
platform::dynload::vcSub(args...);
}
template <typename... ARGS>
static void VMUL(ARGS... args) {
platform::dynload::vcMul(args...);
}
template <typename... ARGS>
static void VDIV(ARGS... args) {
platform::dynload::vcDiv(args...);
}
*/
template <typename... ARGS>
static void VADD(int n, const paddle::platform::complex64 *a,
const paddle::platform::complex64 *b,
paddle::platform::complex64 *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] + b[i];
}
}
template <typename... ARGS>
static void VSUB(int n, const paddle::platform::complex64 *a,
const paddle::platform::complex64 *b,
paddle::platform::complex64 *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] - b[i];
}
}
template <typename... ARGS>
static void VMUL(int n, const paddle::platform::complex64 *a,
const paddle::platform::complex64 *b,
paddle::platform::complex64 *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] * b[i];
}
}
template <typename... ARGS>
static void VDIV(int n, const paddle::platform::complex64 *a,
const paddle::platform::complex64 *b,
paddle::platform::complex64 *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] / b[i];
}
}
template <typename... ARGS>
static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N,
paddle::platform::complex64 alpha,
const paddle::platform::complex64 *A, int lda,
const paddle::platform::complex64 *X, int incx,
paddle::platform::complex64 beta,
paddle::platform::complex64 *Y, int incy) {
const void *a_ = (const void *)(A);
const void *x_ = (const void *)(X);
void *y_ = static_cast<void *>(Y);
platform::dynload::cblas_cgemv(layout, trans, M, N, &alpha, a_, lda, x_,
incx, &beta, y_, incy);
}
template <typename... ARGS>
static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a,
CBLAS_TRANSPOSE trans_b, int M, int N, int K,
paddle::platform::complex64 alpha,
const paddle::platform::complex64 *A, int lda,
const paddle::platform::complex64 *B, int ldb,
paddle::platform::complex64 beta,
paddle::platform::complex64 *C, int ldc) {
const void *a_ = (const void *)(A);
const void *b_ = (const void *)(B);
void *c_ = static_cast<void *>(C);
platform::dynload::cblas_cgemm(layout, trans_a, trans_b, M, N, K, &alpha,
a_, lda, b_, ldb, &beta, c_, ldc);
}
template <typename... ARGS>
static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a,
CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K,
paddle::platform::complex64 *alpha,
const paddle::platform::complex64 **A, const int *lda,
const paddle::platform::complex64 **B, const int *ldb,
paddle::platform::complex64 *beta,
paddle::platform::complex64 **C, const int *ldc,
int group_count, int *group_size) {
const void **A_void = (const void **)(&(*A));
const void **B_void = (const void **)(&(*B));
void **C_void = reinterpret_cast<void **>(C);
platform::dynload::cblas_cgemm_batch(layout, trans_a, trans_b, M, N, K,
alpha, A_void, lda, B_void, ldb, beta,
C_void, ldc, group_count, group_size);
}
template <typename... ARGS>
static void GEMM_EX(ARGS... args) {
platform::dynload::cblas_cgemm_batch(args...);
}
};
template <>
struct CBlas<platform::complex128> {
template <typename... ARGS>
static void VCOPY(ARGS... args) {
platform::dynload::cblas_zcopy(args...);
}
// the libmklml_intel.so paddle used has no vzAdd, vzSub,
// vzMul, vzDiv apis before rebuild from source
// so replace with the raw operator methods
/*
template <typename... ARGS>
static void VADD(ARGS... args) {
platform::dynload::vzAdd(args...);
}
template <typename... ARGS>
static void VSUB(ARGS... args) {
platform::dynload::vzSub(args...);
}
template <typename... ARGS>
static void VMUL(ARGS... args) {
platform::dynload::vzMul(args...);
}
template <typename... ARGS>
static void VDIV(ARGS... args) {
platform::dynload::vzDiv(args...);
}
*/
template <typename... ARGS>
static void VADD(int n, const paddle::platform::complex128 *a,
const paddle::platform::complex128 *b,
paddle::platform::complex128 *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] + b[i];
}
}
template <typename... ARGS>
static void VSUB(int n, const paddle::platform::complex128 *a,
const paddle::platform::complex128 *b,
paddle::platform::complex128 *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] - b[i];
}
}
template <typename... ARGS>
static void VMUL(int n, const paddle::platform::complex128 *a,
const paddle::platform::complex128 *b,
paddle::platform::complex128 *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] * b[i];
}
}
template <typename... ARGS>
static void VDIV(int n, const paddle::platform::complex128 *a,
const paddle::platform::complex128 *b,
paddle::platform::complex128 *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] / b[i];
}
}
template <typename... ARGS>
static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N,
paddle::platform::complex128 alpha,
const paddle::platform::complex128 *A, int lda,
const paddle::platform::complex128 *X, int incx,
paddle::platform::complex128 beta,
paddle::platform::complex128 *Y, int incy) {
const void *a_ = (const void *)(A);
const void *x_ = (const void *)(X);
void *y_ = static_cast<void *>(Y);
platform::dynload::cblas_zgemv(layout, trans, M, N, &alpha, a_, lda, x_,
incx, &beta, y_, incy);
}
template <typename... ARGS>
static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a,
CBLAS_TRANSPOSE trans_b, int M, int N, int K,
paddle::platform::complex128 alpha,
const paddle::platform::complex128 *A, int lda,
const paddle::platform::complex128 *B, int ldb,
paddle::platform::complex128 beta,
paddle::platform::complex128 *C, int ldc) {
const void *a_ = (const void *)(A);
const void *b_ = (const void *)(B);
void *c_ = static_cast<void *>(C);
platform::dynload::cblas_zgemm(layout, trans_a, trans_b, M, N, K, &alpha,
a_, lda, b_, ldb, &beta, c_, ldc);
}
template <typename... ARGS>
static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a,
CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K,
paddle::platform::complex128 *alpha,
const paddle::platform::complex128 **A, const int *lda,
const paddle::platform::complex128 **B, const int *ldb,
paddle::platform::complex128 *beta,
paddle::platform::complex128 **C, const int *ldc,
int group_count, int *group_size) {
const void **A_void = (const void **)(&(*A));
const void **B_void = (const void **)(&(*B));
void **C_void = reinterpret_cast<void **>(C);
platform::dynload::cblas_zgemm_batch(layout, trans_a, trans_b, M, N, K,
alpha, A_void, lda, B_void, ldb, beta,
C_void, ldc, group_count, group_size);
}
template <typename... ARGS>
static void GEMM_EX(ARGS... args) {
platform::dynload::cblas_zgemm_batch(args...);
}
};
#else
template <>
......@@ -344,6 +590,93 @@ struct CBlas<double> {
cblas_dtrsm(args...);
}
};
template <>
struct CBlas<platform::complex64> {
template <typename... ARGS>
static void VCOPY(ARGS... args) {
cblas_ccopy(args...);
}
template <typename... ARGS>
static void VADD(ARGS... args) {
vcAdd(args...);
}
template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex64 alpha,
const paddle::platform::complex64 *X, const int incX,
paddle::platform::complex64 *Y, const int incY) {
cblas_caxpy(n, &alpha, X, incX, Y, incY);
}
template <typename... ARGS>
static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA,
const int M, const int N,
const paddle::platform::complex64 alpha,
const paddle::platform::complex64 *A, const int lda,
const paddle::platform::complex64 *X, const int incX,
const paddle::platform::complex64 beta,
paddle::platform::complex64 *Y, const int incY) {
cblas_cgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY);
}
template <typename... ARGS>
static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const paddle::platform::complex64 alpha,
const paddle::platform::complex64 *A, const int lda,
const paddle::platform::complex64 *B, const int ldb,
const paddle::platform::complex64 beta,
paddle::platform::complex64 *C, const int ldc) {
cblas_cgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta,
C, ldc);
}
};
template <>
struct CBlas<platform::complex128> {
template <typename... ARGS>
static void VCOPY(ARGS... args) {
cblas_zcopy(args...);
}
template <typename... ARGS>
static void VADD(ARGS... args) {
vzAdd(args...);
}
template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex128 alpha,
const paddle::platform::complex128 *X, const int incX,
paddle::platform::complex128 *Y, const int incY) {
cblas_zaxpy(n, &alpha, X, incX, Y, incY);
}
template <typename... ARGS>
static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA,
const int M, const int N,
const paddle::platform::complex128 alpha,
const paddle::platform::complex128 *A, const int lda,
const paddle::platform::complex128 *X, const int incX,
const paddle::platform::complex128 beta,
paddle::platform::complex128 *Y, const int incY) {
cblas_zgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY);
}
template <typename... ARGS>
static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const paddle::platform::complex128 alpha,
const paddle::platform::complex128 *A, const int lda,
const paddle::platform::complex128 *B, const int ldb,
const paddle::platform::complex128 beta,
paddle::platform::complex128 *C, const int ldc) {
cblas_zgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta,
C, ldc);
}
};
#endif
template <>
......@@ -517,10 +850,10 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
CBlas<T>::VADD(n, x, y, z);
#else
if (x == z) {
this->template AXPY<T>(n, 1., y, z);
this->template AXPY<T>(n, (T)(1.), y, z);
} else {
this->template VCOPY<T>(n, y, z);
this->template AXPY<T>(n, 1., x, z);
this->template AXPY<T>(n, (T)(1.), x, z);
}
#endif
}
......
......@@ -65,14 +65,16 @@ class SplitFunctor {
} // namespace operators
} // namespace paddle
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16)
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16); \
macro(::paddle::platform::complex64); \
macro(::paddle::platform::complex128)
......@@ -44,6 +44,8 @@ template struct SetConstant<platform::CPUDeviceContext, int>;
template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
template struct SetConstant<platform::CPUDeviceContext, platform::complex64>;
template struct SetConstant<platform::CPUDeviceContext, platform::complex128>;
#ifdef PADDLE_WITH_XPU
template struct SetConstant<platform::XPUDeviceContext, platform::float16>;
......@@ -54,19 +56,23 @@ template struct SetConstant<platform::XPUDeviceContext, int64_t>;
template struct SetConstant<platform::XPUDeviceContext, bool>;
#endif
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::bfloat16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>;
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::bfloat16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::complex64, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::complex128, \
RANK>;
DEFINE_CPU_TRANS(1);
DEFINE_CPU_TRANS(2);
......@@ -117,6 +123,8 @@ DEFINE_CPU_TRANS_NORMAL(bool);
DEFINE_CPU_TRANS_NORMAL(int16_t);
DEFINE_CPU_TRANS_NORMAL(uint8_t);
DEFINE_CPU_TRANS_NORMAL(int8_t);
DEFINE_CPU_TRANS_NORMAL(platform::complex64);
DEFINE_CPU_TRANS_NORMAL(platform::complex128);
struct TensorSetConstantCPU {
TensorSetConstantCPU(framework::Tensor* tensor, float value)
......
......@@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -27,6 +29,8 @@ namespace math {
using float16 = paddle::platform::float16;
using bfloat16 = paddle::platform::bfloat16;
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
template struct SetConstant<platform::CUDADeviceContext, float>;
......@@ -34,15 +38,19 @@ template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
template struct SetConstant<platform::CUDADeviceContext, platform::complex64>;
template struct SetConstant<platform::CUDADeviceContext, platform::complex128>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, bfloat16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int32_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int64_t, RANK>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, bfloat16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int32_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, complex64, RANK>; \
template struct Transpose<platform::CUDADeviceContext, complex128, RANK>;
DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);
......@@ -132,6 +140,8 @@ DEFINE_GPU_TRANS_NORMAL(bool);
DEFINE_GPU_TRANS_NORMAL(int16_t);
DEFINE_GPU_TRANS_NORMAL(uint8_t);
DEFINE_GPU_TRANS_NORMAL(int8_t);
DEFINE_GPU_TRANS_NORMAL(complex64);
DEFINE_GPU_TRANS_NORMAL(complex128);
struct TensorSetConstantGPU {
TensorSetConstantGPU(const platform::DeviceContext& context,
......
......@@ -168,9 +168,17 @@ REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad);
REGISTER_OP_CPU_KERNEL(
matmul_v2, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, double>);
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
matmul_v2_grad,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, double>);
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
......@@ -20,9 +20,13 @@ namespace plf = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
matmul_v2, ops::MatMulV2Kernel<plf::CUDADeviceContext, float>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, double>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::float16>);
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::float16>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::complex64>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::complex128>);
REGISTER_OP_CUDA_KERNEL(
matmul_v2_grad, ops::MatMulV2GradKernel<plf::CUDADeviceContext, float>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, double>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::float16>);
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::float16>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex64>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex128>);
......@@ -424,10 +424,18 @@ REGISTER_OP_CPU_KERNEL(
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>);
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>,
ops::SliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::SliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>);
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
......@@ -23,7 +23,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::float16>);
ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::complex64>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::complex128>);
REGISTER_OP_CUDA_KERNEL(
slice_grad,
......@@ -31,4 +33,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, plat::float16>);
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, plat::complex64>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext,
plat::complex128>);
......@@ -327,11 +327,19 @@ REGISTER_OP_CPU_KERNEL(
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, double>);
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, double>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, double>);
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
......@@ -20,11 +22,19 @@ REGISTER_OP_CUDA_KERNEL(
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, double>);
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, double>);
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <limits>
#if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else
#define PADDLE_ALIGN(x) __declspec(align(x))
#endif
#ifdef PADDLE_WITH_CUDA
#include <cuComplex.h>
#include <thrust/complex.h>
#endif // PADDLE_WITH_CUDA
#include <cstring>
#include "paddle/fluid/platform/hostdevice.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace Eigen {
template <typename T>
struct NumTraits;
} // namespace Eigen
namespace paddle {
namespace platform {
struct PADDLE_ALIGN(16) complex128 {
public:
double real;
double imag;
complex128() = default;
complex128(const complex128& o) = default;
complex128& operator=(const complex128& o) = default;
complex128(complex128&& o) = default;
complex128& operator=(complex128&& o) = default;
~complex128() = default;
HOSTDEVICE complex128(double real, double imag) : real(real), imag(imag) {}
#if defined(PADDLE_WITH_CUDA)
HOSTDEVICE inline explicit complex128(const thrust::complex<double>& c) {
real = c.real();
imag = c.imag();
}
HOSTDEVICE inline explicit operator thrust::complex<double>() const {
return thrust::complex<double>(real, imag);
}
HOSTDEVICE inline explicit operator cuDoubleComplex() const {
return make_cuDoubleComplex(real, imag);
}
#endif
HOSTDEVICE complex128(const float& val) { real = static_cast<double>(val); }
HOSTDEVICE complex128(const double& val) { real = val; }
HOSTDEVICE complex128(const int& val) { real = static_cast<double>(val); }
HOSTDEVICE complex128(const int64_t& val) { real = static_cast<double>(val); }
HOSTDEVICE inline explicit operator std::complex<double>() {
return static_cast<std::complex<double>>(std::complex<double>(real, imag));
}
template <class T>
HOSTDEVICE inline explicit complex128(const T& val)
: real(complex128(static_cast<double>(val)).real) {}
HOSTDEVICE complex128(const std::complex<double> val)
: real(val.real()), imag(val.imag()) {}
HOSTDEVICE inline complex128& operator=(bool b) {
real = b ? 1 : 0;
imag = 0;
return *this;
}
HOSTDEVICE inline complex128& operator=(int8_t val) {
real = static_cast<double>(val);
return *this;
}
HOSTDEVICE inline complex128& operator=(uint8_t val) {
real = static_cast<double>(val);
return *this;
}
HOSTDEVICE inline complex128& operator=(int16_t val) {
real = static_cast<double>(val);
return *this;
}
HOSTDEVICE inline complex128& operator=(uint16_t val) {
real = static_cast<double>(val);
return *this;
}
HOSTDEVICE inline complex128& operator=(int32_t val) {
real = static_cast<double>(val);
return *this;
}
HOSTDEVICE inline complex128& operator=(uint32_t val) {
real = static_cast<double>(val);
return *this;
}
HOSTDEVICE inline complex128& operator=(int64_t val) {
real = static_cast<double>(val);
return *this;
}
HOSTDEVICE inline complex128& operator=(uint64_t val) {
real = static_cast<double>(val);
return *this;
}
HOSTDEVICE inline complex128& operator=(float val) {
real = val;
return *this;
}
HOSTDEVICE inline complex128& operator=(double val) {
real = static_cast<double>(val);
return *this;
}
HOSTDEVICE inline operator float() const {
return static_cast<float>(this->real);
}
HOSTDEVICE inline explicit operator bool() const {
return static_cast<bool>(this->real) || static_cast<bool>(this->imag);
}
HOSTDEVICE inline explicit operator int8_t() const {
return static_cast<int8_t>(this->real);
}
HOSTDEVICE inline explicit operator uint8_t() const {
return static_cast<uint8_t>(this->real);
}
HOSTDEVICE inline explicit operator int16_t() const {
return static_cast<int16_t>(this->real);
}
HOSTDEVICE inline explicit operator uint16_t() const {
return static_cast<uint16_t>(this->real);
}
HOSTDEVICE inline explicit operator int32_t() const {
return static_cast<int32_t>(this->real);
}
HOSTDEVICE inline explicit operator uint32_t() const {
return static_cast<uint32_t>(this->real);
}
HOSTDEVICE inline explicit operator int64_t() const {
return static_cast<int64_t>(this->real);
}
HOSTDEVICE inline explicit operator uint64_t() const {
return static_cast<uint64_t>(this->real);
}
HOSTDEVICE inline explicit operator double() const {
return static_cast<double>(this->real);
}
};
HOSTDEVICE inline complex128 operator+(const complex128& a,
const complex128& b) {
#if defined(__CUDA_ARCH__)
return complex128(thrust::complex<double>(a.real, a.imag) +
thrust::complex<double>(b.real, b.imag));
#else
return complex128(a.real + b.real, a.imag + b.imag);
#endif
}
HOSTDEVICE inline complex128 operator-(const complex128& a,
const complex128& b) {
#if defined(__CUDA_ARCH__)
return complex128(thrust::complex<double>(a.real, a.imag) -
thrust::complex<double>(b.real, b.imag));
#else
return complex128(a.real - b.real, a.imag - b.imag);
#endif
}
HOSTDEVICE inline complex128 operator*(const complex128& a,
const complex128& b) {
#if defined(__CUDA_ARCH__)
return complex128(thrust::complex<double>(a.real, a.imag) *
thrust::complex<double>(b.real, b.imag));
#else
return complex128(a.real * b.real - a.imag * b.imag,
a.imag * b.real + b.imag * a.real);
#endif
}
HOSTDEVICE inline complex128 operator/(const complex128& a,
const complex128& b) {
#if defined(__CUDA_ARCH__)
return complex128(thrust::complex<double>(a.real, a.imag) /
thrust::complex<double>(b.real, b.imag));
#else
double denominator = b.real * b.real + b.imag * b.imag;
return complex128((a.real * b.real + a.imag * b.imag) / denominator,
(a.imag * b.real - a.real * b.imag) / denominator);
#endif
}
HOSTDEVICE inline complex128 operator-(const complex128& a) {
#if defined(__CUDA_ARCH__)
return complex128(-thrust::complex<double>(a.real, a.imag));
#else
complex128 res;
res.real = -a.real;
res.imag = -a.imag;
return res;
#endif
}
HOSTDEVICE inline complex128& operator+=(complex128& a, // NOLINT
const complex128& b) {
#if defined(__CUDA_ARCH__)
a = complex128(thrust::complex<double>(a.real, a.imag) +=
thrust::complex<double>(b.real, b.imag));
return a;
#else
a.real += b.real;
a.imag += b.imag;
return a;
#endif
}
HOSTDEVICE inline complex128& operator-=(complex128& a, // NOLINT
const complex128& b) {
#if defined(__CUDA_ARCH__)
a = complex128(thrust::complex<double>(a.real, a.imag) -=
thrust::complex<double>(b.real, b.imag));
return a;
#else
a.real -= b.real;
a.imag -= b.imag;
return a;
#endif
}
HOSTDEVICE inline complex128& operator*=(complex128& a, // NOLINT
const complex128& b) {
#if defined(__CUDA_ARCH__)
a = complex128(thrust::complex<double>(a.real, a.imag) *=
thrust::complex<double>(b.real, b.imag));
return a;
#else
a.real = a.real * b.real - a.imag * b.imag;
a.imag = a.imag * b.real + b.imag * a.real;
return a;
#endif
}
HOSTDEVICE inline complex128& operator/=(complex128& a, // NOLINT
const complex128& b) {
#if defined(__CUDA_ARCH__)
a = complex128(thrust::complex<double>(a.real, a.imag) /=
thrust::complex<double>(b.real, b.imag));
return a;
#else
double denominator = b.real * b.real + b.imag * b.imag;
a.real = (a.real * b.real + a.imag * b.imag) / denominator;
a.imag = (a.imag * b.real - a.real * b.imag) / denominator;
return a;
#endif
}
HOSTDEVICE inline complex128 raw_uint16_to_complex128(uint16_t a) {
complex128 res;
res.real = a;
return res;
}
HOSTDEVICE inline bool operator==(const complex128& a, const complex128& b) {
return a.real == b.real && a.imag == b.imag;
}
HOSTDEVICE inline bool operator!=(const complex128& a, const complex128& b) {
return a.real != b.real || a.imag != b.imag;
}
HOSTDEVICE inline bool operator<(const complex128& a, const complex128& b) {
return static_cast<double>(a.real) < static_cast<double>(b.real);
}
HOSTDEVICE inline bool operator<=(const complex128& a, const complex128& b) {
return static_cast<double>(a.real) <= static_cast<double>(b.real);
}
HOSTDEVICE inline bool operator>(const complex128& a, const complex128& b) {
return static_cast<double>(a.real) > static_cast<double>(b.real);
}
HOSTDEVICE inline bool operator>=(const complex128& a, const complex128& b) {
return static_cast<double>(a.real) >= static_cast<double>(b.real);
}
HOSTDEVICE inline bool(isnan)(const complex128& a) {
#if defined(__CUDA_ARCH__)
return __isnan(a.real) || __isnan(a.imag);
#else
return std::isnan(a.real) || std::isnan(a.imag);
#endif
}
HOSTDEVICE inline bool(isinf)(const complex128& a) {
#if defined(__CUDA_ARCH__)
return __isinf(a.real) || __isinf(a.imag);
#else
return std::isinf(a.real) || std::isinf(a.imag);
#endif
}
HOSTDEVICE inline bool(isfinite)(const complex128& a) {
return !((isnan)(a)) && !((isinf)(a));
}
HOSTDEVICE inline double(abs)(const complex128& a) {
#if defined(__CUDA_ARCH__)
return thrust::abs(thrust::complex<double>(a.real, a.imag));
#else
return std::abs(std::complex<double>(a));
#endif
}
HOSTDEVICE inline complex128(pow)(const complex128& a, const complex128& b) {
#if defined(__CUDA_ARCH__)
return complex128(thrust::pow(thrust::complex<double>(a.real, a.imag),
thrust::complex<double>(b.real, b.imag)));
#else
return std::pow(std::complex<double>(a), std::complex<float>(b));
#endif
}
HOSTDEVICE inline complex128(sqrt)(const complex128& a) {
#if defined(__CUDA_ARCH__)
return complex128(thrust::sqrt(thrust::complex<double>(a.real, a.imag)));
#else
return std::sqrt(std::complex<double>(a));
#endif
}
HOSTDEVICE inline complex128(tanh)(const complex128& a) {
#if defined(__CUDA_ARCH__)
return complex128(thrust::tanh(thrust::complex<double>(a.real, a.imag)));
#else
return std::tanh(std::complex<double>(a));
#endif
}
HOSTDEVICE inline complex128(log)(const complex128& a) {
#if defined(__CUDA_ARCH__)
return complex128(thrust::log(thrust::complex<double>(a.real, a.imag)));
#else
return complex128(std::log(std::complex<double>(a)));
#endif
}
inline std::ostream& operator<<(std::ostream& os, const complex128& a) {
os << "real:" << a.real << " imag:" << a.imag;
return os;
}
} // namespace platform
} // namespace paddle
namespace std {
template <>
struct is_pod<paddle::platform::complex128> {
static const bool value =
is_trivial<paddle::platform::complex128>::value &&
is_standard_layout<paddle::platform::complex128>::value;
};
template <>
struct is_floating_point<paddle::platform::complex128>
: std::integral_constant<
bool, std::is_same<paddle::platform::complex128,
typename std::remove_cv<
paddle::platform::complex128>::type>::value> {
};
template <>
struct is_signed<paddle::platform::complex128> {
static const bool value = false;
};
template <>
struct is_unsigned<paddle::platform::complex128> {
static const bool value = false;
};
inline bool isnan(const paddle::platform::complex128& a) {
return paddle::platform::isnan(a);
}
inline bool isinf(const paddle::platform::complex128& a) {
return paddle::platform::isinf(a);
}
template <>
struct numeric_limits<paddle::platform::complex128> {
static const bool is_specialized = false;
static const bool is_signed = false;
static const bool is_integer = false;
static const bool is_exact = false;
static const bool has_infinity = false;
static const bool has_quiet_NaN = false;
static const bool has_signaling_NaN = false;
static const float_denorm_style has_denorm = denorm_absent;
static const bool has_denorm_loss = false;
static const std::float_round_style round_style = std::round_toward_zero;
static const bool is_iec559 = false;
static const bool is_bounded = false;
static const bool is_modulo = false;
static const int digits = 0;
static const int digits10 = 0;
static const int max_digits10 = 0;
static const int radix = 0;
static const int min_exponent = 0;
static const int min_exponent10 = 0;
static const int max_exponent = 0;
static const int max_exponent10 = 0;
static const bool traps = false;
static const bool tinyness_before = false;
static paddle::platform::complex128(min)() {
return paddle::platform::complex128(0.0, 0.0);
}
static paddle::platform::complex128 lowest() {
return paddle::platform::complex128(0.0, 0.0);
}
static paddle::platform::complex128(max)() {
return paddle::platform::complex128(0.0, 0.0);
}
static paddle::platform::complex128 epsilon() {
return paddle::platform::complex128(0.0, 0.0);
}
static paddle::platform::complex128 round_error() {
return paddle::platform::complex128(0.0, 0.0);
}
static paddle::platform::complex128 infinity() {
return paddle::platform::complex128(0.0, 0.0);
}
static paddle::platform::complex128 quiet_NaN() {
return paddle::platform::complex128(0.0, 0.0);
}
static paddle::platform::complex128 signaling_NaN() {
return paddle::platform::complex128(0.0, 0.0);
}
static paddle::platform::complex128 denorm_min() {
return paddle::platform::complex128(0.0, 0.0);
}
};
} // namespace std
namespace Eigen {
using complex128 = paddle::platform::complex128;
template <>
struct NumTraits<complex128> : GenericNumTraits<std::complex<double>> {
typedef double Real;
typedef typename NumTraits<double>::Literal Literal;
enum {
IsComplex = 1,
RequireInitialization = NumTraits<double>::RequireInitialization,
ReadCost = 2 * NumTraits<double>::ReadCost,
AddCost = 2 * NumTraits<Real>::AddCost,
MulCost = 4 * NumTraits<Real>::MulCost + 2 * NumTraits<Real>::AddCost
};
EIGEN_DEVICE_FUNC
static inline Real epsilon() { return NumTraits<Real>::epsilon(); }
EIGEN_DEVICE_FUNC
static inline Real dummy_precision() {
return NumTraits<Real>::dummy_precision();
}
EIGEN_DEVICE_FUNC
static inline int digits10() { return NumTraits<Real>::digits10(); }
};
namespace numext {
template <>
HOSTDEVICE inline bool(isnan)(const complex128& a) {
return (paddle::platform::isnan)(a);
}
template <>
HOSTDEVICE inline bool(isinf)(const complex128& a) {
return (paddle::platform::isinf)(a);
}
template <>
HOSTDEVICE inline bool(isfinite)(const complex128& a) {
return (paddle::platform::isfinite)(a);
}
template <>
HOSTDEVICE inline complex128 exp(const complex128& a) {
double com = ::expf(a.real);
double res_real = com * ::cosf(a.imag);
double res_imag = com * ::sinf(a.imag);
return complex128(res_real, res_imag);
}
template <>
HOSTDEVICE inline complex128 log(const complex128& a) {
return paddle::platform::log(a);
}
template <>
HOSTDEVICE inline complex128 tanh(const complex128& a) {
return paddle::platform::tanh(a);
}
template <>
HOSTDEVICE inline complex128 sqrt(const complex128& a) {
return paddle::platform::sqrt(a);
}
template <>
HOSTDEVICE inline complex128 ceil(const complex128& a) {
return complex128(::ceilf(a.real), ::ceilf(a.imag));
}
template <>
HOSTDEVICE inline complex128 floor(const complex128& a) {
return complex128(::floorf(a.real), ::floor(a.imag));
}
template <>
HOSTDEVICE inline complex128 round(const complex128& a) {
return complex128(::roundf(a.real), ::roundf(a.imag));
}
template <>
HOSTDEVICE inline complex128 pow(const complex128& a, const complex128& b) {
return paddle::platform::pow(a, b);
}
template <>
HOSTDEVICE inline double abs(const complex128& a) {
return paddle::platform::abs(a);
}
} // namespace numext
} // namespace Eigen
#define MKL_Complex16 paddle::platform::complex128
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <limits>
#if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else
#define PADDLE_ALIGN(x) __declspec(align(x))
#endif
#ifdef PADDLE_WITH_CUDA
#include <cuComplex.h>
#include <thrust/complex.h>
#endif // PADDLE_WITH_CUDA
#include <cstring>
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace Eigen {
template <typename T>
struct NumTraits;
} // namespace Eigen
namespace paddle {
namespace platform {
struct PADDLE_ALIGN(8) complex64 {
public:
float real;
float imag;
complex64() = default;
complex64(const complex64& o) = default;
complex64& operator=(const complex64& o) = default;
complex64(complex64&& o) = default;
complex64& operator=(complex64&& o) = default;
~complex64() = default;
HOSTDEVICE complex64(float real, float imag) : real(real), imag(imag) {}
#if defined(PADDLE_WITH_CUDA)
HOSTDEVICE inline explicit complex64(const thrust::complex<float>& c) {
real = c.real();
imag = c.imag();
}
HOSTDEVICE inline explicit operator thrust::complex<float>() const {
return thrust::complex<float>(real, imag);
}
HOSTDEVICE inline explicit operator cuFloatComplex() const {
return make_cuFloatComplex(real, imag);
}
#endif
HOSTDEVICE complex64(const float& val) { real = val; }
HOSTDEVICE complex64(const double& val) { real = static_cast<float>(val); }
HOSTDEVICE complex64(const int& val) { real = static_cast<float>(val); }
HOSTDEVICE complex64(const int64_t& val) { real = static_cast<float>(val); }
HOSTDEVICE complex64(const complex128& val) {
real = static_cast<float>(val.real);
imag = static_cast<float>(val.imag);
}
HOSTDEVICE inline explicit operator std::complex<float>() {
return static_cast<std::complex<float>>(std::complex<float>(real, imag));
}
template <class T>
HOSTDEVICE inline explicit complex64(const T& val)
: real(complex64(static_cast<float>(val)).real) {}
HOSTDEVICE complex64(const std::complex<float> val)
: real(val.real()), imag(val.imag()) {}
HOSTDEVICE inline complex64& operator=(bool b) {
real = b ? 1 : 0;
imag = 0;
return *this;
}
HOSTDEVICE inline complex64& operator=(int8_t val) {
real = static_cast<float>(val);
return *this;
}
HOSTDEVICE inline complex64& operator=(uint8_t val) {
real = static_cast<float>(val);
return *this;
}
HOSTDEVICE inline complex64& operator=(int16_t val) {
real = static_cast<float>(val);
return *this;
}
HOSTDEVICE inline complex64& operator=(uint16_t val) {
real = static_cast<float>(val);
return *this;
}
HOSTDEVICE inline complex64& operator=(int32_t val) {
real = static_cast<float>(val);
return *this;
}
HOSTDEVICE inline complex64& operator=(uint32_t val) {
real = static_cast<float>(val);
return *this;
}
HOSTDEVICE inline complex64& operator=(int64_t val) {
real = static_cast<float>(val);
return *this;
}
HOSTDEVICE inline complex64& operator=(uint64_t val) {
real = static_cast<float>(val);
return *this;
}
HOSTDEVICE inline complex64& operator=(float val) {
real = val;
return *this;
}
HOSTDEVICE inline complex64& operator=(double val) {
real = static_cast<float>(val);
return *this;
}
HOSTDEVICE inline operator float() const { return this->real; }
HOSTDEVICE inline explicit operator bool() const {
return static_cast<bool>(this->real) || static_cast<bool>(this->imag);
}
HOSTDEVICE inline explicit operator int8_t() const {
return static_cast<int8_t>(this->real);
}
HOSTDEVICE inline explicit operator uint8_t() const {
return static_cast<uint8_t>(this->real);
}
HOSTDEVICE inline explicit operator int16_t() const {
return static_cast<int16_t>(this->real);
}
HOSTDEVICE inline explicit operator uint16_t() const {
return static_cast<uint16_t>(this->real);
}
HOSTDEVICE inline explicit operator int32_t() const {
return static_cast<int32_t>(this->real);
}
HOSTDEVICE inline explicit operator uint32_t() const {
return static_cast<uint32_t>(this->real);
}
HOSTDEVICE inline explicit operator int64_t() const {
return static_cast<int64_t>(this->real);
}
HOSTDEVICE inline explicit operator uint64_t() const {
return static_cast<uint64_t>(this->real);
}
HOSTDEVICE inline explicit operator double() const {
return static_cast<double>(this->real);
}
HOSTDEVICE inline operator complex128() const {
return complex128(static_cast<double>(this->real),
static_cast<double>(this->imag));
}
};
HOSTDEVICE inline complex64 operator+(const complex64& a, const complex64& b) {
#if defined(__CUDA_ARCH__)
return complex64(thrust::complex<float>(a.real, a.imag) +
thrust::complex<float>(b.real, b.imag));
#else
return complex64(a.real + b.real, a.imag + b.imag);
#endif
}
HOSTDEVICE inline complex64 operator-(const complex64& a, const complex64& b) {
#if defined(__CUDA_ARCH__)
return complex64(thrust::complex<float>(a.real, a.imag) -
thrust::complex<float>(b.real, b.imag));
#else
return complex64(a.real - b.real, a.imag - b.imag);
#endif
}
HOSTDEVICE inline complex64 operator*(const complex64& a, const complex64& b) {
#if defined(__CUDA_ARCH__)
return complex64(thrust::complex<float>(a.real, a.imag) *
thrust::complex<float>(b.real, b.imag));
#else
return complex64(a.real * b.real - a.imag * b.imag,
a.imag * b.real + b.imag * a.real);
#endif
}
HOSTDEVICE inline complex64 operator/(const complex64& a, const complex64& b) {
#if defined(__CUDA_ARCH__)
return complex64(thrust::complex<float>(a.real, a.imag) /
thrust::complex<float>(b.real, b.imag));
#else
float denominator = b.real * b.real + b.imag * b.imag;
return complex64((a.real * b.real + a.imag * b.imag) / denominator,
(a.imag * b.real - a.real * b.imag) / denominator);
#endif
}
HOSTDEVICE inline complex64 operator-(const complex64& a) {
#if defined(__CUDA_ARCH__)
return complex64(-thrust::complex<float>(a.real, a.imag));
#else
complex64 res;
res.real = -a.real;
res.imag = -a.imag;
return res;
#endif
}
HOSTDEVICE inline complex64& operator+=(complex64& a, // NOLINT
const complex64& b) {
#if defined(__CUDA_ARCH__)
a = complex64(thrust::complex<float>(a.real, a.imag) +=
thrust::complex<float>(b.real, b.imag));
return a;
#else
a.real += b.real;
a.imag += b.imag;
return a;
#endif
}
HOSTDEVICE inline complex64& operator-=(complex64& a, // NOLINT
const complex64& b) {
#if defined(__CUDA_ARCH__)
a = complex64(thrust::complex<float>(a.real, a.imag) -=
thrust::complex<float>(b.real, b.imag));
return a;
#else
a.real -= b.real;
a.imag -= b.imag;
return a;
#endif
}
HOSTDEVICE inline complex64& operator*=(complex64& a, // NOLINT
const complex64& b) {
#if defined(__CUDA_ARCH__)
a = complex64(thrust::complex<float>(a.real, a.imag) *=
thrust::complex<float>(b.real, b.imag));
return a;
#else
a.real = a.real * b.real - a.imag * b.imag;
a.imag = a.imag * b.real + b.imag * a.real;
return a;
#endif
}
HOSTDEVICE inline complex64& operator/=(complex64& a, // NOLINT
const complex64& b) {
#if defined(__CUDA_ARCH__)
a = complex64(thrust::complex<float>(a.real, a.imag) /=
thrust::complex<float>(b.real, b.imag));
return a;
#else
float denominator = b.real * b.real + b.imag * b.imag;
a.real = (a.real * b.real + a.imag * b.imag) / denominator;
a.imag = (a.imag * b.real - a.real * b.imag) / denominator;
return a;
#endif
}
HOSTDEVICE inline complex64 raw_uint16_to_complex64(uint16_t a) {
complex64 res;
res.real = a;
return res;
}
HOSTDEVICE inline bool operator==(const complex64& a, const complex64& b) {
return a.real == b.real && a.imag == b.imag;
}
HOSTDEVICE inline bool operator!=(const complex64& a, const complex64& b) {
return a.real != b.real || a.imag != b.imag;
}
HOSTDEVICE inline bool operator<(const complex64& a, const complex64& b) {
return static_cast<float>(a.real) < static_cast<float>(b.real);
}
HOSTDEVICE inline bool operator<=(const complex64& a, const complex64& b) {
return static_cast<float>(a.real) <= static_cast<float>(b.real);
}
HOSTDEVICE inline bool operator>(const complex64& a, const complex64& b) {
return static_cast<float>(a.real) > static_cast<float>(b.real);
}
HOSTDEVICE inline bool operator>=(const complex64& a, const complex64& b) {
return static_cast<float>(a.real) >= static_cast<float>(b.real);
}
HOSTDEVICE inline bool(isnan)(const complex64& a) {
#if defined(__CUDA_ARCH__)
return __isnanf(a.real) || __isnanf(a.imag);
#else
return std::isnan(a.real) || std::isnan(a.imag);
#endif
}
HOSTDEVICE inline bool(isinf)(const complex64& a) {
#if defined(__CUDA_ARCH__)
return __isinff(a.real) || __isinff(a.imag);
#else
return std::isinf(a.real) || std::isinf(a.imag);
#endif
}
HOSTDEVICE inline bool(isfinite)(const complex64& a) {
return !((isnan)(a)) && !((isinf)(a));
}
HOSTDEVICE inline float(abs)(const complex64& a) {
#if defined(__CUDA_ARCH__)
return complex64(thrust::abs(thrust::complex<float>(a.real, a.imag)));
#else
return std::abs(std::complex<float>(a));
#endif
}
HOSTDEVICE inline complex64(pow)(const complex64& a, const complex64& b) {
#if defined(__CUDA_ARCH__)
return complex64(thrust::pow(thrust::complex<float>(a.real, a.imag),
thrust::complex<float>(b.real, b.imag)));
#else
return std::pow(std::complex<float>(a), std::complex<float>(b));
#endif
}
HOSTDEVICE inline complex64(sqrt)(const complex64& a) {
#if defined(__CUDA_ARCH__)
return complex64(thrust::sqrt(thrust::complex<float>(a.real, a.imag)));
#else
return std::sqrt(std::complex<float>(a));
#endif
}
HOSTDEVICE inline complex64(tanh)(const complex64& a) {
#if defined(__CUDA_ARCH__)
return complex64(thrust::tanh(thrust::complex<float>(a.real, a.imag)));
#else
return std::tanh(std::complex<float>(a));
#endif
}
HOSTDEVICE inline complex64(log)(const complex64& a) {
#if defined(__CUDA_ARCH__)
return complex64(thrust::log(thrust::complex<float>(a.real, a.imag)));
#else
return std::log(std::complex<float>(a));
#endif
}
inline std::ostream& operator<<(std::ostream& os, const complex64& a) {
os << "real:" << a.real << " imag:" << a.imag;
return os;
}
} // namespace platform
} // namespace paddle
namespace std {
template <>
struct is_pod<paddle::platform::complex64> {
static const bool value =
is_trivial<paddle::platform::complex64>::value &&
is_standard_layout<paddle::platform::complex64>::value;
};
template <>
struct is_floating_point<paddle::platform::complex64>
: std::integral_constant<
bool, std::is_same<paddle::platform::complex64,
typename std::remove_cv<
paddle::platform::complex64>::type>::value> {};
template <>
struct is_signed<paddle::platform::complex64> {
static const bool value = false;
};
template <>
struct is_unsigned<paddle::platform::complex64> {
static const bool value = false;
};
inline bool isnan(const paddle::platform::complex64& a) {
return paddle::platform::isnan(a);
}
inline bool isinf(const paddle::platform::complex64& a) {
return paddle::platform::isinf(a);
}
template <>
struct numeric_limits<paddle::platform::complex64> {
static const bool is_specialized = false;
static const bool is_signed = false;
static const bool is_integer = false;
static const bool is_exact = false;
static const bool has_infinity = false;
static const bool has_quiet_NaN = false;
static const bool has_signaling_NaN = false;
static const float_denorm_style has_denorm = denorm_absent;
static const bool has_denorm_loss = false;
static const std::float_round_style round_style = std::round_toward_zero;
static const bool is_iec559 = false;
static const bool is_bounded = false;
static const bool is_modulo = false;
static const int digits = 0;
static const int digits10 = 0;
static const int max_digits10 = 0;
static const int radix = 0;
static const int min_exponent = 0;
static const int min_exponent10 = 0;
static const int max_exponent = 0;
static const int max_exponent10 = 0;
static const bool traps = false;
static const bool tinyness_before = false;
static paddle::platform::complex64(min)() {
return paddle::platform::complex64(0.0, 0.0);
}
static paddle::platform::complex64 lowest() {
return paddle::platform::complex64(0.0, 0.0);
}
static paddle::platform::complex64(max)() {
return paddle::platform::complex64(0.0, 0.0);
}
static paddle::platform::complex64 epsilon() {
return paddle::platform::complex64(0.0, 0.0);
}
static paddle::platform::complex64 round_error() {
return paddle::platform::complex64(0.0, 0.0);
}
static paddle::platform::complex64 infinity() {
return paddle::platform::complex64(0.0, 0.0);
}
static paddle::platform::complex64 quiet_NaN() {
return paddle::platform::complex64(0.0, 0.0);
}
static paddle::platform::complex64 signaling_NaN() {
return paddle::platform::complex64(0.0, 0.0);
}
static paddle::platform::complex64 denorm_min() {
return paddle::platform::complex64(0.0, 0.0);
}
};
} // namespace std
namespace Eigen {
using complex64 = paddle::platform::complex64;
template <>
struct NumTraits<complex64> : GenericNumTraits<std::complex<float>> {
typedef float Real;
typedef typename NumTraits<float>::Literal Literal;
enum {
IsComplex = 1,
RequireInitialization = NumTraits<float>::RequireInitialization,
ReadCost = 2 * NumTraits<float>::ReadCost,
AddCost = 2 * NumTraits<Real>::AddCost,
MulCost = 4 * NumTraits<Real>::MulCost + 2 * NumTraits<Real>::AddCost
};
EIGEN_DEVICE_FUNC
static inline Real epsilon() { return NumTraits<Real>::epsilon(); }
EIGEN_DEVICE_FUNC
static inline Real dummy_precision() {
return NumTraits<Real>::dummy_precision();
}
EIGEN_DEVICE_FUNC
static inline int digits10() { return NumTraits<Real>::digits10(); }
};
namespace numext {
template <>
HOSTDEVICE inline bool(isnan)(const complex64& a) {
return (paddle::platform::isnan)(a);
}
template <>
HOSTDEVICE inline bool(isinf)(const complex64& a) {
return (paddle::platform::isinf)(a);
}
template <>
HOSTDEVICE inline bool(isfinite)(const complex64& a) {
return (paddle::platform::isfinite)(a);
}
template <>
HOSTDEVICE inline complex64 exp(const complex64& a) {
float com = ::expf(a.real);
float res_real = com * ::cosf(a.imag);
float res_imag = com * ::sinf(a.imag);
return complex64(res_real, res_imag);
}
template <>
HOSTDEVICE inline complex64 log(const complex64& a) {
return paddle::platform::log(a);
}
template <>
HOSTDEVICE inline complex64 tanh(const complex64& a) {
return paddle::platform::tanh(a);
}
template <>
HOSTDEVICE inline complex64 sqrt(const complex64& a) {
return paddle::platform::sqrt(a);
}
template <>
HOSTDEVICE inline complex64 ceil(const complex64& a) {
return complex64(::ceilf(a.real), ::ceilf(a.imag));
}
template <>
HOSTDEVICE inline complex64 floor(const complex64& a) {
return complex64(::floorf(a.real), ::floor(a.imag));
}
template <>
HOSTDEVICE inline complex64 round(const complex64& a) {
return complex64(::roundf(a.real), ::roundf(a.imag));
}
template <>
HOSTDEVICE inline complex64 pow(const complex64& a, const complex64& b) {
return paddle::platform::pow(a, b);
}
template <>
HOSTDEVICE inline float abs(const complex64& a) {
return paddle::platform::abs(a);
}
} // namespace numext
} // namespace Eigen
#define MKL_Complex8 paddle::platform::complex64
......@@ -18,6 +18,8 @@ limitations under the License. */
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -104,11 +106,54 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
return float16(__shfl_down_sync(mask, static_cast<half>(val),
static_cast<unsigned>(delta), width));
}
template <>
__forceinline__ __device__ paddle::platform::complex64 CudaShuffleDownSync(
unsigned mask, paddle::platform::complex64 val, int delta, int width) {
float real = static_cast<float>(__shfl_down_sync(
mask, static_cast<float>(val.real), static_cast<unsigned>(delta), width));
float imag = static_cast<float>(__shfl_down_sync(
mask, static_cast<float>(val.imag), static_cast<unsigned>(delta), width));
return paddle::platform::complex64(real, imag);
}
template <>
__forceinline__ __device__ paddle::platform::complex128 CudaShuffleDownSync(
unsigned mask, paddle::platform::complex128 val, int delta, int width) {
double real = static_cast<double>(
__shfl_down_sync(mask, static_cast<double>(val.real),
static_cast<unsigned>(delta), width));
double imag = static_cast<double>(
__shfl_down_sync(mask, static_cast<double>(val.imag),
static_cast<unsigned>(delta), width));
return paddle::platform::complex128(real, imag);
}
template <>
__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
float16 val, int width) {
return float16(__shfl_xor_sync(mask, static_cast<half>(val), width));
}
template <>
__forceinline__ __device__ paddle::platform::complex64 CudaShuffleXorSync(
unsigned mask, paddle::platform::complex64 val, int width) {
float real = static_cast<float>(
__shfl_xor_sync(mask, static_cast<float>(val.real), width));
float imag = static_cast<float>(
__shfl_xor_sync(mask, static_cast<float>(val.imag), width));
return paddle::platform::complex64(real, imag);
}
template <>
__forceinline__ __device__ paddle::platform::complex128 CudaShuffleXorSync(
unsigned mask, paddle::platform::complex128 val, int width) {
double real = static_cast<double>(
__shfl_xor_sync(mask, static_cast<double>(val.real), width));
double imag = static_cast<double>(
__shfl_xor_sync(mask, static_cast<double>(val.imag), width));
return paddle::platform::complex128(real, imag);
}
#endif
template <typename T>
......
......@@ -61,8 +61,12 @@ extern void *cublas_dso_handle;
__macro(cublasDcopy_v2); \
__macro(cublasSgemv_v2); \
__macro(cublasDgemv_v2); \
__macro(cublasCgemv_v2); \
__macro(cublasZgemv_v2); \
__macro(cublasSgemm_v2); \
__macro(cublasDgemm_v2); \
__macro(cublasCgemm_v2); \
__macro(cublasZgemm_v2); \
__macro(cublasHgemm); \
__macro(cublasSgemmEx); \
__macro(cublasSgeam); \
......
......@@ -51,12 +51,20 @@ extern void* mklml_dso_handle;
#define MKLML_ROUTINE_EACH(__macro) \
__macro(cblas_sgemm); \
__macro(cblas_dgemm); \
__macro(cblas_cgemm); \
__macro(cblas_zgemm); \
__macro(cblas_saxpy); \
__macro(cblas_daxpy); \
__macro(cblas_caxpy); \
__macro(cblas_zaxpy); \
__macro(cblas_scopy); \
__macro(cblas_dcopy); \
__macro(cblas_ccopy); \
__macro(cblas_zcopy); \
__macro(cblas_sgemv); \
__macro(cblas_dgemv); \
__macro(cblas_cgemv); \
__macro(cblas_zgemv); \
__macro(cblas_strsm); \
__macro(cblas_dtrsm); \
__macro(cblas_sgemm_alloc); \
......@@ -69,6 +77,8 @@ extern void* mklml_dso_handle;
__macro(cblas_dgemm_free); \
__macro(cblas_sgemm_batch); \
__macro(cblas_dgemm_batch); \
__macro(cblas_cgemm_batch); \
__macro(cblas_zgemm_batch); \
__macro(cblas_sdot); \
__macro(cblas_ddot); \
__macro(cblas_sasum); \
......
......@@ -185,6 +185,8 @@ void BindVarDsec(pybind11::module *m) {
.value("FP32", pd::proto::VarType::FP32)
.value("FP64", pd::proto::VarType::FP64)
.value("BF16", pd::proto::VarType::BF16)
.value("COMPLEX64", pd::proto::VarType::COMPLEX64)
.value("COMPLEX128", pd::proto::VarType::COMPLEX128)
.value("LOD_TENSOR", pd::proto::VarType::LOD_TENSOR)
.value("SELECTED_ROWS", pd::proto::VarType::SELECTED_ROWS)
.value("FEED_MINIBATCH", pd::proto::VarType::FEED_MINIBATCH)
......
......@@ -42,6 +42,8 @@ namespace detail {
// print np.dtype(np.float16).num # 23
constexpr int NPY_FLOAT16_ = 23;
constexpr int NPY_UINT16_ = 4;
constexpr int NPY_COMPLEX64 = 14;
constexpr int NPY_COMPLEX128 = 15;
// Note: Since float16 is not a builtin type in C++, we register
// paddle::platform::float16 as numpy.float16.
......@@ -78,6 +80,44 @@ struct npy_format_descriptor<paddle::platform::bfloat16> {
static constexpr auto name = _("bfloat16");
};
// we register paddle::platform::complex64 as numpy.complex64.
template <>
struct npy_format_descriptor<paddle::platform::complex64> {
static py::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX64);
return reinterpret_borrow<py::dtype>(ptr);
}
static std::string format() {
// Note: "F" represents complex64.
// Details at:
// https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx
// for k, v in np.sctypeDict.iteritems():
// print '{0:14s} : {1:40s}'.format(str(k), v)
return "F";
}
static constexpr auto name = _("complext64");
};
// we register paddle::platform::complex128 as numpy.complex128.
template <>
struct npy_format_descriptor<paddle::platform::complex128> {
static py::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX128);
return reinterpret_borrow<py::dtype>(ptr);
}
static std::string format() {
// Note: "D" represents complex128.
// Details at:
// https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx
// for k, v in np.sctypeDict.iteritems():
// print '{0:14s} : {1:40s}'.format(str(k), v)
return "D";
}
static constexpr auto name = _("complext128");
};
} // namespace detail
} // namespace pybind11
......@@ -124,6 +164,8 @@ struct ValidDTypeToPyArrayChecker {
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::bfloat16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex64);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex128);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(float);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(double);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool);
......@@ -142,6 +184,10 @@ inline std::string TensorDTypeToPyDTypeStr(
} else if (std::is_same<T, platform::bfloat16>::value) { \
/* NumPy character code of uint16 due to no support for bfloat16 */ \
return "H"; \
} else if (std::is_same<T, platform::complex64>::value) { \
return "F"; \
} else if (std::is_same<T, platform::complex128>::value) { \
return "D"; \
} else { \
constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker<T>::kValue; \
PADDLE_ENFORCE_EQ( \
......@@ -284,6 +330,12 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
} else if (py::isinstance<py::array_t<paddle::platform::float16>>(array)) {
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<paddle::platform::complex64>>(array)) {
SetTensorFromPyArrayT<paddle::platform::complex64, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<paddle::platform::complex128>>(array)) {
SetTensorFromPyArrayT<paddle::platform::complex128, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<uint16_t>>(array)) {
// since there is still no support for bfloat16 in NumPy,
// uint16 is used for casting bfloat16
......@@ -504,6 +556,10 @@ inline framework::Tensor *_sliceTensor(const framework::Tensor &self,
return _sliceAndConcat<paddle::platform::float16>(self, obj, dim);
case framework::proto::VarType::BF16:
return _sliceAndConcat<paddle::platform::bfloat16>(self, obj, dim);
case framework::proto::VarType::COMPLEX64:
return _sliceAndConcat<paddle::platform::complex64>(self, obj, dim);
case framework::proto::VarType::COMPLEX128:
return _sliceAndConcat<paddle::platform::complex128>(self, obj, dim);
case framework::proto::VarType::FP32:
return _sliceAndConcat<float>(self, obj, dim);
case framework::proto::VarType::FP64:
......
......@@ -47,6 +47,10 @@ def convert_dtype(dtype):
return 'int64'
elif dtype == core.VarDesc.VarType.UINT8:
return 'uint8'
elif dtype == core.VarDesc.VarType.COMPLEX64:
return 'complex64'
elif dtype == core.VarDesc.VarType.COMPLEX128:
return 'complex128'
elif isinstance(dtype, type):
if dtype in [
np.bool, np.float16, np.float32, np.float64, np.int8, np.int16,
......
......@@ -643,6 +643,10 @@ def convert_np_dtype_to_dtype_(np_dtype):
return core.VarDesc.VarType.UINT8
elif dtype == np.int8:
return core.VarDesc.VarType.INT8
elif dtype == np.complex64:
return core.VarDesc.VarType.COMPLEX64
elif dtype == np.complex128:
return core.VarDesc.VarType.COMPLEX128
else:
raise ValueError("Not supported numpy dtype %s" % dtype)
......
......@@ -26,6 +26,13 @@ layers = {
"div": cpx.elementwise_div,
}
fluid_layers = {
"add": fluid.layers.elementwise_add,
"sub": fluid.layers.elementwise_sub,
"mul": fluid.layers.elementwise_mul,
"div": fluid.layers.elementwise_div,
}
class TestComplexElementwiseLayers(unittest.TestCase):
def setUp(self):
......@@ -40,6 +47,22 @@ class TestComplexElementwiseLayers(unittest.TestCase):
var_y = dg.to_variable(y)
return layers[layer_type](var_x, var_y).numpy()
def fuild_calc(self, x, y, layer_type, place):
with dg.guard(place):
var_x = fluid.core.VarBase(
value=x,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
var_y = fluid.core.VarBase(
value=y,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
return fluid_layers[layer_type](var_x, var_y).numpy()
def compare(self, x, y):
for place in self._places:
self.assertTrue(np.allclose(self.calc(x, y, "add", place), x + y))
......@@ -47,6 +70,17 @@ class TestComplexElementwiseLayers(unittest.TestCase):
self.assertTrue(np.allclose(self.calc(x, y, "mul", place), x * y))
self.assertTrue(np.allclose(self.calc(x, y, "div", place), x / y))
def compare_1(self, x, y):
for place in self._places:
self.assertTrue(
np.allclose(self.fuild_calc(x, y, "add", place), x + y))
self.assertTrue(
np.allclose(self.fuild_calc(x, y, "sub", place), x - y))
self.assertTrue(
np.allclose(self.fuild_calc(x, y, "mul", place), x * y))
self.assertTrue(
np.allclose(self.fuild_calc(x, y, "div", place), x / y))
def compare_op(self, x, y):
for place in self._places:
with dg.guard(place):
......@@ -57,6 +91,26 @@ class TestComplexElementwiseLayers(unittest.TestCase):
self.assertTrue(var_x * var_y, x * y)
self.assertTrue(var_x / var_y, x / y)
def compare_op_1(self, x, y):
for place in self._places:
with dg.guard(place):
var_x = fluid.core.VarBase(
value=x,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
var_y = fluid.core.VarBase(
value=y,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
self.assertTrue(np.allclose((var_x + var_y).numpy(), x + y))
self.assertTrue(np.allclose((var_x - var_y).numpy(), x - y))
self.assertTrue(np.allclose((var_x * var_y).numpy(), x * y))
self.assertTrue(np.allclose((var_x / var_y).numpy(), x / y))
def test_complex_xy(self):
x = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand(
[2, 3, 4, 5]).astype(self._dtype)
......@@ -64,6 +118,8 @@ class TestComplexElementwiseLayers(unittest.TestCase):
[2, 3, 4, 5]).astype(self._dtype)
self.compare(x, y)
self.compare_op(x, y)
self.compare_1(x, y)
self.compare_op_1(x, y)
def test_complex_x_real_y(self):
x = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand(
......@@ -78,6 +134,14 @@ class TestComplexElementwiseLayers(unittest.TestCase):
self.compare(x, y)
self.compare_op(x, y)
def test_complex64_xy(self):
x = rand([2, 3, 4, 5]).astype("float32") + 1j * rand(
[2, 3, 4, 5]).astype("float32")
y = rand([2, 3, 4, 5]).astype("float32") + 1j * rand(
[2, 3, 4, 5]).astype("float32")
self.compare_1(x, y)
self.compare_op_1(x, y)
if __name__ == '__main__':
unittest.main()
......@@ -36,6 +36,18 @@ class TestComplexGetitemLayer(unittest.TestCase):
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
for place in self._places:
with dg.guard(place):
x_var = fluid.core.VarBase(
value=x_np,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
x_var_slice = x_var[0]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
def test_case2(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
x_np_slice = x_np[0][1]
......@@ -47,6 +59,18 @@ class TestComplexGetitemLayer(unittest.TestCase):
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
for place in self._places:
with dg.guard(place):
x_var = fluid.core.VarBase(
value=x_np,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
x_var_slice = x_var[0][1]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
def test_case3(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
x_np_slice = x_np[0][1][2]
......@@ -58,6 +82,18 @@ class TestComplexGetitemLayer(unittest.TestCase):
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
for place in self._places:
with dg.guard(place):
x_var = fluid.core.VarBase(
value=x_np,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
x_var_slice = x_var[0][1][2]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
def test_case4(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
x_np_slice = x_np[0][1][0:3]
......@@ -69,6 +105,18 @@ class TestComplexGetitemLayer(unittest.TestCase):
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
for place in self._places:
with dg.guard(place):
x_var = fluid.core.VarBase(
value=x_np,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
x_var_slice = x_var[0][1][0:3]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
def test_case5(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
x_np_slice = x_np[0][1][0:4:2]
......@@ -80,6 +128,18 @@ class TestComplexGetitemLayer(unittest.TestCase):
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
for place in self._places:
with dg.guard(place):
x_var = fluid.core.VarBase(
value=x_np,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
x_var_slice = x_var[0][1][0:4:2]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
def test_case6(self):
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
x_np_slice = x_np[0][1:3][0:4:2]
......@@ -90,6 +150,17 @@ class TestComplexGetitemLayer(unittest.TestCase):
x_var_slice = x_var[0][1:3][0:4:2]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
for place in self._places:
with dg.guard(place):
x_var = fluid.core.VarBase(
value=x_np,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
x_var_slice = x_var[0][1:3][0:4:2]
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
if __name__ == '__main__':
......
......@@ -34,6 +34,25 @@ class TestComplexMatMulLayer(unittest.TestCase):
np_result = np.matmul(x, y)
self.assertTrue(np.allclose(result.numpy(), np_result))
def compare_1(self, x, y):
for place in self._places:
with dg.guard(place):
x_var = fluid.core.VarBase(
value=x,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
y_var = fluid.core.VarBase(
value=y,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
result = paddle.matmul(x_var, y_var)
np_result = np.matmul(x, y)
self.assertTrue(np.allclose(result.numpy(), np_result))
def compare_op(self, x, y):
for place in self._places:
with dg.guard(place):
......@@ -43,6 +62,25 @@ class TestComplexMatMulLayer(unittest.TestCase):
np_result = np.matmul(x, y)
self.assertTrue(np.allclose(result.numpy(), np_result))
def compare_op_1(self, x, y):
for place in self._places:
with dg.guard(place):
x_var = fluid.core.VarBase(
value=x,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
y_var = fluid.core.VarBase(
value=y,
place=fluid.framework._current_expected_place(),
persistable=False,
zero_copy=None,
name='')
result = x_var.matmul(y_var)
np_result = np.matmul(x, y)
self.assertTrue(np.allclose(result.numpy(), np_result))
def test_complex_xy(self):
x = np.random.random(
(2, 3, 4, 5)).astype("float32") + 1J * np.random.random(
......@@ -52,6 +90,8 @@ class TestComplexMatMulLayer(unittest.TestCase):
(2, 3, 5, 4)).astype("float32")
self.compare(x, y)
self.compare_op(x, y)
self.compare_1(x, y)
self.compare_op_1(x, y)
def test_complex_x(self):
x = np.random.random(
......@@ -68,6 +108,52 @@ class TestComplexMatMulLayer(unittest.TestCase):
(2, 3, 5, 4)).astype("float32")
self.compare(x, y)
def test_complex128_xy(self):
x = np.random.random(
(2, 3, 4, 5)).astype("float64") + 1J * np.random.random(
(2, 3, 4, 5)).astype("float64")
y = np.random.random(
(2, 3, 5, 4)).astype("float64") + 1J * np.random.random(
(2, 3, 5, 4)).astype("float64")
self.compare_1(x, y)
self.compare_op_1(x, y)
def test_complex_xy_gemv(self):
x = np.random.random(
(2, 1, 100)).astype("float32") + 1J * np.random.random(
(2, 1, 100)).astype("float32")
y = np.random.random((100)).astype("float32") + 1J * np.random.random(
(100)).astype("float32")
self.compare_1(x, y)
self.compare_op_1(x, y)
x = np.random.random(
(2, 1, 100)).astype("float64") + 1J * np.random.random(
(2, 1, 100)).astype("float64")
y = np.random.random((100)).astype("float64") + 1J * np.random.random(
(100)).astype("float64")
self.compare_1(x, y)
self.compare_op_1(x, y)
def test_complex_xy_gemm(self):
x = np.random.random(
(1, 2, 50)).astype("float32") + 1J * np.random.random(
(1, 2, 50)).astype("float32")
y = np.random.random(
(1, 50, 2)).astype("float32") + 1J * np.random.random(
(1, 50, 2)).astype("float32")
self.compare_1(x, y)
self.compare_op_1(x, y)
x = np.random.random(
(1, 2, 50)).astype("float64") + 1J * np.random.random(
(1, 2, 50)).astype("float64")
y = np.random.random(
(1, 50, 2)).astype("float64") + 1J * np.random.random(
(1, 50, 2)).astype("float64")
self.compare_1(x, y)
self.compare_op_1(x, y)
if __name__ == '__main__':
unittest.main()
......@@ -16,6 +16,9 @@ import unittest
import numpy as np
import paddle
import paddle.fluid.dygraph as dg
import paddle.fluid.core as core
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.data_feeder import convert_dtype
class TestComplexVariable(unittest.TestCase):
......@@ -43,6 +46,20 @@ class TestComplexVariable(unittest.TestCase):
self._dtype = "complex128"
self.compare()
def test_convert_np_dtype_to_dtype(self):
self.assertEqual(
convert_np_dtype_to_dtype_(np.complex64),
core.VarDesc.VarType.COMPLEX64)
self.assertEqual(
convert_np_dtype_to_dtype_(np.complex64),
core.VarDesc.VarType.COMPLEX64)
def test_convert_dtype(self):
self.assertEqual(
convert_dtype(core.VarDesc.VarType.COMPLEX64), "complex64")
self.assertEqual(
convert_dtype(core.VarDesc.VarType.COMPLEX128), "complex128")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册