未验证 提交 87852616 编写于 作者: C Chen Weihang 提交者: GitHub

[CustomOp] Support complex dtype in custom op (#31657)

* support custom complex op

* fix detail error

* add inference support

* fix setup windows failed
上级 fe241fd0
......@@ -192,6 +192,12 @@ include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/extension/include/*
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex64.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex128.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
# CAPI inference library for only inference
set(PADDLE_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_c_install_dir" CACHE STRING
......
......@@ -68,6 +68,22 @@ namespace paddle {
} \
}()
///////// Complex Dispatch Marco ///////////
#define PD_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Floating and Integral Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \
......@@ -93,6 +109,55 @@ namespace paddle {
} \
}()
///////// Floating and Complex Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Floating, Integral and Complex Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
// TODO(chenweihang): Add more Marcos in the future if needed
} // namespace paddle
......@@ -16,10 +16,15 @@ limitations under the License. */
#include <cstdint>
#include <string>
#include "complex128.h" // NOLINT
#include "complex64.h" // NOLINT
#include "ext_exception.h" // NOLINT
namespace paddle {
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
enum class DataType {
BOOL,
INT8,
......@@ -29,6 +34,8 @@ enum class DataType {
INT64,
FLOAT32,
FLOAT64,
COMPLEX64,
COMPLEX128,
// TODO(JiabinYang) support more data types if needed.
};
......@@ -50,20 +57,26 @@ inline std::string ToString(DataType dtype) {
return "float";
case DataType::FLOAT64:
return "double";
case DataType::COMPLEX64:
return "complex64";
case DataType::COMPLEX128:
return "complex128";
default:
PD_THROW("Unsupported paddle enum data type.");
}
}
#define PD_FOR_EACH_DATA_TYPE(_) \
_(bool, DataType::BOOL) \
_(int8_t, DataType::INT8) \
_(uint8_t, DataType::UINT8) \
_(int16_t, DataType::INT16) \
_(int, DataType::INT32) \
_(int64_t, DataType::INT64) \
_(float, DataType::FLOAT32) \
_(double, DataType::FLOAT64)
#define PD_FOR_EACH_DATA_TYPE(_) \
_(bool, DataType::BOOL) \
_(int8_t, DataType::INT8) \
_(uint8_t, DataType::UINT8) \
_(int16_t, DataType::INT16) \
_(int, DataType::INT32) \
_(int64_t, DataType::INT64) \
_(float, DataType::FLOAT32) \
_(double, DataType::FLOAT64) \
_(complex64, DataType::COMPLEX64) \
_(complex128, DataType::COMPLEX128)
template <paddle::DataType T>
struct DataTypeToCPPType;
......
......@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/extension/include/ext_tensor.h"
#include <utility>
#include "paddle/fluid/framework/custom_tensor_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/transform.h"
......@@ -162,6 +166,10 @@ DataType Tensor::type() const {
return DataType::FLOAT64;
} else if (type == framework::proto::VarType::BOOL) {
return DataType::BOOL;
} else if (type == framework::proto::VarType::COMPLEX64) {
return DataType::COMPLEX64;
} else if (type == framework::proto::VarType::COMPLEX128) {
return DataType::COMPLEX128;
}
// TODO(JiabinYang) Support more dtype here
return DataType::FLOAT32;
......@@ -217,6 +225,10 @@ template PD_DLL_DECL Tensor
Tensor::copy_to<int16_t>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor
Tensor::copy_to<bool>(const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex64>(
const PlaceType &target_place) const;
template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex128>(
const PlaceType &target_place) const;
template PD_DLL_DECL float *Tensor::data<float>() const;
template PD_DLL_DECL double *Tensor::data<double>() const;
......@@ -226,6 +238,10 @@ template PD_DLL_DECL uint8_t *Tensor::data<uint8_t>() const;
template PD_DLL_DECL int8_t *Tensor::data<int8_t>() const;
template PD_DLL_DECL int16_t *Tensor::data<int16_t>() const;
template PD_DLL_DECL bool *Tensor::data<bool>() const;
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::data<paddle::platform::complex64>() const;
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::data<paddle::platform::complex128>() const;
template PD_DLL_DECL float *Tensor::mutable_data<float>();
template PD_DLL_DECL double *Tensor::mutable_data<double>();
......@@ -235,6 +251,10 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>();
template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>();
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>();
template PD_DLL_DECL bool *Tensor::mutable_data<bool>();
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>();
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>();
template PD_DLL_DECL float *Tensor::mutable_data<float>(const PlaceType &place);
template PD_DLL_DECL double *Tensor::mutable_data<double>(
......@@ -250,6 +270,10 @@ template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>(
template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(
const PlaceType &place);
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
template PD_DLL_DECL paddle::platform::complex128 *
Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
std::vector<int64_t> Tensor::shape() const {
GET_CASTED_TENSOR
......@@ -310,6 +334,16 @@ Tensor Tensor::cast(const DataType &target_type) const {
framework::VisitDataType(
dst_type, CastDataType<uint8_t>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::COMPLEX64:
framework::VisitDataType(
dst_type,
CastDataType<paddle::platform::complex64>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::COMPLEX128:
framework::VisitDataType(dst_type,
CastDataType<paddle::platform::complex128>(
*tensor, rlt_tensor_, ctx));
break;
// TODO(JiabinYang) Support more dtype here
default:
PADDLE_THROW(platform::errors::Unimplemented(
......
......@@ -346,13 +346,16 @@ message(STATUS "branch: ${PADDLE_BRANCH}")
configure_file(commit.h.in commit.h)
# Adapt to custom op mechanism: Include the header files related to the data type
# to avoid exposing the path of the underlying file
include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../extension/include)
cc_library(custom_tensor SRCS ../extension/src/ext_tensor.cc DEPS lod_tensor memory enforce)
cc_library(op_meta_info SRCS ../extension/src/ext_op_meta_info.cc DEPS custom_tensor)
cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper custom_tensor op_meta_info)
cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../extension/include)
set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator)
cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES})
......
......@@ -757,10 +757,39 @@ void RegisterOperatorWithMetaInfo(
return new CustomOperator(type, inputs, outputs, attrs);
};
// Grad InferShape (gradient's shape is same with forward input default)
grad_info.infer_shape_ = [grad_op_outputs](InferShapeContext* ctx) {
// Grad InferShape
grad_info.infer_shape_ = [grad_op_inputs,
grad_op_outputs](InferShapeContext* ctx) {
// 1. if forward input exists, gradient's shape is same with forward input
// default
// [Suitable for most situations]
// 2. if forward input not exists, and only contains one grad input and
// output,
// use grad input shape as grad output shape
// [Suitable for the situation that forward input is not used as
// backward input]
// TODO(chenweihang): support set grad op infershape func if needed
for (auto& out_name : grad_op_outputs) {
ctx->ShareDim(detail::NoGrad(out_name), out_name);
auto fwd_name = detail::NoGrad(out_name);
if (detail::IsDuplicableVar(fwd_name)) {
// Duplicable forward var must as backward input
ctx->ShareDim(fwd_name, out_name);
} else {
if (ctx->HasInput(fwd_name)) {
ctx->ShareDim(fwd_name, out_name);
} else {
PADDLE_ENFORCE_EQ(
grad_op_inputs.size() == 1UL && grad_op_outputs.size() == 1UL,
true,
platform::errors::Unavailable(
"Custom grad operator infershape error. "
"If a custom grad operator contains only one input and "
"only one output, the input shape will be directly set to "
"the output shape. Otherwise, Please set the forward input "
"as the grad operator's input."));
ctx->ShareDim(grad_op_inputs[0], out_name);
}
}
}
};
......
......@@ -109,6 +109,10 @@ void GroupTestCopy() {
TestCopyTensor<int8_t>();
VLOG(2) << "uint8 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<uint8_t>();
VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::complex64>();
VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu";
TestCopyTensor<paddle::complex128>();
}
void GroupTestCast() {
......@@ -126,6 +130,10 @@ void GroupTestCast() {
TestCast<uint8_t>(paddle::DataType::FLOAT32);
VLOG(2) << "float cast";
TestCast<float>(paddle::DataType::FLOAT32);
VLOG(2) << "complex64 cast";
TestCast<paddle::complex64>(paddle::DataType::FLOAT32);
VLOG(2) << "complex128 cast";
TestCast<paddle::complex128>(paddle::DataType::FLOAT32);
}
void GroupTestDtype() {
......@@ -136,6 +144,8 @@ void GroupTestDtype() {
CHECK(TestDtype<int16_t>() == paddle::DataType::INT16);
CHECK(TestDtype<int8_t>() == paddle::DataType::INT8);
CHECK(TestDtype<uint8_t>() == paddle::DataType::UINT8);
CHECK(TestDtype<paddle::complex64>() == paddle::DataType::COMPLEX64);
CHECK(TestDtype<paddle::complex128>() == paddle::DataType::COMPLEX128);
}
void GroupTestDtypeConvert() {
......@@ -162,6 +172,12 @@ void GroupTestDtypeConvert() {
paddle::framework::proto::VarType::INT16);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::BOOL) == paddle::framework::proto::VarType::BOOL);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::COMPLEX64) ==
paddle::framework::proto::VarType::COMPLEX64);
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
paddle::DataType::COMPLEX128) ==
paddle::framework::proto::VarType::COMPLEX128);
// proto -> enum
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::FP64) ==
......@@ -185,6 +201,12 @@ void GroupTestDtypeConvert() {
paddle::DataType::INT16);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::BOOL) == paddle::DataType::BOOL);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::COMPLEX64) ==
paddle::DataType::COMPLEX64);
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
paddle::framework::proto::VarType::COMPLEX128) ==
paddle::DataType::COMPLEX128);
}
TEST(CustomTensor, copyTest) {
......
......@@ -56,6 +56,10 @@ class CustomTensorUtils {
return framework::proto::VarType::INT64;
case paddle::DataType::INT16:
return framework::proto::VarType::INT16;
case paddle::DataType::COMPLEX64:
return framework::proto::VarType::COMPLEX64;
case paddle::DataType::COMPLEX128:
return framework::proto::VarType::COMPLEX128;
case paddle::DataType::BOOL:
return framework::proto::VarType::BOOL;
default:
......@@ -83,6 +87,10 @@ class CustomTensorUtils {
return paddle::DataType::UINT8;
case framework::proto::VarType::INT16:
return paddle::DataType::INT16;
case framework::proto::VarType::COMPLEX64:
return paddle::DataType::COMPLEX64;
case framework::proto::VarType::COMPLEX128:
return paddle::DataType::COMPLEX128;
case framework::proto::VarType::BOOL:
return paddle::DataType::BOOL;
default:
......
......@@ -36,6 +36,10 @@ endif()
# fluid_modules exclude API-interface of inference/api and inference/capi
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
# Adapt to custom op mechanism: Include the header files related to the data type
# to avoid exposing the path of the underlying file
include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform)
add_subdirectory(api)
# Create static inference library if needed
......
# Adapt to custom op mechanism: Include the header files related to the data type
# to avoid exposing the path of the underlying file
include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform)
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune
feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
......
......@@ -26,6 +26,9 @@ set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120)
py_test(test_custom_concat SRCS test_custom_concat.py)
set_tests_properties(test_custom_concat PROPERTIES TIMEOUT 120)
py_test(test_custom_conj SRCS test_custom_conj.py)
set_tests_properties(test_custom_conj PROPERTIES TIMEOUT 120)
py_test(test_check_abi SRCS test_check_abi.py)
cc_test(test_check_error SRCS test_check_error.cc DEPS gtest)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WIdata_tHOUdata_t WARRANdata_tIES OR CONDIdata_tIONS OF ANY KIND, either
// express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "paddle/extension.h"
#define CHECK_INPUT(x) \
PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
template <typename data_t>
using EnableComplex = typename std::enable_if<
std::is_same<data_t, paddle::complex64>::value ||
std::is_same<data_t, paddle::complex128>::value>::type;
template <typename data_t>
using DisableComplex = typename std::enable_if<
!std::is_same<data_t, paddle::complex64>::value &&
!std::is_same<data_t, paddle::complex128>::value>::type;
template <typename data_t, typename Enable = void>
struct ConjFunctor;
template <typename data_t>
struct ConjFunctor<data_t, EnableComplex<data_t>> {
ConjFunctor(const data_t* input, int64_t numel, data_t* output)
: input_(input), numel_(numel), output_(output) {}
void operator()(size_t idx) const {
output_[idx] = data_t(input_[idx].real, -input_[idx].imag);
}
const data_t* input_;
int64_t numel_;
data_t* output_;
};
template <typename data_t>
struct ConjFunctor<data_t, DisableComplex<data_t>> {
ConjFunctor(const data_t* input, int64_t numel, data_t* output)
: input_(input), numel_(numel), output_(output) {}
void operator()(size_t idx) const { output_[idx] = input_[idx]; }
const data_t* input_;
int64_t numel_;
data_t* output_;
};
template <typename data_t>
void ConjCPUKernel(const data_t* x_data, int64_t numel, data_t* out_data) {
ConjFunctor<data_t> conj(x_data, numel, out_data);
for (int64_t i = 0; i < numel; ++i) {
conj(i);
}
}
std::vector<paddle::Tensor> ConjFunction(const paddle::Tensor& x) {
CHECK_INPUT(x);
paddle::Tensor out(x.place());
out.reshape(x.shape());
PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
x.type(), "ConjCPUKernel", ([&] {
ConjCPUKernel<data_t>(
x.data<data_t>(), x.size(), out.mutable_data<data_t>());
}));
return {out};
}
PD_BUILD_OP(custom_conj)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ConjFunction));
PD_BUILD_GRAD_OP(custom_conj)
.Inputs({paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ConjFunction));
......@@ -62,3 +62,59 @@ PD_BUILD_OP(dispatch_test_float_and_integer)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndInteger));
std::vector<paddle::Tensor> DispatchTestComplex(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_COMPLEX_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP(dispatch_test_complex)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestComplex));
std::vector<paddle::Tensor> DispatchTestFloatAndComplex(
const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP(dispatch_test_float_and_complex)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndComplex));
std::vector<paddle::Tensor> DispatchTestFloatAndIntegerAndComplex(
const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP(dispatch_test_float_and_integer_and_complex)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex));
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
import paddle
import paddle.static as static
from paddle.utils.cpp_extension import load, get_build_directory
from paddle.utils.cpp_extension.extension_utils import run_cmd
from utils import paddle_includes, extra_cc_args, extra_nvcc_args
# Because Windows don't use docker, the shared lib already exists in the
# cache dir, it will not be compiled again unless the shared lib is removed.
file = '{}\\custom_relu_module_jit\\custom_relu_module_jit.pyd'.format(
get_build_directory())
if os.name == 'nt' and os.path.isfile(file):
cmd = 'del {}'.format(file)
run_cmd(cmd, True)
custom_ops = load(
name='custom_conj_jit',
sources=['custom_conj_op.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=extra_cc_args, # test for cc flags
extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags
verbose=True)
def is_complex(dtype):
return dtype == paddle.fluid.core.VarDesc.VarType.COMPLEX64 or \
dtype == paddle.fluid.core.VarDesc.VarType.COMPLEX128
def to_complex(dtype):
if dtype == "float32":
return np.complex64
elif dtype == "float64":
return np.complex128
else:
return dtype
def conj_dynamic(func, dtype, np_input):
paddle.set_device("cpu")
x = paddle.to_tensor(np_input)
out = func(x)
out.stop_gradient = False
sum_out = paddle.sum(out)
if is_complex(sum_out.dtype):
sum_out.real().backward()
else:
sum_out.backward()
return out.numpy(), x.grad
def conj_static(func, shape, dtype, np_input):
paddle.enable_static()
paddle.set_device("cpu")
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name="x", shape=shape, dtype=dtype)
x.stop_gradient = False
out = func(x)
sum_out = paddle.sum(out)
static.append_backward(sum_out)
exe = static.Executor()
exe.run(static.default_startup_program())
out_v, x_grad_v = exe.run(static.default_main_program(),
feed={"x": np_input},
fetch_list=[out.name, x.name + "@GRAD"])
paddle.disable_static()
return out_v, x_grad_v
class TestCustomConjJit(unittest.TestCase):
def setUp(self):
self.dtypes = ['float32', 'float64']
self.shape = [2, 20, 2, 3]
def check_output(self, out, pd_out, name):
self.assertTrue(
np.array_equal(out, pd_out),
"custom op {}: {},\n paddle api {}: {}".format(name, out, name,
pd_out))
def run_dynamic(self, dtype, np_input):
out, x_grad = conj_dynamic(custom_ops.custom_conj, dtype, np_input)
pd_out, pd_x_grad = conj_dynamic(paddle.conj, dtype, np_input)
self.check_output(out, pd_out, "out")
self.check_output(x_grad, pd_x_grad, "x's grad")
def run_static(self, dtype, np_input):
out, x_grad = conj_static(custom_ops.custom_conj, self.shape, dtype,
np_input)
pd_out, pd_x_grad = conj_static(paddle.conj, self.shape, dtype,
np_input)
self.check_output(out, pd_out, "out")
self.check_output(x_grad, pd_x_grad, "x's grad")
def test_dynamic(self):
for dtype in self.dtypes:
np_input = np.random.random(self.shape).astype(dtype)
self.run_dynamic(dtype, np_input)
def test_static(self):
for dtype in self.dtypes:
np_input = np.random.random(self.shape).astype(dtype)
self.run_static(dtype, np_input)
# complex only used in dynamic mode now
def test_complex_dynamic(self):
for dtype in self.dtypes:
np_input = np.random.random(self.shape).astype(
dtype) + 1j * np.random.random(self.shape).astype(dtype)
self.run_dynamic(to_complex(dtype), np_input)
if __name__ == "__main__":
unittest.main()
......@@ -55,6 +55,11 @@ class TestJitDispatch(unittest.TestCase):
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_integer, dtype)
def test_dispatch_complex(self):
dtypes = ["complex64", "complex128"]
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_complex, dtype)
def test_dispatch_float_and_integer(self):
dtypes = [
"float32", "float64", "int32", "int64", "int8", "uint8", "int16"
......@@ -63,6 +68,21 @@ class TestJitDispatch(unittest.TestCase):
self.run_dispatch_test(dispatch_op.dispatch_test_float_and_integer,
dtype)
def test_dispatch_float_and_complex(self):
dtypes = ["float32", "float64", "complex64", "complex128"]
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_float_and_complex,
dtype)
def test_dispatch_float_and_integer_and_complex(self):
dtypes = [
"float32", "float64", "int32", "int64", "int8", "uint8", "int16",
"complex64", "complex128"
]
for dtype in dtypes:
self.run_dispatch_test(
dispatch_op.dispatch_test_float_and_integer_and_complex, dtype)
if __name__ == '__main__':
unittest.main()
......@@ -451,12 +451,30 @@ class InstallHeaders(Command):
('install_headers', 'install_dir'),
('force', 'force'))
def copy_data_type_headers(self, header):
if os.name == 'nt':
data_type_headers = ['platform\\complex64.h', 'platform\\complex128.h']
else:
data_type_headers = ['platform/complex64.h', 'platform/complex128.h']
for dtype_header in data_type_headers:
if dtype_header in header:
if os.name == 'nt':
install_dir = os.path.join(self.install_dir, "paddle\\fluid\\extension\\include")
else:
install_dir = os.path.join(self.install_dir, "paddle/fluid/extension/include")
if not os.path.exists(install_dir):
self.mkpath(install_dir)
return self.copy_file(header, install_dir)
def mkdir_and_copy_file(self, header):
if 'pb.h' in header:
install_dir = re.sub('${PADDLE_BINARY_DIR}/', '', header)
elif 'third_party' not in header:
# framework
# paddle headers
install_dir = re.sub('@PADDLE_SOURCE_DIR@/', '', header)
# For paddle data type headers, we also need to copy to `extension/incude`,
# used for new custom operator
self.copy_data_type_headers(header)
else:
# third_party
install_dir = re.sub('${THIRD_PARTY_PATH}', 'third_party', header)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册