未验证 提交 9ad800eb 编写于 作者: C Chen Weihang 提交者: GitHub

Support type promote for basic math ops (quantum required) (#29265)

* basic impl of type promote

* add comment & another testcase

* fix complex bugs & support python op promote type

* fix failed unittests & polish code

* add unittest for coverage

* change to only promote complex type

* polish code details

* polish several comments
上级 f31e5ada
...@@ -98,5 +98,58 @@ size_t SizeOfType(proto::VarType::Type type) { ...@@ -98,5 +98,58 @@ size_t SizeOfType(proto::VarType::Type type) {
DataTypeToString(type))); DataTypeToString(type)));
} }
// Now only supports promotion of complex type
bool NeedPromoteTypes(const proto::VarType::Type a,
const proto::VarType::Type b) {
return (IsComplexType(a) || IsComplexType(b));
}
int DataTypeNumAlign(const proto::VarType::Type t) {
int cast_type_num = -1;
if (t == proto::VarType::FP32 || t == proto::VarType::FP64) {
cast_type_num = static_cast<int>(t) - 5;
} else if (t == proto::VarType::COMPLEX64 ||
t == proto::VarType::COMPLEX128) {
cast_type_num = static_cast<int>(t) - 21;
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Only supports to align data type include float32, float64, complex64 "
"and complex128, but received data type is `s`.",
DataTypeToString(t)));
}
return cast_type_num;
}
// Now only supports promotion of complex type
proto::VarType::Type PromoteTypesIfComplexExists(
const proto::VarType::Type type_a, const proto::VarType::Type type_b) {
constexpr auto f4 = proto::VarType::FP32; // 5
constexpr auto f8 = proto::VarType::FP64; // 6
constexpr auto c4 = proto::VarType::COMPLEX64; // 23
constexpr auto c8 = proto::VarType::COMPLEX128; // 24
if (!NeedPromoteTypes(type_a, type_b)) {
// NOTE(chenweihang): keep consistent with rule in original op's impl,
// kernel type based on the first input tensor's dtype
return type_a;
}
int type_an = DataTypeNumAlign(type_a);
int type_bn = DataTypeNumAlign(type_b);
// Here is a complete rules table, but some rules are not used.
// It is still written this way because array accessing is still
// more efficient than if-else
static constexpr proto::VarType::Type promote_types_table[4][4] = {
/* f4 f8 c4 c8*/
/* f4 */ {f4, f8, c4, c8},
/* f8 */ {f8, f8, c8, c8},
/* c4 */ {c4, c8, c4, c8},
/* c8 */ {c8, c8, c8, c8},
};
return promote_types_table[type_an][type_bn];
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -141,5 +141,14 @@ inline std::ostream& operator<<(std::ostream& out, ...@@ -141,5 +141,14 @@ inline std::ostream& operator<<(std::ostream& out,
out << DataTypeToString(type); out << DataTypeToString(type);
return out; return out;
} }
extern inline bool IsComplexType(const proto::VarType::Type type) {
return (type == proto::VarType::COMPLEX64 ||
type == proto::VarType::COMPLEX128);
}
extern proto::VarType::Type PromoteTypesIfComplexExists(
const proto::VarType::Type type_a, const proto::VarType::Type type_b);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -1480,6 +1480,66 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType( ...@@ -1480,6 +1480,66 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
return data_type; return data_type;
} }
Tensor* OperatorWithKernel::GetTensorFormInputSafely(
const ExecutionContext& ctx, const std::string& name) const {
// 1. get variable and check
// NOTE: only supports signal input var now
// NOTE: using const_cast is because we don't have method
// can get single mutable var, and here will not change
// the var's data, only use some attribute
Variable* var = const_cast<Variable*>(ctx.InputVar(name));
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The variable %s is not found when promote complex types.", name));
// 2. get tensor and check
Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = var->GetMutable<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported input variable type in complex type promotion."));
}
PADDLE_ENFORCE_NOT_NULL(
t,
platform::errors::InvalidArgument(
"The Tensor of variable %s is nullptr when promote complex types."));
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
platform::errors::InvalidArgument(
"The Tensor in the %s Op's Input Variable %s(%s) is "
"not initialized.",
Type(), name, ctx.InputName(name)));
return t;
}
/** NOTE(chenweihang): For safety reasons, we now only
* perform type promotes for binary operations with
* complex type inputs, which is used to support the
* paddle quantum function.
* In other cases, the first input data type is used as
* the kernel data type.
*/
proto::VarType::Type OperatorWithKernel::IndicateOrPromoteVarDataTypes(
const ExecutionContext& ctx, const std::string& name1,
const std::string& name2) const {
// 1. Get tensor
auto* tensor_a = GetTensorFormInputSafely(ctx, name1);
auto* tensor_b = GetTensorFormInputSafely(ctx, name2);
// 2. Get two input types
auto type_a = tensor_a->type();
auto type_b = tensor_b->type();
// 3. Get first input type or promote complex types
auto target_type = PromoteTypesIfComplexExists(type_a, type_b);
return target_type;
}
OpKernelType OperatorWithKernel::GetExpectedKernelType( OpKernelType OperatorWithKernel::GetExpectedKernelType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace()); return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
......
...@@ -504,6 +504,10 @@ class OperatorWithKernel : public OperatorBase { ...@@ -504,6 +504,10 @@ class OperatorWithKernel : public OperatorBase {
proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx, proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx,
const std::string& name) const; const std::string& name) const;
proto::VarType::Type IndicateOrPromoteVarDataTypes(
const ExecutionContext& ctx, const std::string& name1,
const std::string& name2) const;
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
// change this to public so that in dygraph mode we can call it to check if we // change this to public so that in dygraph mode we can call it to check if we
...@@ -518,11 +522,6 @@ class OperatorWithKernel : public OperatorBase { ...@@ -518,11 +522,6 @@ class OperatorWithKernel : public OperatorBase {
} }
private: private:
void ParseInputDataType(const ExecutionContext& ctx, const std::string& name,
proto::VarType::Type* type) const;
// indicate kernel DataType by input data. By default all input data must be
// same.
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
void RunImpl(const Scope& scope, const platform::Place& place, void RunImpl(const Scope& scope, const platform::Place& place,
RuntimeContext* runtime_ctx) const; RuntimeContext* runtime_ctx) const;
...@@ -546,6 +545,17 @@ class OperatorWithKernel : public OperatorBase { ...@@ -546,6 +545,17 @@ class OperatorWithKernel : public OperatorBase {
void ChooseKernel(const RuntimeContext& ctx, const Scope& scope, void ChooseKernel(const RuntimeContext& ctx, const Scope& scope,
const platform::Place& place) const; const platform::Place& place) const;
/* Inner assist methods */
// indicate kernel DataType by input data.
// By default all input data must be same.
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
// used for IndicateDataType
void ParseInputDataType(const ExecutionContext& ctx, const std::string& name,
proto::VarType::Type* type) const;
// used for IndicateOrPromoteVarDataTypes
Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx,
const std::string& name) const;
protected: protected:
mutable std::unique_ptr<OpKernelType> kernel_type_; mutable std::unique_ptr<OpKernelType> kernel_type_;
mutable std::unique_ptr<OpKernelFunc> kernel_func_; mutable std::unique_ptr<OpKernelFunc> kernel_func_;
......
...@@ -1000,9 +1000,10 @@ std::ostream& print_tensor<paddle::platform::complex64>( ...@@ -1000,9 +1000,10 @@ std::ostream& print_tensor<paddle::platform::complex64>(
os << " - data: ["; os << " - data: [";
if (element_num > 0) { if (element_num > 0) {
os << signed(inspect[0].real) << signed(inspect[0].imag) << "j"; os << signed(inspect[0].real) << "+" << signed(inspect[0].imag) << "j";
for (int j = 1; j < element_num; ++j) { for (int j = 1; j < element_num; ++j) {
os << signed(inspect[j].real) << signed(inspect[j].imag) << "j"; os << " " << signed(inspect[j].real) << "+" << signed(inspect[j].imag)
<< "j";
} }
} }
os << "]"; os << "]";
...@@ -1017,9 +1018,10 @@ std::ostream& print_tensor<paddle::platform::complex128>( ...@@ -1017,9 +1018,10 @@ std::ostream& print_tensor<paddle::platform::complex128>(
os << " - data: ["; os << " - data: [";
if (element_num > 0) { if (element_num > 0) {
os << signed(inspect[0].real) << signed(inspect[0].imag) << "j"; os << signed(inspect[0].real) << "+" << signed(inspect[0].imag) << "j";
for (int j = 1; j < element_num; ++j) { for (int j = 1; j < element_num; ++j) {
os << signed(inspect[j].real) << signed(inspect[j].imag) << "j"; os << " " << signed(inspect[j].real) << "+" << signed(inspect[j].imag)
<< "j";
} }
} }
os << "]"; os << "]";
......
...@@ -96,4 +96,6 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>, ...@@ -96,4 +96,6 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, int64_t>, ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, bool>, ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>, ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>); ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::complex64>,
ops::CastOpKernel<CPU, paddle::platform::complex128>);
...@@ -25,4 +25,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,4 +25,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>); paddle::platform::float16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
...@@ -82,6 +82,17 @@ class CastOpKernel : public framework::OpKernel<InT> { ...@@ -82,6 +82,17 @@ class CastOpKernel : public framework::OpKernel<InT> {
CastFunction<DeviceContext, InT, uint8_t>(context); CastFunction<DeviceContext, InT, uint8_t>(context);
} else if (out_type == paddle::framework::proto::VarType::BOOL) { } else if (out_type == paddle::framework::proto::VarType::BOOL) {
CastFunction<DeviceContext, InT, bool>(context); CastFunction<DeviceContext, InT, bool>(context);
} else if (out_type == paddle::framework::proto::VarType::COMPLEX64) {
CastFunction<DeviceContext, InT, paddle::platform::complex64>(context);
} else if (out_type == paddle::framework::proto::VarType::COMPLEX128) {
CastFunction<DeviceContext, InT, paddle::platform::complex128>(context);
} else {
// NOTE(chenweihang): if else branch do nothing, the output var will
// be non-initialized in dygraph, which will throw error if the
// non-initialized var is used as the next op's input
PADDLE_THROW(platform::errors::Unimplemented(
"Now does not support casting Tensor to `%s` data type.",
framework::DataTypeToString(out_type)));
} }
} }
}; };
......
...@@ -30,7 +30,8 @@ class ElementwiseMulOp : public ElementwiseOp { ...@@ -30,7 +30,8 @@ class ElementwiseMulOp : public ElementwiseOp {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) { if (this->CanMKLDNNBeUsed(ctx)) {
...@@ -41,6 +42,19 @@ class ElementwiseMulOp : public ElementwiseOp { ...@@ -41,6 +42,19 @@ class ElementwiseMulOp : public ElementwiseOp {
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -105,7 +105,8 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -105,7 +105,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) { if (this->CanMKLDNNBeUsed(ctx)) {
...@@ -116,6 +117,19 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -116,6 +117,19 @@ class ElementwiseOp : public framework::OperatorWithKernel {
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
}; };
class ElementwiseOpInferVarType class ElementwiseOpInferVarType
......
...@@ -655,7 +655,8 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -655,7 +655,8 @@ class MatMulOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using mkldnn::memory; using mkldnn::memory;
...@@ -667,6 +668,19 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -667,6 +668,19 @@ class MatMulOp : public framework::OperatorWithKernel {
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
}; };
class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -85,9 +85,22 @@ class MatMulV2Op : public framework::OperatorWithKernel { ...@@ -85,9 +85,22 @@ class MatMulV2Op : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( auto data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
} }
}; };
......
...@@ -70,10 +70,13 @@ struct PADDLE_ALIGN(16) complex128 { ...@@ -70,10 +70,13 @@ struct PADDLE_ALIGN(16) complex128 {
} }
#endif #endif
HOSTDEVICE complex128(const float& val) { real = static_cast<double>(val); } HOSTDEVICE complex128(const float& val)
HOSTDEVICE complex128(const double& val) { real = val; } : real(static_cast<double>(val)), imag(0) {}
HOSTDEVICE complex128(const int& val) { real = static_cast<double>(val); } HOSTDEVICE complex128(const double& val) : real(val), imag(0) {}
HOSTDEVICE complex128(const int64_t& val) { real = static_cast<double>(val); } HOSTDEVICE complex128(const int& val)
: real(static_cast<double>(val)), imag(0) {}
HOSTDEVICE complex128(const int64_t& val)
: real(static_cast<double>(val)), imag(0) {}
HOSTDEVICE inline explicit operator std::complex<double>() { HOSTDEVICE inline explicit operator std::complex<double>() {
return static_cast<std::complex<double>>(std::complex<double>(real, imag)); return static_cast<std::complex<double>>(std::complex<double>(real, imag));
...@@ -94,51 +97,61 @@ struct PADDLE_ALIGN(16) complex128 { ...@@ -94,51 +97,61 @@ struct PADDLE_ALIGN(16) complex128 {
HOSTDEVICE inline complex128& operator=(int8_t val) { HOSTDEVICE inline complex128& operator=(int8_t val) {
real = static_cast<double>(val); real = static_cast<double>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex128& operator=(uint8_t val) { HOSTDEVICE inline complex128& operator=(uint8_t val) {
real = static_cast<double>(val); real = static_cast<double>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex128& operator=(int16_t val) { HOSTDEVICE inline complex128& operator=(int16_t val) {
real = static_cast<double>(val); real = static_cast<double>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex128& operator=(uint16_t val) { HOSTDEVICE inline complex128& operator=(uint16_t val) {
real = static_cast<double>(val); real = static_cast<double>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex128& operator=(int32_t val) { HOSTDEVICE inline complex128& operator=(int32_t val) {
real = static_cast<double>(val); real = static_cast<double>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex128& operator=(uint32_t val) { HOSTDEVICE inline complex128& operator=(uint32_t val) {
real = static_cast<double>(val); real = static_cast<double>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex128& operator=(int64_t val) { HOSTDEVICE inline complex128& operator=(int64_t val) {
real = static_cast<double>(val); real = static_cast<double>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex128& operator=(uint64_t val) { HOSTDEVICE inline complex128& operator=(uint64_t val) {
real = static_cast<double>(val); real = static_cast<double>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex128& operator=(float val) { HOSTDEVICE inline complex128& operator=(float val) {
real = val; real = val;
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex128& operator=(double val) { HOSTDEVICE inline complex128& operator=(double val) {
real = static_cast<double>(val); real = static_cast<double>(val);
imag = 0;
return *this; return *this;
} }
......
...@@ -70,14 +70,16 @@ struct PADDLE_ALIGN(8) complex64 { ...@@ -70,14 +70,16 @@ struct PADDLE_ALIGN(8) complex64 {
} }
#endif #endif
HOSTDEVICE complex64(const float& val) { real = val; } HOSTDEVICE complex64(const float& val) : real(val), imag(0) {}
HOSTDEVICE complex64(const double& val) { real = static_cast<float>(val); } HOSTDEVICE complex64(const double& val)
HOSTDEVICE complex64(const int& val) { real = static_cast<float>(val); } : real(static_cast<float>(val)), imag(0) {}
HOSTDEVICE complex64(const int64_t& val) { real = static_cast<float>(val); } HOSTDEVICE complex64(const int& val)
HOSTDEVICE complex64(const complex128& val) { : real(static_cast<float>(val)), imag(0) {}
real = static_cast<float>(val.real); HOSTDEVICE complex64(const int64_t& val)
imag = static_cast<float>(val.imag); : real(static_cast<float>(val)), imag(0) {}
} HOSTDEVICE complex64(const complex128& val)
: real(static_cast<float>(val.real)),
imag(static_cast<float>(val.imag)) {}
HOSTDEVICE inline explicit operator std::complex<float>() { HOSTDEVICE inline explicit operator std::complex<float>() {
return static_cast<std::complex<float>>(std::complex<float>(real, imag)); return static_cast<std::complex<float>>(std::complex<float>(real, imag));
...@@ -98,21 +100,25 @@ struct PADDLE_ALIGN(8) complex64 { ...@@ -98,21 +100,25 @@ struct PADDLE_ALIGN(8) complex64 {
HOSTDEVICE inline complex64& operator=(int8_t val) { HOSTDEVICE inline complex64& operator=(int8_t val) {
real = static_cast<float>(val); real = static_cast<float>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex64& operator=(uint8_t val) { HOSTDEVICE inline complex64& operator=(uint8_t val) {
real = static_cast<float>(val); real = static_cast<float>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex64& operator=(int16_t val) { HOSTDEVICE inline complex64& operator=(int16_t val) {
real = static_cast<float>(val); real = static_cast<float>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex64& operator=(uint16_t val) { HOSTDEVICE inline complex64& operator=(uint16_t val) {
real = static_cast<float>(val); real = static_cast<float>(val);
imag = 0;
return *this; return *this;
} }
...@@ -123,26 +129,31 @@ struct PADDLE_ALIGN(8) complex64 { ...@@ -123,26 +129,31 @@ struct PADDLE_ALIGN(8) complex64 {
HOSTDEVICE inline complex64& operator=(uint32_t val) { HOSTDEVICE inline complex64& operator=(uint32_t val) {
real = static_cast<float>(val); real = static_cast<float>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex64& operator=(int64_t val) { HOSTDEVICE inline complex64& operator=(int64_t val) {
real = static_cast<float>(val); real = static_cast<float>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex64& operator=(uint64_t val) { HOSTDEVICE inline complex64& operator=(uint64_t val) {
real = static_cast<float>(val); real = static_cast<float>(val);
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex64& operator=(float val) { HOSTDEVICE inline complex64& operator=(float val) {
real = val; real = val;
imag = 0;
return *this; return *this;
} }
HOSTDEVICE inline complex64& operator=(double val) { HOSTDEVICE inline complex64& operator=(double val) {
real = static_cast<float>(val); real = static_cast<float>(val);
imag = 0;
return *this; return *this;
} }
......
...@@ -514,6 +514,9 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -514,6 +514,9 @@ PYBIND11_MODULE(core_noavx, m) {
m.def("_set_paddle_lib_path", &paddle::platform::dynload::SetPaddleLibPath); m.def("_set_paddle_lib_path", &paddle::platform::dynload::SetPaddleLibPath);
m.def("_promote_types_if_complex_exists",
&paddle::framework::PromoteTypesIfComplexExists);
BindImperative(&m); BindImperative(&m);
py::class_<Tensor>(m, "Tensor", py::buffer_protocol()) py::class_<Tensor>(m, "Tensor", py::buffer_protocol())
......
...@@ -272,6 +272,7 @@ if avx_supported(): ...@@ -272,6 +272,7 @@ if avx_supported():
from .core_avx import _load_dygraph_dict from .core_avx import _load_dygraph_dict
from .core_avx import _create_loaded_parameter from .core_avx import _create_loaded_parameter
from .core_avx import _cuda_synchronize from .core_avx import _cuda_synchronize
from .core_avx import _promote_types_if_complex_exists
if sys.platform != 'win32': if sys.platform != 'win32':
from .core_avx import _set_process_pids from .core_avx import _set_process_pids
from .core_avx import _erase_process_pids from .core_avx import _erase_process_pids
...@@ -317,6 +318,7 @@ if load_noavx: ...@@ -317,6 +318,7 @@ if load_noavx:
from .core_noavx import _load_dygraph_dict from .core_noavx import _load_dygraph_dict
from .core_noavx import _create_loaded_parameter from .core_noavx import _create_loaded_parameter
from .core_noavx import _cuda_synchronize from .core_noavx import _cuda_synchronize
from .core_noavx import _promote_types_if_complex_exists
if sys.platform != 'win32': if sys.platform != 'win32':
from .core_noavx import _set_process_pids from .core_noavx import _set_process_pids
from .core_noavx import _erase_process_pids from .core_noavx import _erase_process_pids
......
...@@ -30,6 +30,27 @@ _supported_int_dtype_ = [ ...@@ -30,6 +30,27 @@ _supported_int_dtype_ = [
core.VarDesc.VarType.INT64, core.VarDesc.VarType.INT64,
] ]
# NOTE(chenweihang): We currently do not fully support the type promotion
# between tensors. Parting support here is because the interoperation of
# real and complex numbers in paddle quantum is very frequent, such as the
# binary operation between `float` and `complex64`, so we must support the
# correct type promotion on the APIs paddle quantum used.
# Now only check in dygraph (paddle quantum based dygraph)
# Full type promotion support will need to be fully verified later.
_supported_promote_complex_types_ = [
'__add__',
'__radd__',
'__sub__',
'__rsub__',
'__mul__',
'__rmul__',
'__div__',
'__truediv__',
'__rdiv__',
'__rtruediv__',
'__matmul__',
]
_already_patch_varbase = False _already_patch_varbase = False
...@@ -197,10 +218,22 @@ def monkey_patch_math_varbase(): ...@@ -197,10 +218,22 @@ def monkey_patch_math_varbase():
# add fill_op # add fill_op
other_var = create_scalar(value=other_var, dtype=lhs_dtype) other_var = create_scalar(value=other_var, dtype=lhs_dtype)
# 3. unify right var type to left var # 3. promote types or unify right var type to left var
rhs_dtype = other_var.dtype rhs_dtype = other_var.dtype
if lhs_dtype != rhs_dtype: if lhs_dtype != rhs_dtype:
other_var = astype(other_var, lhs_dtype) if method_name in _supported_promote_complex_types_:
# only when lhs_dtype or rhs_dtype is complex type,
# the dtype will promote, in other cases, directly
# use lhs_dtype, this is consistent will original rule
promote_dtype = core._promote_types_if_complex_exists(
lhs_dtype, rhs_dtype)
self = self if lhs_dtype == promote_dtype else astype(
self, promote_dtype)
other_var = other_var if rhs_dtype == promote_dtype else astype(
other_var, promote_dtype)
else:
other_var = astype(other_var, lhs_dtype)
if reverse: if reverse:
tmp = self tmp = self
self = other_var self = other_var
...@@ -266,6 +299,8 @@ def monkey_patch_math_varbase(): ...@@ -266,6 +299,8 @@ def monkey_patch_math_varbase():
'elementwise_floordiv', False, None)), 'elementwise_floordiv', False, None)),
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False, ('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False,
None)), None)),
('__matmul__', _binary_creator_('__matmul__', "matmul_v2", False,
None)),
## for logical compare ## for logical compare
('__eq__', _binary_creator_('__eq__', 'equal', False, None)), ('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)), ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
......
...@@ -17,6 +17,8 @@ from __future__ import print_function ...@@ -17,6 +17,8 @@ from __future__ import print_function
import op_test import op_test
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard from paddle.fluid import compiler, Program, program_guard
...@@ -88,5 +90,18 @@ class TestCastOpError(unittest.TestCase): ...@@ -88,5 +90,18 @@ class TestCastOpError(unittest.TestCase):
self.assertRaises(TypeError, test_dtype_type) self.assertRaises(TypeError, test_dtype_type)
class TestCastOpErrorInDygraph(unittest.TestCase):
def test_non_support_out_dtype(self):
paddle.disable_static()
with self.assertRaises(NotImplementedError):
tensor = paddle.randn([10, 10], 'float32')
core.ops.cast(tensor, 'in_dtype', core.VarDesc.VarType.FP32,
'out_dtype', core.VarDesc.VarType.INT16)
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
import unittest import unittest
import numpy as np import numpy as np
from numpy.random import random as rand from numpy.random import random as rand
from paddle import complex as cpx
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
from paddle import complex as cpx
layers = { layers = {
"add": cpx.elementwise_add, "add": cpx.elementwise_add,
...@@ -26,121 +28,135 @@ layers = { ...@@ -26,121 +28,135 @@ layers = {
"div": cpx.elementwise_div, "div": cpx.elementwise_div,
} }
fluid_layers = { paddle_apis = {
"add": fluid.layers.elementwise_add, "add": paddle.add,
"sub": fluid.layers.elementwise_sub, "sub": paddle.subtract,
"mul": fluid.layers.elementwise_mul, "mul": paddle.multiply,
"div": fluid.layers.elementwise_div, "div": paddle.divide,
} }
class TestComplexElementwiseLayers(unittest.TestCase): class TestComplexElementwiseLayers(unittest.TestCase):
def setUp(self): def setUp(self):
self._dtype = "float64" self._dtypes = ["float32", "float64"]
self._places = [fluid.CPUPlace()] self._places = [paddle.CPUPlace()]
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
self._places.append(fluid.CUDAPlace(0)) self._places.append(paddle.CUDAPlace(0))
def calc(self, x, y, layer_type, place): def calc(self, x, y, op, place):
with dg.guard(place): with dg.guard(place):
var_x = dg.to_variable(x) var_x = dg.to_variable(x)
var_y = dg.to_variable(y) var_y = dg.to_variable(y)
return layers[layer_type](var_x, var_y).numpy() return layers[op](var_x, var_y).numpy()
def fuild_calc(self, x, y, layer_type, place): def paddle_calc(self, x, y, op, place):
with dg.guard(place): with dg.guard(place):
var_x = fluid.core.VarBase( x_t = paddle.Tensor(
value=x, value=x,
place=fluid.framework._current_expected_place(), place=place,
persistable=False, persistable=False,
zero_copy=None, zero_copy=False,
name='') stop_gradient=True)
var_y = fluid.core.VarBase( y_t = paddle.Tensor(
value=y, value=y,
place=fluid.framework._current_expected_place(), place=place,
persistable=False, persistable=False,
zero_copy=None, zero_copy=False,
name='') stop_gradient=True)
return fluid_layers[layer_type](var_x, var_y).numpy() return paddle_apis[op](x_t, y_t).numpy()
def compare(self, x, y): def assert_check(self, pd_result, np_result, place):
self.assertTrue(
np.allclose(pd_result, np_result),
"\nplace: {}\npaddle diff result:\n {}\nnumpy diff result:\n {}\n".
format(place, pd_result[~np.isclose(pd_result, np_result)],
np_result[~np.isclose(pd_result, np_result)]))
def compare_by_complex_api(self, x, y):
for place in self._places: for place in self._places:
self.assertTrue(np.allclose(self.calc(x, y, "add", place), x + y)) self.assert_check(self.calc(x, y, "add", place), x + y, place)
self.assertTrue(np.allclose(self.calc(x, y, "sub", place), x - y)) self.assert_check(self.calc(x, y, "sub", place), x - y, place)
self.assertTrue(np.allclose(self.calc(x, y, "mul", place), x * y)) self.assert_check(self.calc(x, y, "mul", place), x * y, place)
self.assertTrue(np.allclose(self.calc(x, y, "div", place), x / y)) self.assert_check(self.calc(x, y, "div", place), x / y, place)
def compare_1(self, x, y): def compare_by_basic_api(self, x, y):
for place in self._places: for place in self._places:
self.assertTrue( self.assert_check(
np.allclose(self.fuild_calc(x, y, "add", place), x + y)) self.paddle_calc(x, y, "add", place), x + y, place)
self.assertTrue( self.assert_check(
np.allclose(self.fuild_calc(x, y, "sub", place), x - y)) self.paddle_calc(x, y, "sub", place), x - y, place)
self.assertTrue( self.assert_check(
np.allclose(self.fuild_calc(x, y, "mul", place), x * y)) self.paddle_calc(x, y, "mul", place), x * y, place)
self.assertTrue( self.assert_check(
np.allclose(self.fuild_calc(x, y, "div", place), x / y)) self.paddle_calc(x, y, "div", place), x / y, place)
def compare_op(self, x, y): def compare_op_by_complex_api(self, x, y):
for place in self._places: for place in self._places:
with dg.guard(place): with dg.guard(place):
var_x = dg.to_variable(x) var_x = dg.to_variable(x)
var_y = dg.to_variable(y) var_y = dg.to_variable(y)
self.assertTrue(var_x + var_y, x + y) self.assert_check((var_x + var_y).numpy(), x + y, place)
self.assertTrue(var_x - var_y, x - y) self.assert_check((var_x - var_y).numpy(), x - y, place)
self.assertTrue(var_x * var_y, x * y) self.assert_check((var_x * var_y).numpy(), x * y, place)
self.assertTrue(var_x / var_y, x / y) self.assert_check((var_x / var_y).numpy(), x / y, place)
def compare_op_1(self, x, y): def compare_op_by_basic_api(self, x, y):
for place in self._places: for place in self._places:
with dg.guard(place): with dg.guard(place):
var_x = fluid.core.VarBase( x_t = paddle.Tensor(
value=x, value=x,
place=fluid.framework._current_expected_place(), place=place,
persistable=False, persistable=False,
zero_copy=None, zero_copy=False,
name='') stop_gradient=True)
var_y = fluid.core.VarBase( y_t = paddle.Tensor(
value=y, value=y,
place=fluid.framework._current_expected_place(), place=place,
persistable=False, persistable=False,
zero_copy=None, zero_copy=False,
name='') stop_gradient=True)
self.assertTrue(np.allclose((var_x + var_y).numpy(), x + y)) self.assert_check((x_t + y_t).numpy(), x + y, place)
self.assertTrue(np.allclose((var_x - var_y).numpy(), x - y)) self.assert_check((x_t - y_t).numpy(), x - y, place)
self.assertTrue(np.allclose((var_x * var_y).numpy(), x * y)) self.assert_check((x_t * y_t).numpy(), x * y, place)
self.assertTrue(np.allclose((var_x / var_y).numpy(), x / y)) self.assert_check((x_t / y_t).numpy(), x / y, place)
def test_complex_xy(self): def test_complex_xy(self):
x = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand( for dtype in self._dtypes:
[2, 3, 4, 5]).astype(self._dtype) x = rand([2, 3, 4, 5]).astype(dtype) + 1j * rand(
y = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand( [2, 3, 4, 5]).astype(dtype)
[2, 3, 4, 5]).astype(self._dtype) y = rand([2, 3, 4, 5]).astype(dtype) + 1j * rand(
self.compare(x, y) [2, 3, 4, 5]).astype(dtype)
self.compare_op(x, y)
self.compare_1(x, y) self.compare_by_complex_api(x, y)
self.compare_op_1(x, y) self.compare_op_by_complex_api(x, y)
self.compare_op_by_complex_api(x, y)
self.compare_op_by_basic_api(x, y)
def test_complex_x_real_y(self): def test_complex_x_real_y(self):
x = rand([2, 3, 4, 5]).astype(self._dtype) + 1j * rand( for dtype in self._dtypes:
[2, 3, 4, 5]).astype(self._dtype) x = rand([2, 3, 4, 5]).astype(dtype) + 1j * rand(
y = rand([4, 5]).astype(self._dtype) [2, 3, 4, 5]).astype(dtype)
self.compare(x, y) y = rand([4, 5]).astype(dtype)
self.compare_op(x, y)
self.compare_by_complex_api(x, y)
self.compare_op_by_complex_api(x, y)
# promote types cases
self.compare_by_basic_api(x, y)
self.compare_op_by_basic_api(x, y)
def test_real_x_complex_y(self): def test_real_x_complex_y(self):
x = rand([2, 3, 4, 5]).astype(self._dtype) for dtype in self._dtypes:
y = rand([5]).astype(self._dtype) + 1j * rand([5]).astype(self._dtype) x = rand([2, 3, 4, 5]).astype(dtype)
self.compare(x, y) y = rand([5]).astype(dtype) + 1j * rand([5]).astype(dtype)
self.compare_op(x, y)
self.compare_by_complex_api(x, y)
def test_complex64_xy(self): self.compare_op_by_complex_api(x, y)
x = rand([2, 3, 4, 5]).astype("float32") + 1j * rand(
[2, 3, 4, 5]).astype("float32") # promote types cases
y = rand([2, 3, 4, 5]).astype("float32") + 1j * rand( self.compare_by_basic_api(x, y)
[2, 3, 4, 5]).astype("float32") self.compare_op_by_basic_api(x, y)
self.compare_1(x, y)
self.compare_op_1(x, y)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -21,21 +21,25 @@ import paddle.fluid.dygraph as dg ...@@ -21,21 +21,25 @@ import paddle.fluid.dygraph as dg
class TestComplexMatMulLayer(unittest.TestCase): class TestComplexMatMulLayer(unittest.TestCase):
def setUp(self): def setUp(self):
self._dtypes = ["float32", "float64"]
self._places = [fluid.CPUPlace()] self._places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
self._places.append(fluid.CUDAPlace(0)) self._places.append(fluid.CUDAPlace(0))
def compare_by_complex_api(self, x, y): def compare_by_complex_api(self, x, y, np_result):
np_result = np.matmul(x, y)
for place in self._places: for place in self._places:
with dg.guard(place): with dg.guard(place):
x_var = dg.to_variable(x) x_var = dg.to_variable(x)
y_var = dg.to_variable(y) y_var = dg.to_variable(y)
result = paddle.complex.matmul(x_var, y_var) result = paddle.complex.matmul(x_var, y_var)
self.assertTrue(np.allclose(result.numpy(), np_result)) pd_result = result.numpy()
self.assertTrue(
def compare_by_basic_api(self, x, y): np.allclose(pd_result, np_result),
np_result = np.matmul(x, y) "\nplace: {}\npaddle diff result:\n {}\nnumpy diff result:\n {}\n".
format(place, pd_result[~np.isclose(pd_result, np_result)],
np_result[~np.isclose(pd_result, np_result)]))
def compare_by_basic_api(self, x, y, np_result):
for place in self._places: for place in self._places:
with dg.guard(place): with dg.guard(place):
x_var = fluid.core.VarBase( x_var = fluid.core.VarBase(
...@@ -51,19 +55,27 @@ class TestComplexMatMulLayer(unittest.TestCase): ...@@ -51,19 +55,27 @@ class TestComplexMatMulLayer(unittest.TestCase):
zero_copy=None, zero_copy=None,
name='') name='')
result = paddle.matmul(x_var, y_var) result = paddle.matmul(x_var, y_var)
self.assertTrue(np.allclose(result.numpy(), np_result)) pd_result = result.numpy()
self.assertTrue(
def compare_op_by_complex_api(self, x, y): np.allclose(pd_result, np_result),
np_result = np.matmul(x, y) "\nplace: {}\npaddle diff result:\n {}\nnumpy diff result:\n {}\n".
format(place, pd_result[~np.isclose(pd_result, np_result)],
np_result[~np.isclose(pd_result, np_result)]))
def compare_op_by_complex_api(self, x, y, np_result):
for place in self._places: for place in self._places:
with dg.guard(place): with dg.guard(place):
x_var = dg.to_variable(x) x_var = dg.to_variable(x)
y_var = dg.to_variable(y) y_var = dg.to_variable(y)
result = x_var.matmul(y_var) result = x_var.matmul(y_var)
self.assertTrue(np.allclose(result.numpy(), np_result)) pd_result = result.numpy()
self.assertTrue(
def compare_op_by_basic_api(self, x, y): np.allclose(pd_result, np_result),
np_result = np.matmul(x, y) "\nplace: {}\npaddle diff result:\n {}\nnumpy diff result:\n {}\n".
format(place, pd_result[~np.isclose(pd_result, np_result)],
np_result[~np.isclose(pd_result, np_result)]))
def compare_op_by_basic_api(self, x, y, np_result):
for place in self._places: for place in self._places:
with dg.guard(place): with dg.guard(place):
x_var = fluid.core.VarBase( x_var = fluid.core.VarBase(
...@@ -79,126 +91,89 @@ class TestComplexMatMulLayer(unittest.TestCase): ...@@ -79,126 +91,89 @@ class TestComplexMatMulLayer(unittest.TestCase):
zero_copy=None, zero_copy=None,
name='') name='')
result = x_var.matmul(y_var) result = x_var.matmul(y_var)
self.assertTrue(np.allclose(result.numpy(), np_result)) pd_result = result.numpy()
self.assertTrue(
np.allclose(pd_result, np_result),
"\nplace: {}\npaddle diff result:\n {}\nnumpy diff result:\n {}\n".
format(place, pd_result[~np.isclose(pd_result, np_result)],
np_result[~np.isclose(pd_result, np_result)]))
def test_complex_xy(self): def test_complex_xy(self):
x = np.random.random( for dtype in self._dtypes:
(2, 3, 4, 5)).astype("float32") + 1J * np.random.random( x = np.random.random(
(2, 3, 4, 5)).astype("float32") (2, 3, 4, 5)).astype(dtype) + 1J * np.random.random(
y = np.random.random( (2, 3, 4, 5)).astype(dtype)
(2, 3, 5, 4)).astype("float32") + 1J * np.random.random( y = np.random.random(
(2, 3, 5, 4)).astype("float32") (2, 3, 5, 4)).astype(dtype) + 1J * np.random.random(
self.compare_by_complex_api(x, y) (2, 3, 5, 4)).astype(dtype)
self.compare_op_by_complex_api(x, y)
self.compare_by_basic_api(x, y)
self.compare_op_by_basic_api(x, y)
def test_complex_x(self):
x = np.random.random(
(2, 3, 4, 5)).astype("float32") + 1J * np.random.random(
(2, 3, 4, 5)).astype("float32")
y = np.random.random((2, 3, 5, 4)).astype("float32")
self.compare_by_complex_api(x, y)
self.compare_op_by_complex_api(x, y)
def test_complex_y(self):
x = np.random.random((2, 3, 4, 5)).astype("float32")
y = np.random.random(
(2, 3, 5, 4)).astype("float32") + 1J * np.random.random(
(2, 3, 5, 4)).astype("float32")
self.compare_by_complex_api(x, y)
def test_complex_xy_128(self):
x = np.random.random(
(2, 3, 4, 5)).astype("float64") + 1J * np.random.random(
(2, 3, 4, 5)).astype("float64")
y = np.random.random(
(2, 3, 5, 4)).astype("float64") + 1J * np.random.random(
(2, 3, 5, 4)).astype("float64")
self.compare_by_basic_api(x, y)
self.compare_op_by_basic_api(x, y)
def test_complex_xy_gemv(self): np_result = np.matmul(x, y)
x = np.random.random(
(2, 1, 100)).astype("float32") + 1J * np.random.random(
(2, 1, 100)).astype("float32")
y = np.random.random((100)).astype("float32") + 1J * np.random.random(
(100)).astype("float32")
self.compare_by_basic_api(x, y)
self.compare_op_by_basic_api(x, y)
x = np.random.random(
(2, 1, 100)).astype("float64") + 1J * np.random.random(
(2, 1, 100)).astype("float64")
y = np.random.random((100)).astype("float64") + 1J * np.random.random(
(100)).astype("float64")
self.compare_by_basic_api(x, y)
self.compare_op_by_basic_api(x, y)
def test_complex_xy_gemm_128(self):
x = np.random.random(
(1, 2, 50)).astype("float64") + 1J * np.random.random(
(1, 2, 50)).astype("float64")
y = np.random.random(
(1, 50, 2)).astype("float64") + 1J * np.random.random(
(1, 50, 2)).astype("float64")
self.compare_by_basic_api(x, y)
self.compare_op_by_basic_api(x, y)
class TestComplexMatMulLayerGEMM(unittest.TestCase):
def setUp(self):
self._places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
self._places.append(fluid.CUDAPlace(0))
def compare_by_basic_api(self, x, y): self.compare_by_complex_api(x, y, np_result)
np_result = np.matmul(x, y) self.compare_op_by_complex_api(x, y, np_result)
for place in self._places:
with dg.guard(place):
x_var = fluid.core.VarBase(
value=x,
place=place,
persistable=False,
zero_copy=None,
name='')
y_var = fluid.core.VarBase(
value=y,
place=place,
persistable=False,
zero_copy=None,
name='')
result = paddle.matmul(x_var, y_var)
self.assertTrue(np.allclose(result.numpy(), np_result))
def compare_op_by_basic_api(self, x, y): self.compare_by_basic_api(x, y, np_result)
np_result = np.matmul(x, y) self.compare_op_by_basic_api(x, y, np_result)
for place in self._places:
with dg.guard(place): def test_complex_x_real_y(self):
x_var = fluid.core.VarBase( for dtype in self._dtypes:
value=x, x = np.random.random(
place=place, (2, 3, 4, 5)).astype(dtype) + 1J * np.random.random(
persistable=False, (2, 3, 4, 5)).astype(dtype)
zero_copy=None, y = np.random.random((2, 3, 5, 4)).astype(dtype)
name='')
y_var = fluid.core.VarBase( np_result = np.matmul(x, y)
value=y,
place=place, self.compare_by_complex_api(x, y, np_result)
persistable=False, self.compare_op_by_complex_api(x, y, np_result)
zero_copy=None,
name='') # float -> complex type promotion
result = x_var.matmul(y_var) self.compare_by_basic_api(x, y, np_result)
self.assertTrue(np.allclose(result.numpy(), np_result)) self.compare_op_by_basic_api(x, y, np_result)
def test_complex_xy_gemm_64(self): def test_real_x_complex_y(self):
x = np.random.random( for dtype in self._dtypes:
(1, 2, 50)).astype("float32") + 1J * np.random.random( x = np.random.random((2, 3, 4, 5)).astype(dtype)
(1, 2, 50)).astype("float32") y = np.random.random(
y = np.random.random( (2, 3, 5, 4)).astype(dtype) + 1J * np.random.random(
(1, 50, 2)).astype("float32") + 1J * np.random.random( (2, 3, 5, 4)).astype(dtype)
(1, 50, 2)).astype("float32")
self.compare_by_basic_api(x, y) np_result = np.matmul(x, y)
self.compare_op_by_basic_api(x, y)
self.compare_by_complex_api(x, y, np_result)
# float -> complex type promotion
self.compare_by_basic_api(x, y, np_result)
self.compare_op_by_basic_api(x, y, np_result)
# for coverage
def test_complex_xy_gemv(self):
for dtype in self._dtypes:
x = np.random.random(
(2, 1, 100)).astype(dtype) + 1J * np.random.random(
(2, 1, 100)).astype(dtype)
y = np.random.random((100)).astype(dtype) + 1J * np.random.random(
(100)).astype(dtype)
np_result = np.matmul(x, y)
self.compare_by_basic_api(x, y, np_result)
self.compare_op_by_basic_api(x, y, np_result)
# for coverage
def test_complex_xy_gemm(self):
for dtype in self._dtypes:
x = np.random.random(
(1, 2, 50)).astype(dtype) + 1J * np.random.random(
(1, 2, 50)).astype(dtype)
y = np.random.random(
(1, 50, 2)).astype(dtype) + 1J * np.random.random(
(1, 50, 2)).astype(dtype)
np_result = np.matmul(x, y)
self.compare_by_basic_api(x, y, np_result)
self.compare_op_by_basic_api(x, y, np_result)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -262,6 +262,15 @@ class TestMathOpPatchesVarBase(unittest.TestCase): ...@@ -262,6 +262,15 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
res = a + b res = a + b
self.assertTrue(np.array_equal(res.numpy(), a_np + b_np)) self.assertTrue(np.array_equal(res.numpy(), a_np + b_np))
def test_floordiv_different_dtype(self):
a_np = np.full(self.shape, 10, np.int64)
b_np = np.full(self.shape, 2, np.int32)
with fluid.dygraph.guard():
a = paddle.to_tensor(a_np)
b = paddle.to_tensor(b_np)
res = a // b
self.assertTrue(np.array_equal(res.numpy(), a_np // b_np))
def test_astype(self): def test_astype(self):
a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
...@@ -127,41 +127,41 @@ class TestMultiplyError(unittest.TestCase): ...@@ -127,41 +127,41 @@ class TestMultiplyError(unittest.TestCase):
y = paddle.to_tensor(y_data) y = paddle.to_tensor(y_data)
self.assertRaises(ValueError, paddle.multiply, x, y) self.assertRaises(ValueError, paddle.multiply, x, y)
# test dynamic computation graph: dtype must be same # test dynamic computation graph: dtype must be same
x_data = np.random.randn(200).astype(np.int64) x_data = np.random.randn(200).astype(np.int64)
y_data = np.random.randn(200).astype(np.float64) y_data = np.random.randn(200).astype(np.float64)
x = paddle.to_tensor(x_data) x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data) y = paddle.to_tensor(y_data)
self.assertRaises(TypeError, paddle.multiply, x, y) self.assertRaises(ValueError, paddle.multiply, x, y)
# test dynamic computation graph: dtype must be Tensor type # test dynamic computation graph: dtype must be Tensor type
x_data = np.random.randn(200).astype(np.int64) x_data = np.random.randn(200).astype(np.int64)
y_data = np.random.randn(200).astype(np.float64) y_data = np.random.randn(200).astype(np.float64)
y = paddle.to_tensor(y_data) y = paddle.to_tensor(y_data)
self.assertRaises(TypeError, paddle.multiply, x_data, y) self.assertRaises(ValueError, paddle.multiply, x_data, y)
# test dynamic computation graph: dtype must be Tensor type # test dynamic computation graph: dtype must be Tensor type
x_data = np.random.randn(200).astype(np.int64) x_data = np.random.randn(200).astype(np.int64)
y_data = np.random.randn(200).astype(np.float64) y_data = np.random.randn(200).astype(np.float64)
x = paddle.to_tensor(x_data) x = paddle.to_tensor(x_data)
self.assertRaises(TypeError, paddle.multiply, x, y_data) self.assertRaises(ValueError, paddle.multiply, x, y_data)
# test dynamic computation graph: dtype must be Tensor type # test dynamic computation graph: dtype must be Tensor type
x_data = np.random.randn(200).astype(np.float32) x_data = np.random.randn(200).astype(np.float32)
y_data = np.random.randn(200).astype(np.float32) y_data = np.random.randn(200).astype(np.float32)
x = paddle.to_tensor(x_data) x = paddle.to_tensor(x_data)
self.assertRaises(TypeError, paddle.multiply, x, y_data) self.assertRaises(ValueError, paddle.multiply, x, y_data)
# test dynamic computation graph: dtype must be Tensor type # test dynamic computation graph: dtype must be Tensor type
x_data = np.random.randn(200).astype(np.float32) x_data = np.random.randn(200).astype(np.float32)
y_data = np.random.randn(200).astype(np.float32) y_data = np.random.randn(200).astype(np.float32)
x = paddle.to_tensor(x_data) x = paddle.to_tensor(x_data)
self.assertRaises(TypeError, paddle.multiply, x_data, y) self.assertRaises(ValueError, paddle.multiply, x_data, y)
# test dynamic computation graph: dtype must be Tensor type # test dynamic computation graph: dtype must be Tensor type
x_data = np.random.randn(200).astype(np.float32) x_data = np.random.randn(200).astype(np.float32)
y_data = np.random.randn(200).astype(np.float32) y_data = np.random.randn(200).astype(np.float32)
self.assertRaises(TypeError, paddle.multiply, x_data, y_data) self.assertRaises(ValueError, paddle.multiply, x_data, y_data)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -504,19 +504,15 @@ def multiply(x, y, name=None): ...@@ -504,19 +504,15 @@ def multiply(x, y, name=None):
act = None act = None
axis = -1 axis = -1
if in_dygraph_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type)
if x.dtype != y.dtype: if x.dtype != y.dtype:
raise TypeError( raise TypeError(
'Input tensors must be same type, but received type of x: %s, type of y: %s ' 'Input tensors must be same type, but received type of x: %s, type of y: %s '
% (x.dtype, y.dtype)) % (x.dtype, y.dtype))
if in_dygraph_mode():
if not isinstance(x, (paddle.Tensor)):
raise TypeError(
'Input x must tensor type, but received type of x: %s'
% (x.dtype))
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type)
return _elementwise_op(LayerHelper(op_type, **locals())) return _elementwise_op(LayerHelper(op_type, **locals()))
def maximum(x, y, name=None): def maximum(x, y, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册