From 87852616aaf2517567a68d6b7dd5a61ab3857380 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 18 Mar 2021 16:22:08 +0800 Subject: [PATCH] [CustomOp] Support complex dtype in custom op (#31657) * support custom complex op * fix detail error * add inference support * fix setup windows failed --- cmake/inference_lib.cmake | 6 + paddle/fluid/extension/include/ext_dispatch.h | 65 +++++++++ paddle/fluid/extension/include/ext_dtype.h | 31 ++-- paddle/fluid/extension/src/ext_tensor.cc | 34 +++++ paddle/fluid/framework/CMakeLists.txt | 7 +- paddle/fluid/framework/custom_operator.cc | 35 ++++- paddle/fluid/framework/custom_tensor_test.cc | 22 +++ paddle/fluid/framework/custom_tensor_utils.h | 8 ++ paddle/fluid/inference/CMakeLists.txt | 4 + paddle/fluid/pybind/CMakeLists.txt | 4 + .../fluid/tests/custom_op/CMakeLists.txt | 3 + .../fluid/tests/custom_op/custom_conj_op.cc | 94 ++++++++++++ .../fluid/tests/custom_op/dispatch_test_op.cc | 56 ++++++++ .../fluid/tests/custom_op/test_custom_conj.py | 136 ++++++++++++++++++ .../tests/custom_op/test_dispatch_jit.py | 20 +++ python/setup.py.in | 20 ++- 16 files changed, 530 insertions(+), 15 deletions(-) create mode 100644 python/paddle/fluid/tests/custom_op/custom_conj_op.cc create mode 100644 python/paddle/fluid/tests/custom_op/test_custom_conj.py diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 2cba3d0693..570b37ff11 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -192,6 +192,12 @@ include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io) copy(inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/extension/include/* DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) +copy(inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex64.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) +copy(inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex128.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/) # CAPI inference library for only inference set(PADDLE_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_c_install_dir" CACHE STRING diff --git a/paddle/fluid/extension/include/ext_dispatch.h b/paddle/fluid/extension/include/ext_dispatch.h index eed7360464..7b3893e283 100644 --- a/paddle/fluid/extension/include/ext_dispatch.h +++ b/paddle/fluid/extension/include/ext_dispatch.h @@ -68,6 +68,22 @@ namespace paddle { } \ }() +///////// Complex Dispatch Marco /////////// + +#define PD_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `" + \ + ::paddle::ToString(__dtype__) + "`"); \ + } \ + }() + ///////// Floating and Integral Dispatch Marco /////////// #define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \ @@ -93,6 +109,55 @@ namespace paddle { } \ }() +///////// Floating and Complex Dispatch Marco /////////// + +#define PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `" + \ + ::paddle::ToString(__dtype__) + "`"); \ + } \ + }() + +///////// Floating, Integral and Complex Dispatch Marco /////////// + +#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `" + \ + ::paddle::ToString(__dtype__) + "`"); \ + } \ + }() + // TODO(chenweihang): Add more Marcos in the future if needed } // namespace paddle diff --git a/paddle/fluid/extension/include/ext_dtype.h b/paddle/fluid/extension/include/ext_dtype.h index 46c4bac236..a1e58fbacd 100644 --- a/paddle/fluid/extension/include/ext_dtype.h +++ b/paddle/fluid/extension/include/ext_dtype.h @@ -16,10 +16,15 @@ limitations under the License. */ #include #include +#include "complex128.h" // NOLINT +#include "complex64.h" // NOLINT #include "ext_exception.h" // NOLINT namespace paddle { +using complex64 = paddle::platform::complex64; +using complex128 = paddle::platform::complex128; + enum class DataType { BOOL, INT8, @@ -29,6 +34,8 @@ enum class DataType { INT64, FLOAT32, FLOAT64, + COMPLEX64, + COMPLEX128, // TODO(JiabinYang) support more data types if needed. }; @@ -50,20 +57,26 @@ inline std::string ToString(DataType dtype) { return "float"; case DataType::FLOAT64: return "double"; + case DataType::COMPLEX64: + return "complex64"; + case DataType::COMPLEX128: + return "complex128"; default: PD_THROW("Unsupported paddle enum data type."); } } -#define PD_FOR_EACH_DATA_TYPE(_) \ - _(bool, DataType::BOOL) \ - _(int8_t, DataType::INT8) \ - _(uint8_t, DataType::UINT8) \ - _(int16_t, DataType::INT16) \ - _(int, DataType::INT32) \ - _(int64_t, DataType::INT64) \ - _(float, DataType::FLOAT32) \ - _(double, DataType::FLOAT64) +#define PD_FOR_EACH_DATA_TYPE(_) \ + _(bool, DataType::BOOL) \ + _(int8_t, DataType::INT8) \ + _(uint8_t, DataType::UINT8) \ + _(int16_t, DataType::INT16) \ + _(int, DataType::INT32) \ + _(int64_t, DataType::INT64) \ + _(float, DataType::FLOAT32) \ + _(double, DataType::FLOAT64) \ + _(complex64, DataType::COMPLEX64) \ + _(complex128, DataType::COMPLEX128) template struct DataTypeToCPPType; diff --git a/paddle/fluid/extension/src/ext_tensor.cc b/paddle/fluid/extension/src/ext_tensor.cc index 4434a3bf59..cb37bf180c 100644 --- a/paddle/fluid/extension/src/ext_tensor.cc +++ b/paddle/fluid/extension/src/ext_tensor.cc @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/extension/include/ext_tensor.h" + #include + #include "paddle/fluid/framework/custom_tensor_utils.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/complex128.h" +#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/transform.h" @@ -162,6 +166,10 @@ DataType Tensor::type() const { return DataType::FLOAT64; } else if (type == framework::proto::VarType::BOOL) { return DataType::BOOL; + } else if (type == framework::proto::VarType::COMPLEX64) { + return DataType::COMPLEX64; + } else if (type == framework::proto::VarType::COMPLEX128) { + return DataType::COMPLEX128; } // TODO(JiabinYang) Support more dtype here return DataType::FLOAT32; @@ -217,6 +225,10 @@ template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; +template PD_DLL_DECL Tensor Tensor::copy_to( + const PlaceType &target_place) const; +template PD_DLL_DECL Tensor Tensor::copy_to( + const PlaceType &target_place) const; template PD_DLL_DECL float *Tensor::data() const; template PD_DLL_DECL double *Tensor::data() const; @@ -226,6 +238,10 @@ template PD_DLL_DECL uint8_t *Tensor::data() const; template PD_DLL_DECL int8_t *Tensor::data() const; template PD_DLL_DECL int16_t *Tensor::data() const; template PD_DLL_DECL bool *Tensor::data() const; +template PD_DLL_DECL paddle::platform::complex64 * +Tensor::data() const; +template PD_DLL_DECL paddle::platform::complex128 * +Tensor::data() const; template PD_DLL_DECL float *Tensor::mutable_data(); template PD_DLL_DECL double *Tensor::mutable_data(); @@ -235,6 +251,10 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data(); template PD_DLL_DECL int8_t *Tensor::mutable_data(); template PD_DLL_DECL int16_t *Tensor::mutable_data(); template PD_DLL_DECL bool *Tensor::mutable_data(); +template PD_DLL_DECL paddle::platform::complex64 * +Tensor::mutable_data(); +template PD_DLL_DECL paddle::platform::complex128 * +Tensor::mutable_data(); template PD_DLL_DECL float *Tensor::mutable_data(const PlaceType &place); template PD_DLL_DECL double *Tensor::mutable_data( @@ -250,6 +270,10 @@ template PD_DLL_DECL int8_t *Tensor::mutable_data( template PD_DLL_DECL int16_t *Tensor::mutable_data( const PlaceType &place); template PD_DLL_DECL bool *Tensor::mutable_data(const PlaceType &place); +template PD_DLL_DECL paddle::platform::complex64 * +Tensor::mutable_data(const PlaceType &place); +template PD_DLL_DECL paddle::platform::complex128 * +Tensor::mutable_data(const PlaceType &place); std::vector Tensor::shape() const { GET_CASTED_TENSOR @@ -310,6 +334,16 @@ Tensor Tensor::cast(const DataType &target_type) const { framework::VisitDataType( dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); break; + case framework::proto::VarType::COMPLEX64: + framework::VisitDataType( + dst_type, + CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::COMPLEX128: + framework::VisitDataType(dst_type, + CastDataType( + *tensor, rlt_tensor_, ctx)); + break; // TODO(JiabinYang) Support more dtype here default: PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 43bbc06787..1fa4ce9b57 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -346,13 +346,16 @@ message(STATUS "branch: ${PADDLE_BRANCH}") configure_file(commit.h.in commit.h) +# Adapt to custom op mechanism: Include the header files related to the data type +# to avoid exposing the path of the underlying file +include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../extension/include) + cc_library(custom_tensor SRCS ../extension/src/ext_tensor.cc DEPS lod_tensor memory enforce) cc_library(op_meta_info SRCS ../extension/src/ext_op_meta_info.cc DEPS custom_tensor) cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper custom_tensor op_meta_info) cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../extension/include) - set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator) cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES}) diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 0baacd4621..69a9be603e 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -757,10 +757,39 @@ void RegisterOperatorWithMetaInfo( return new CustomOperator(type, inputs, outputs, attrs); }; - // Grad InferShape (gradient's shape is same with forward input default) - grad_info.infer_shape_ = [grad_op_outputs](InferShapeContext* ctx) { + // Grad InferShape + grad_info.infer_shape_ = [grad_op_inputs, + grad_op_outputs](InferShapeContext* ctx) { + // 1. if forward input exists, gradient's shape is same with forward input + // default + // [Suitable for most situations] + // 2. if forward input not exists, and only contains one grad input and + // output, + // use grad input shape as grad output shape + // [Suitable for the situation that forward input is not used as + // backward input] + // TODO(chenweihang): support set grad op infershape func if needed for (auto& out_name : grad_op_outputs) { - ctx->ShareDim(detail::NoGrad(out_name), out_name); + auto fwd_name = detail::NoGrad(out_name); + if (detail::IsDuplicableVar(fwd_name)) { + // Duplicable forward var must as backward input + ctx->ShareDim(fwd_name, out_name); + } else { + if (ctx->HasInput(fwd_name)) { + ctx->ShareDim(fwd_name, out_name); + } else { + PADDLE_ENFORCE_EQ( + grad_op_inputs.size() == 1UL && grad_op_outputs.size() == 1UL, + true, + platform::errors::Unavailable( + "Custom grad operator infershape error. " + "If a custom grad operator contains only one input and " + "only one output, the input shape will be directly set to " + "the output shape. Otherwise, Please set the forward input " + "as the grad operator's input.")); + ctx->ShareDim(grad_op_inputs[0], out_name); + } + } } }; diff --git a/paddle/fluid/framework/custom_tensor_test.cc b/paddle/fluid/framework/custom_tensor_test.cc index 2e42248f64..7da5658860 100644 --- a/paddle/fluid/framework/custom_tensor_test.cc +++ b/paddle/fluid/framework/custom_tensor_test.cc @@ -109,6 +109,10 @@ void GroupTestCopy() { TestCopyTensor(); VLOG(2) << "uint8 cpu-cpu-gpu-gpu-cpu"; TestCopyTensor(); + VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); } void GroupTestCast() { @@ -126,6 +130,10 @@ void GroupTestCast() { TestCast(paddle::DataType::FLOAT32); VLOG(2) << "float cast"; TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "complex64 cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "complex128 cast"; + TestCast(paddle::DataType::FLOAT32); } void GroupTestDtype() { @@ -136,6 +144,8 @@ void GroupTestDtype() { CHECK(TestDtype() == paddle::DataType::INT16); CHECK(TestDtype() == paddle::DataType::INT8); CHECK(TestDtype() == paddle::DataType::UINT8); + CHECK(TestDtype() == paddle::DataType::COMPLEX64); + CHECK(TestDtype() == paddle::DataType::COMPLEX128); } void GroupTestDtypeConvert() { @@ -162,6 +172,12 @@ void GroupTestDtypeConvert() { paddle::framework::proto::VarType::INT16); CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( paddle::DataType::BOOL) == paddle::framework::proto::VarType::BOOL); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::COMPLEX64) == + paddle::framework::proto::VarType::COMPLEX64); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::COMPLEX128) == + paddle::framework::proto::VarType::COMPLEX128); // proto -> enum CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( paddle::framework::proto::VarType::FP64) == @@ -185,6 +201,12 @@ void GroupTestDtypeConvert() { paddle::DataType::INT16); CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( paddle::framework::proto::VarType::BOOL) == paddle::DataType::BOOL); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::COMPLEX64) == + paddle::DataType::COMPLEX64); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::COMPLEX128) == + paddle::DataType::COMPLEX128); } TEST(CustomTensor, copyTest) { diff --git a/paddle/fluid/framework/custom_tensor_utils.h b/paddle/fluid/framework/custom_tensor_utils.h index 919a3a1a49..a252d6aef4 100644 --- a/paddle/fluid/framework/custom_tensor_utils.h +++ b/paddle/fluid/framework/custom_tensor_utils.h @@ -56,6 +56,10 @@ class CustomTensorUtils { return framework::proto::VarType::INT64; case paddle::DataType::INT16: return framework::proto::VarType::INT16; + case paddle::DataType::COMPLEX64: + return framework::proto::VarType::COMPLEX64; + case paddle::DataType::COMPLEX128: + return framework::proto::VarType::COMPLEX128; case paddle::DataType::BOOL: return framework::proto::VarType::BOOL; default: @@ -83,6 +87,10 @@ class CustomTensorUtils { return paddle::DataType::UINT8; case framework::proto::VarType::INT16: return paddle::DataType::INT16; + case framework::proto::VarType::COMPLEX64: + return paddle::DataType::COMPLEX64; + case framework::proto::VarType::COMPLEX128: + return paddle::DataType::COMPLEX128; case framework::proto::VarType::BOOL: return paddle::DataType::BOOL; default: diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 7a8bfc1a8c..93fd85f13c 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -36,6 +36,10 @@ endif() # fluid_modules exclude API-interface of inference/api and inference/capi get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) +# Adapt to custom op mechanism: Include the header files related to the data type +# to avoid exposing the path of the underlying file +include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform) + add_subdirectory(api) # Create static inference library if needed diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 7a63217d67..5452b2160a 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,3 +1,7 @@ +# Adapt to custom op mechanism: Include the header files related to the data type +# to avoid exposing the path of the underlying file +include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform) + set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index 620bff11a2..4ba537930c 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -26,6 +26,9 @@ set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120) py_test(test_custom_concat SRCS test_custom_concat.py) set_tests_properties(test_custom_concat PROPERTIES TIMEOUT 120) +py_test(test_custom_conj SRCS test_custom_conj.py) +set_tests_properties(test_custom_conj PROPERTIES TIMEOUT 120) + py_test(test_check_abi SRCS test_check_abi.py) cc_test(test_check_error SRCS test_check_error.cc DEPS gtest) diff --git a/python/paddle/fluid/tests/custom_op/custom_conj_op.cc b/python/paddle/fluid/tests/custom_op/custom_conj_op.cc new file mode 100644 index 0000000000..4feb887ca0 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/custom_conj_op.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WIdata_tHOUdata_t WARRANdata_tIES OR CONDIdata_tIONS OF ANY KIND, either +// express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" + +#define CHECK_INPUT(x) \ + PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.") + +template +using EnableComplex = typename std::enable_if< + std::is_same::value || + std::is_same::value>::type; + +template +using DisableComplex = typename std::enable_if< + !std::is_same::value && + !std::is_same::value>::type; + +template +struct ConjFunctor; + +template +struct ConjFunctor> { + ConjFunctor(const data_t* input, int64_t numel, data_t* output) + : input_(input), numel_(numel), output_(output) {} + + void operator()(size_t idx) const { + output_[idx] = data_t(input_[idx].real, -input_[idx].imag); + } + + const data_t* input_; + int64_t numel_; + data_t* output_; +}; + +template +struct ConjFunctor> { + ConjFunctor(const data_t* input, int64_t numel, data_t* output) + : input_(input), numel_(numel), output_(output) {} + + void operator()(size_t idx) const { output_[idx] = input_[idx]; } + + const data_t* input_; + int64_t numel_; + data_t* output_; +}; + +template +void ConjCPUKernel(const data_t* x_data, int64_t numel, data_t* out_data) { + ConjFunctor conj(x_data, numel, out_data); + for (int64_t i = 0; i < numel; ++i) { + conj(i); + } +} + +std::vector ConjFunction(const paddle::Tensor& x) { + CHECK_INPUT(x); + + paddle::Tensor out(x.place()); + out.reshape(x.shape()); + + PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + x.type(), "ConjCPUKernel", ([&] { + ConjCPUKernel( + x.data(), x.size(), out.mutable_data()); + })); + + return {out}; +} + +PD_BUILD_OP(custom_conj) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ConjFunction)); + +PD_BUILD_GRAD_OP(custom_conj) + .Inputs({paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ConjFunction)); diff --git a/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc b/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc index 33ca6ee86f..fbf5442ac0 100644 --- a/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc +++ b/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc @@ -62,3 +62,59 @@ PD_BUILD_OP(dispatch_test_float_and_integer) .Inputs({"X"}) .Outputs({"Out"}) .SetKernelFn(PD_KERNEL(DispatchTestFloatAndInteger)); + +std::vector DispatchTestComplex(const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_COMPLEX_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(), x.size()); + })); + + return {out}; +} + +PD_BUILD_OP(dispatch_test_complex) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(DispatchTestComplex)); + +std::vector DispatchTestFloatAndComplex( + const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(), x.size()); + })); + + return {out}; +} + +PD_BUILD_OP(dispatch_test_float_and_complex) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(DispatchTestFloatAndComplex)); + +std::vector DispatchTestFloatAndIntegerAndComplex( + const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(), x.size()); + })); + + return {out}; +} + +PD_BUILD_OP(dispatch_test_float_and_integer_and_complex) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex)); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_conj.py b/python/paddle/fluid/tests/custom_op/test_custom_conj.py new file mode 100644 index 0000000000..3a8f79a06f --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_custom_conj.py @@ -0,0 +1,136 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import numpy as np + +import paddle +import paddle.static as static +from paddle.utils.cpp_extension import load, get_build_directory +from paddle.utils.cpp_extension.extension_utils import run_cmd +from utils import paddle_includes, extra_cc_args, extra_nvcc_args + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = '{}\\custom_relu_module_jit\\custom_relu_module_jit.pyd'.format( + get_build_directory()) +if os.name == 'nt' and os.path.isfile(file): + cmd = 'del {}'.format(file) + run_cmd(cmd, True) + +custom_ops = load( + name='custom_conj_jit', + sources=['custom_conj_op.cc'], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_cc_args, # test for cc flags + extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags + verbose=True) + + +def is_complex(dtype): + return dtype == paddle.fluid.core.VarDesc.VarType.COMPLEX64 or \ + dtype == paddle.fluid.core.VarDesc.VarType.COMPLEX128 + + +def to_complex(dtype): + if dtype == "float32": + return np.complex64 + elif dtype == "float64": + return np.complex128 + else: + return dtype + + +def conj_dynamic(func, dtype, np_input): + paddle.set_device("cpu") + x = paddle.to_tensor(np_input) + out = func(x) + out.stop_gradient = False + sum_out = paddle.sum(out) + if is_complex(sum_out.dtype): + sum_out.real().backward() + else: + sum_out.backward() + return out.numpy(), x.grad + + +def conj_static(func, shape, dtype, np_input): + paddle.enable_static() + paddle.set_device("cpu") + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + x = static.data(name="x", shape=shape, dtype=dtype) + x.stop_gradient = False + out = func(x) + sum_out = paddle.sum(out) + static.append_backward(sum_out) + + exe = static.Executor() + exe.run(static.default_startup_program()) + + out_v, x_grad_v = exe.run(static.default_main_program(), + feed={"x": np_input}, + fetch_list=[out.name, x.name + "@GRAD"]) + paddle.disable_static() + return out_v, x_grad_v + + +class TestCustomConjJit(unittest.TestCase): + def setUp(self): + self.dtypes = ['float32', 'float64'] + self.shape = [2, 20, 2, 3] + + def check_output(self, out, pd_out, name): + self.assertTrue( + np.array_equal(out, pd_out), + "custom op {}: {},\n paddle api {}: {}".format(name, out, name, + pd_out)) + + def run_dynamic(self, dtype, np_input): + out, x_grad = conj_dynamic(custom_ops.custom_conj, dtype, np_input) + pd_out, pd_x_grad = conj_dynamic(paddle.conj, dtype, np_input) + + self.check_output(out, pd_out, "out") + self.check_output(x_grad, pd_x_grad, "x's grad") + + def run_static(self, dtype, np_input): + out, x_grad = conj_static(custom_ops.custom_conj, self.shape, dtype, + np_input) + pd_out, pd_x_grad = conj_static(paddle.conj, self.shape, dtype, + np_input) + + self.check_output(out, pd_out, "out") + self.check_output(x_grad, pd_x_grad, "x's grad") + + def test_dynamic(self): + for dtype in self.dtypes: + np_input = np.random.random(self.shape).astype(dtype) + self.run_dynamic(dtype, np_input) + + def test_static(self): + for dtype in self.dtypes: + np_input = np.random.random(self.shape).astype(dtype) + self.run_static(dtype, np_input) + + # complex only used in dynamic mode now + def test_complex_dynamic(self): + for dtype in self.dtypes: + np_input = np.random.random(self.shape).astype( + dtype) + 1j * np.random.random(self.shape).astype(dtype) + self.run_dynamic(to_complex(dtype), np_input) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py b/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py index 6cdbc61620..bc36372c6a 100644 --- a/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py @@ -55,6 +55,11 @@ class TestJitDispatch(unittest.TestCase): for dtype in dtypes: self.run_dispatch_test(dispatch_op.dispatch_test_integer, dtype) + def test_dispatch_complex(self): + dtypes = ["complex64", "complex128"] + for dtype in dtypes: + self.run_dispatch_test(dispatch_op.dispatch_test_complex, dtype) + def test_dispatch_float_and_integer(self): dtypes = [ "float32", "float64", "int32", "int64", "int8", "uint8", "int16" @@ -63,6 +68,21 @@ class TestJitDispatch(unittest.TestCase): self.run_dispatch_test(dispatch_op.dispatch_test_float_and_integer, dtype) + def test_dispatch_float_and_complex(self): + dtypes = ["float32", "float64", "complex64", "complex128"] + for dtype in dtypes: + self.run_dispatch_test(dispatch_op.dispatch_test_float_and_complex, + dtype) + + def test_dispatch_float_and_integer_and_complex(self): + dtypes = [ + "float32", "float64", "int32", "int64", "int8", "uint8", "int16", + "complex64", "complex128" + ] + for dtype in dtypes: + self.run_dispatch_test( + dispatch_op.dispatch_test_float_and_integer_and_complex, dtype) + if __name__ == '__main__': unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 0e214c5c65..0afc3956a0 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -451,12 +451,30 @@ class InstallHeaders(Command): ('install_headers', 'install_dir'), ('force', 'force')) + def copy_data_type_headers(self, header): + if os.name == 'nt': + data_type_headers = ['platform\\complex64.h', 'platform\\complex128.h'] + else: + data_type_headers = ['platform/complex64.h', 'platform/complex128.h'] + for dtype_header in data_type_headers: + if dtype_header in header: + if os.name == 'nt': + install_dir = os.path.join(self.install_dir, "paddle\\fluid\\extension\\include") + else: + install_dir = os.path.join(self.install_dir, "paddle/fluid/extension/include") + if not os.path.exists(install_dir): + self.mkpath(install_dir) + return self.copy_file(header, install_dir) + def mkdir_and_copy_file(self, header): if 'pb.h' in header: install_dir = re.sub('${PADDLE_BINARY_DIR}/', '', header) elif 'third_party' not in header: - # framework + # paddle headers install_dir = re.sub('@PADDLE_SOURCE_DIR@/', '', header) + # For paddle data type headers, we also need to copy to `extension/incude`, + # used for new custom operator + self.copy_data_type_headers(header) else: # third_party install_dir = re.sub('${THIRD_PARTY_PATH}', 'third_party', header) -- GitLab