From a19154ca403da27bf8774c9a7aac93b09cd16f21 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 23 Feb 2021 07:56:17 -0600 Subject: [PATCH] [CustomOp] New custom operator extension mechanism in 2.0.1 (#31097) [CustomOp] New custom operator extension mechanism in 2.0.1 Cherry-pick New custom operator basic implementation related PRs --- paddle/extension.h | 18 + paddle/fluid/extension/include/all.h | 25 + paddle/fluid/extension/include/dispatch.h | 168 ++++ paddle/fluid/extension/include/dtype.h | 103 +++ paddle/fluid/extension/include/op_meta_info.h | 325 ++++++++ paddle/fluid/extension/include/place.h | 22 + paddle/fluid/extension/include/tensor.h | 95 +++ paddle/fluid/extension/src/op_meta_info.cc | 123 +++ paddle/fluid/extension/src/tensor.cc | 383 ++++++++++ paddle/fluid/framework/CMakeLists.txt | 10 +- paddle/fluid/framework/custom_operator.cc | 534 +++++++++++++ paddle/fluid/framework/custom_operator.h | 32 + paddle/fluid/framework/custom_tensor_test.cc | 249 ++++++ paddle/fluid/framework/custom_tensor_utils.h | 145 ++++ paddle/fluid/framework/data_type.cc | 4 + paddle/fluid/framework/data_type_transform.cc | 4 +- paddle/fluid/framework/op_meta_info_helper.h | 54 ++ paddle/fluid/pybind/CMakeLists.txt | 2 +- paddle/fluid/pybind/pybind.cc | 71 +- python/paddle/fluid/framework.py | 9 +- .../fluid/tests/custom_op/CMakeLists.txt | 12 +- .../paddle/fluid/tests/custom_op/__init__.py | 13 + .../fluid/tests/custom_op/dispatch_test_op.cc | 138 ++++ .../paddle/fluid/tests/custom_op/relu_op3.cc | 115 +++ .../paddle/fluid/tests/custom_op/relu_op3.cu | 87 +++ .../fluid/tests/custom_op/relu_op3_simple.cc | 43 ++ .../fluid/tests/custom_op/relu_op_simple.cc | 136 ++++ .../fluid/tests/custom_op/relu_op_simple.cu | 93 +++ .../fluid/tests/custom_op/setup_build.py | 37 + .../fluid/tests/custom_op/setup_install.py | 29 + .../tests/custom_op/setup_install_simple.py | 27 + .../fluid/tests/custom_op/test_check_abi.py | 135 ++++ .../fluid/tests/custom_op/test_custom_op.py | 12 +- .../fluid/tests/custom_op/test_dispatch.py | 79 ++ .../fluid/tests/custom_op/test_jit_load.py | 51 ++ .../fluid/tests/custom_op/test_setup_build.py | 69 ++ .../tests/custom_op/test_setup_install.py | 65 ++ .../custom_op/test_simple_custom_op_jit.py | 125 +++ .../custom_op/test_simple_custom_op_setup.py | 160 ++++ python/paddle/fluid/tests/custom_op/utils.py | 33 + python/paddle/utils/__init__.py | 2 + python/paddle/utils/cpp_extension/__init__.py | 30 + .../utils/cpp_extension/cpp_extension.py | 471 ++++++++++++ .../utils/cpp_extension/extension_utils.py | 722 ++++++++++++++++++ python/setup.py.in | 3 + 45 files changed, 5021 insertions(+), 42 deletions(-) create mode 100644 paddle/extension.h create mode 100644 paddle/fluid/extension/include/all.h create mode 100644 paddle/fluid/extension/include/dispatch.h create mode 100644 paddle/fluid/extension/include/dtype.h create mode 100644 paddle/fluid/extension/include/op_meta_info.h create mode 100644 paddle/fluid/extension/include/place.h create mode 100644 paddle/fluid/extension/include/tensor.h create mode 100644 paddle/fluid/extension/src/op_meta_info.cc create mode 100644 paddle/fluid/extension/src/tensor.cc create mode 100644 paddle/fluid/framework/custom_operator.cc create mode 100644 paddle/fluid/framework/custom_operator.h create mode 100644 paddle/fluid/framework/custom_tensor_test.cc create mode 100644 paddle/fluid/framework/custom_tensor_utils.h create mode 100644 paddle/fluid/framework/op_meta_info_helper.h create mode 100644 python/paddle/fluid/tests/custom_op/__init__.py create mode 100644 python/paddle/fluid/tests/custom_op/dispatch_test_op.cc create mode 100644 python/paddle/fluid/tests/custom_op/relu_op3.cc create mode 100644 python/paddle/fluid/tests/custom_op/relu_op3.cu create mode 100644 python/paddle/fluid/tests/custom_op/relu_op3_simple.cc create mode 100644 python/paddle/fluid/tests/custom_op/relu_op_simple.cc create mode 100644 python/paddle/fluid/tests/custom_op/relu_op_simple.cu create mode 100644 python/paddle/fluid/tests/custom_op/setup_build.py create mode 100644 python/paddle/fluid/tests/custom_op/setup_install.py create mode 100644 python/paddle/fluid/tests/custom_op/setup_install_simple.py create mode 100644 python/paddle/fluid/tests/custom_op/test_check_abi.py create mode 100644 python/paddle/fluid/tests/custom_op/test_dispatch.py create mode 100644 python/paddle/fluid/tests/custom_op/test_jit_load.py create mode 100644 python/paddle/fluid/tests/custom_op/test_setup_build.py create mode 100644 python/paddle/fluid/tests/custom_op/test_setup_install.py create mode 100644 python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py create mode 100644 python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py create mode 100644 python/paddle/fluid/tests/custom_op/utils.py create mode 100644 python/paddle/utils/cpp_extension/__init__.py create mode 100644 python/paddle/utils/cpp_extension/cpp_extension.py create mode 100644 python/paddle/utils/cpp_extension/extension_utils.py diff --git a/paddle/extension.h b/paddle/extension.h new file mode 100644 index 00000000000..1c64b92c5a3 --- /dev/null +++ b/paddle/extension.h @@ -0,0 +1,18 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +// All paddle apis in C++ frontend +#include "paddle/fluid/extension/include/all.h" diff --git a/paddle/fluid/extension/include/all.h b/paddle/fluid/extension/include/all.h new file mode 100644 index 00000000000..5aa61f8203e --- /dev/null +++ b/paddle/fluid/extension/include/all.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#if !defined(_MSC_VER) && __cplusplus < 199711L +#error C++11 or later compatible compiler is required to use Paddle. +#endif + +#include "paddle/fluid/extension/include/dispatch.h" +#include "paddle/fluid/extension/include/dtype.h" +#include "paddle/fluid/extension/include/op_meta_info.h" +#include "paddle/fluid/extension/include/place.h" +#include "paddle/fluid/extension/include/tensor.h" diff --git a/paddle/fluid/extension/include/dispatch.h b/paddle/fluid/extension/include/dispatch.h new file mode 100644 index 00000000000..c2297103952 --- /dev/null +++ b/paddle/fluid/extension/include/dispatch.h @@ -0,0 +1,168 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/extension/include/dtype.h" + +namespace paddle { + +///////// Basic Marco /////////// + +#define PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \ + case enum_type: { \ + using HINT = type; \ + __VA_ARGS__(); \ + break; \ + } + +#define PD_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \ + PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__) + +///////// Floating Dispatch Marco /////////// + +#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \ + __VA_ARGS__) \ + default: \ + throw std::runtime_error("function " #NAME \ + " not implemented for data type `" + \ + ::paddle::ToString(__dtype__) + "`"); \ + } \ + }() + +///////// Integral Dispatch Marco /////////// + +#define PD_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \ + __VA_ARGS__) \ + default: \ + throw std::runtime_error("function " #NAME \ + " not implemented for data type `" + \ + ::paddle::ToString(__dtype__) + "`"); \ + } \ + }() + +///////// Complex Dispatch Marco /////////// + +#define PD_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, __VA_ARGS__) \ + default: \ + throw std::runtime_error("function " #NAME \ + " not implemented for data type `" + \ + ::paddle::ToString(__dtype__) + "`"); \ + } \ + }() + +///////// Floating and Integral Dispatch Marco /////////// + +#define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \ + __VA_ARGS__) \ + default: \ + throw std::runtime_error("function " #NAME \ + " not implemented for data type `" + \ + ::paddle::ToString(__dtype__) + "`"); \ + } \ + }() + +///////// Floating and Complex Dispatch Marco /////////// + +#define PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, __VA_ARGS__) \ + default: \ + throw std::runtime_error("function " #NAME \ + " not implemented for data type `" + \ + ::paddle::ToString(__dtype__) + "`"); \ + } \ + }() + +///////// Floating, Integral and Complex Dispatch Marco /////////// + +#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, __VA_ARGS__) \ + default: \ + throw std::runtime_error("function " #NAME \ + " not implemented for data type `" + \ + ::paddle::ToString(__dtype__) + "`"); \ + } \ + }() + +// TODO(chenweihang): Add more Marcos in the future if needed + +} // namespace paddle diff --git a/paddle/fluid/extension/include/dtype.h b/paddle/fluid/extension/include/dtype.h new file mode 100644 index 00000000000..c5d2e0f8205 --- /dev/null +++ b/paddle/fluid/extension/include/dtype.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/complex128.h" +#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { + +using float16 = paddle::platform::float16; +using bfloat16 = paddle::platform::bfloat16; +using complex64 = paddle::platform::complex64; +using complex128 = paddle::platform::complex128; + +enum DataType { + BOOL, + INT8, + UINT8, + INT16, + INT32, + INT64, + FLOAT16, + BFLOAT16, + FLOAT32, + FLOAT64, + COMPLEX64, + COMPLEX128, + // TODO(JiabinYang) support more data types if needed. +}; + +inline std::string ToString(DataType dtype) { + switch (dtype) { + case DataType::BOOL: + return "bool"; + case DataType::INT8: + return "int8_t"; + case DataType::UINT8: + return "uint8_t"; + case DataType::INT16: + return "int16_t"; + case DataType::INT32: + return "int32_t"; + case DataType::INT64: + return "int64_t"; + case DataType::FLOAT16: + return "float16"; + case DataType::BFLOAT16: + return "bfloat16"; + case DataType::FLOAT32: + return "float"; + case DataType::FLOAT64: + return "double"; + case DataType::COMPLEX64: + return "complex64"; + case DataType::COMPLEX128: + return "complex128"; + default: + throw std::runtime_error("Unsupported paddle enum data type."); + } +} + +#define PD_FOR_EACH_DATA_TYPE(_) \ + _(bool, DataType::BOOL) \ + _(int8_t, DataType::INT8) \ + _(uint8_t, DataType::UINT8) \ + _(int16_t, DataType::INT16) \ + _(int, DataType::INT32) \ + _(int64_t, DataType::INT64) \ + _(float16, DataType::FLOAT16) \ + _(bfloat16, DataType::BFLOAT16) \ + _(float, DataType::FLOAT32) \ + _(double, DataType::FLOAT64) \ + _(complex64, DataType::COMPLEX64) \ + _(complex128, DataType::COMPLEX128) + +template +struct DataTypeToCPPType; + +#define PD_SPECIALIZE_DataTypeToCPPType(cpp_type, data_type) \ + template <> \ + struct DataTypeToCPPType { \ + using type = cpp_type; \ + }; + +PD_FOR_EACH_DATA_TYPE(PD_SPECIALIZE_DataTypeToCPPType) + +#undef PD_SPECIALIZE_DataTypeToCPPType + +} // namespace paddle diff --git a/paddle/fluid/extension/include/op_meta_info.h b/paddle/fluid/extension/include/op_meta_info.h new file mode 100644 index 00000000000..920049e2390 --- /dev/null +++ b/paddle/fluid/extension/include/op_meta_info.h @@ -0,0 +1,325 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include + +#include "paddle/fluid/extension/include/tensor.h" + +/** + * Op Meta Info Related Define. + * + * Used to maintain operator core information. + * + */ + +namespace paddle { +namespace framework { +class OpMetaInfoHelper; +} // namespace framework + +using Tensor = paddle::Tensor; + +#define PD_DISABLE_COPY_AND_ASSIGN(classname) \ + private: \ + classname(const classname&) = delete; \ + classname(classname&&) = delete; \ + classname& operator=(const classname&) = delete; \ + classname& operator=(classname&&) = delete + +///////////////// Util Define and Function //////////////// + +inline std::string Grad(const std::string& var_name) { + std::string result; + result.reserve(var_name.size() + 5U); + result += var_name; + result += "@GRAD"; + return result; +} + +////////////////////// Kernel Function (PD_KERNEL) //////////////////////// + +// Record Op kernel core function +using KernelFunc = std::vector (*)(std::vector inputs, + std::vector attrs); + +template +struct TypeTag {}; + +template +struct KernelFuncImpl; + +template +struct KernelFuncImpl { + static Return Compute(std::vector inputs, + std::vector attrs) { + return ComputeCallHelper>::template Compute<0, 0>( + inputs, attrs); + } + + private: + template + struct ComputeCallHelper; + + // for Tensor input + template + struct ComputeCallHelper { + template + static Return Compute(std::vector inputs, + std::vector attrs, + const PreviousArgs&... pargs) { + static_assert(attr_idx == 0, + "Input tensor should appear before attributes."); + const Tensor& arg = inputs[in_idx]; + return ComputeCallHelper::template Compute( + inputs, attrs, pargs..., arg); + } + }; + + // TODO(chenweihang): add support for attribute input + // int attribute input (not used now) + template + struct ComputeCallHelper { + template + static Return Compute(std::vector inputs, + std::vector attrs, + const PreviousArgs&... pargs) { + try { + int arg = boost::any_cast(attrs[attr_idx]); + return ComputeCallHelper::template Compute( + inputs, attrs, pargs..., arg); + } catch (boost::bad_any_cast&) { + throw std::runtime_error( + "Attribute cast error in custom operator. Expected int value."); + } + } + }; + + // end: base template + template + struct ComputeCallHelper> { + template + static Return Compute(std::vector inputs, + std::vector attrs, const Args&... args) { + return impl_fn(args...); + } + }; +}; + +#define PD_KERNEL(...) \ + ::paddle::KernelFuncImpl::Compute + +/////////////// InferShape Function (PD_INFER_SHAPE) /////////////// + +// Record Op infershape core function +using InferShapeFunc = std::vector> (*)( + std::vector> input_shapes); + +template +struct InferShapeFuncImpl; + +template +struct InferShapeFuncImpl { + static Return InferShape(std::vector> input_shapes) { + return InferShapeCallHelper>::template InferShape<0>( + input_shapes); + } + + private: + template + struct InferShapeCallHelper; + + // only one type input: std::vector + template + struct InferShapeCallHelper, Tail...> { + template + static Return InferShape(std::vector> input_shapes, + const PreviousArgs&... pargs) { + std::vector arg = input_shapes[in_idx]; + return InferShapeCallHelper::template InferShape( + input_shapes, pargs..., arg); + } + }; + + // end: base template + template + struct InferShapeCallHelper> { + template + static Return InferShape(std::vector> input_shapes, + const Args&... args) { + return impl_fn(args...); + } + }; +}; + +#define PD_INFER_SHAPE(...) \ + ::paddle::InferShapeFuncImpl::InferShape + +/////////////// InferDataType Function (PD_INFER_DTYPE) /////////////// + +// Record Op Infer dtype core function +using InferDtypeFunc = + std::vector (*)(std::vector input_dtypes); + +template +struct InferDtypeFuncImpl; + +template +struct InferDtypeFuncImpl { + static Return InferDtype(std::vector input_dtypes) { + return InferDtypeCallHelper>::template InferDtype<0>( + input_dtypes); + } + + private: + template + struct InferDtypeCallHelper; + + // Only one type input now: DataType + template + struct InferDtypeCallHelper { + template + static Return InferDtype(std::vector input_dtypes, + const PreviousArgs&... pargs) { + DataType arg = input_dtypes[in_idx]; + return InferDtypeCallHelper::template InferDtype( + input_dtypes, pargs..., arg); + } + }; + + // end: base template + template + struct InferDtypeCallHelper> { + template + static Return InferDtype(std::vector input_dtypes, + const Args&... args) { + return impl_fn(args...); + } + }; +}; + +#define PD_INFER_DTYPE(...) \ + ::paddle::InferDtypeFuncImpl::InferDtype + +////////////////////// Op Meta Info ////////////////////// + +class OpMetaInfo { + public: + explicit OpMetaInfo(const std::string& op_name) : name_(op_name) {} + OpMetaInfo& Inputs(std::vector&& inputs); + OpMetaInfo& Outputs(std::vector&& outputs); + OpMetaInfo& SetKernelFn(KernelFunc&& func); + OpMetaInfo& SetInferShapeFn(InferShapeFunc&& func); + OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func); + + private: + friend class framework::OpMetaInfoHelper; + + // 1. desc info + std::string name_; + std::vector inputs_; + std::vector outputs_; + std::vector attrs_; + + // 2. func info + KernelFunc kernel_fn_; + InferShapeFunc infer_shape_fn_; + InferDtypeFunc infer_dtype_fn_; +}; + +//////////////// Op Meta Info Map ///////////////// + +class OpMetaInfoMap { + public: + // this function's impl should keep in header file. + // if move to cc file, meta info can not be added + // into map + static OpMetaInfoMap& Instance() { + static OpMetaInfoMap g_custom_op_meta_info_map; + return g_custom_op_meta_info_map; + } + + std::vector& operator[](const std::string& name); + + const std::unordered_map>& GetMap() + const; + + private: + OpMetaInfoMap() = default; + std::unordered_map> map_; + + PD_DISABLE_COPY_AND_ASSIGN(OpMetaInfoMap); +}; + +//////////////// Op Meta Info Builder ///////////////// + +class OpMetaInfoBuilder { + public: + explicit OpMetaInfoBuilder(std::string&& name); + OpMetaInfoBuilder& Inputs(std::vector&& inputs); + OpMetaInfoBuilder& Outputs(std::vector&& outputs); + OpMetaInfoBuilder& SetKernelFn(KernelFunc&& func); + OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc&& func); + OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc&& func); + OpMetaInfoBuilder& SetBackwardOp(const std::string& bwd_op_name); + + private: + // Forward Op name + std::string name_; + // Point to the currently constructed op meta info + OpMetaInfo* info_ptr_; +}; + +/////////////////////// Op register API ///////////////////////// + +// For inference: compile directly with framework +// Call after PD_BUILD_OP(...) +void RegisterAllCustomOperator(); + +// Using this api to load compiled custom operator's dynamic library and +// register Custom +// Operator into it +void LoadCustomOperatorLib(const std::string& dso_name); + +/////////////////////// Op register Macro ///////////////////////// + +#define PD_BUILD_OP_WITH_COUNTER(op_name, counter) \ + static ::paddle::OpMetaInfoBuilder __op_meta_info_##counter##__ = \ + ::paddle::OpMetaInfoBuilder(op_name) + +#define PD_BUILD_OP_INNER(op_name, counter) \ + PD_BUILD_OP_WITH_COUNTER(op_name, counter) + +#define PD_BUILD_OP(op_name) PD_BUILD_OP_INNER(op_name, __COUNTER__) + +} // namespace paddle + +///////////////////// C API /////////////////// + +#ifdef __cplusplus +extern "C" { +#endif + +// C-API to get global OpMetaInfoMap. +paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap(); + +#ifdef __cplusplus +} +#endif diff --git a/paddle/fluid/extension/include/place.h b/paddle/fluid/extension/include/place.h new file mode 100644 index 00000000000..91d4f41c213 --- /dev/null +++ b/paddle/fluid/extension/include/place.h @@ -0,0 +1,22 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { + +// TODO(yangjiabin): Add other place support in next PR +enum class PlaceType { kUNK = -1, kCPU, kGPU }; + +} // namespace paddle diff --git a/paddle/fluid/extension/include/tensor.h b/paddle/fluid/extension/include/tensor.h new file mode 100644 index 00000000000..a5ce0d1a585 --- /dev/null +++ b/paddle/fluid/extension/include/tensor.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/extension/include/dtype.h" +#include "paddle/fluid/extension/include/place.h" + +namespace paddle { +namespace framework { +class CustomTensorUtils; +} // namespace framework +class Tensor { + public: + /// \brief Construct a Tensor on target Place for CustomOp. + /// Generally it's only used for user to create Tensor. + explicit Tensor(const PlaceType& place); + /// \brief Reset the shape of the tensor. + /// Generally it's only used for the input tensor. + /// Reshape must be called before calling + /// mutable_data() or copy_to(const PlaceType& place) + /// \param shape The shape to set. + void reshape(const std::vector& shape); + + /// \brief Get the memory pointer in CPU or GPU with + /// specific data type. + /// Please Reshape the tensor first before call this. + /// It's usually used to get input data pointer. + /// \param place The place of the tensor this will + /// override the original place of current tensor. + template + T* mutable_data(const PlaceType& place); + + /// \brief Get the memory pointer in CPU or GPU with + /// specific data type. Please Reshape the tensor + /// first before call this.It's usually used to get + /// input data pointer. + template + T* mutable_data(); + + /// \brief Get the memory pointer directly. + /// It's usually used to get the output data pointer. + /// \return The tensor data buffer pointer. + template + T* data() const; + + /// \brief Copy the host memory to tensor data. + /// It's usually used to set the input tensor data. + /// \param PlaceType of target place, of which + /// the tensor will copy to. + + template + Tensor copy_to(const PlaceType& place) const; + + /// \brief Return the shape of the Tensor. + std::vector shape() const; + + /// \brief Return the data type of the tensor. + /// It's usually used to get the output tensor data type. + /// \return The data type of the tensor. + DataType type() const; + + /// \brief Get the size of current tensor. + /// Use this method to get the size of tensor + /// \return int64_t. + int64_t size() const; + + /// \brief Get the place of current tensor. + /// Use this method to get the place of tensor + /// \return Place. + const PlaceType& place() const; + + /// \brief Cast datatype from one to another + Tensor cast(const DataType& target_type) const; + + private: + friend class framework::CustomTensorUtils; + mutable std::shared_ptr tensor_; + mutable PlaceType place_; +}; + +} // namespace paddle diff --git a/paddle/fluid/extension/src/op_meta_info.cc b/paddle/fluid/extension/src/op_meta_info.cc new file mode 100644 index 00000000000..f31723e5ac8 --- /dev/null +++ b/paddle/fluid/extension/src/op_meta_info.cc @@ -0,0 +1,123 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/extension/include/op_meta_info.h" + +#include +#include +#include + +#include "paddle/fluid/framework/custom_operator.h" + +namespace paddle { + +////////////////////// Op Meta Info ////////////////////// + +OpMetaInfo& OpMetaInfo::Inputs(std::vector&& inputs) { + inputs_ = std::forward>(inputs); + return *this; +} +OpMetaInfo& OpMetaInfo::Outputs(std::vector&& outputs) { + outputs_ = std::forward>(outputs); + return *this; +} +OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) { + kernel_fn_ = std::forward(func); + return *this; +} +OpMetaInfo& OpMetaInfo::SetInferShapeFn(InferShapeFunc&& func) { + infer_shape_fn_ = std::forward(func); + return *this; +} +OpMetaInfo& OpMetaInfo::SetInferDtypeFn(InferDtypeFunc&& func) { + infer_dtype_fn_ = std::forward(func); + return *this; +} + +//////////////// Op Meta Info Map ///////////////// + +std::vector& OpMetaInfoMap::operator[](const std::string& name) { + return map_[name]; +} + +const std::unordered_map>& +OpMetaInfoMap::GetMap() const { + return map_; +} + +//////////////// Op Meta Info Builder ///////////////// + +OpMetaInfoBuilder::OpMetaInfoBuilder(std::string&& name) { + name_ = std::forward(name); + auto& info_vector = OpMetaInfoMap::Instance()[name_]; + auto op_meta = OpMetaInfo(name_); + info_vector.emplace_back(std::move(op_meta)); + info_ptr_ = &(info_vector.back()); +} + +OpMetaInfoBuilder& OpMetaInfoBuilder::Inputs( + std::vector&& inputs) { + info_ptr_->Inputs(std::forward>(inputs)); + return *this; +} + +OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs( + std::vector&& outputs) { + info_ptr_->Outputs(std::forward>(outputs)); + return *this; +} + +OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc&& func) { + info_ptr_->SetKernelFn(std::forward(func)); + return *this; +} + +OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc&& func) { + info_ptr_->SetInferShapeFn(std::forward(func)); + return *this; +} + +OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc&& func) { + info_ptr_->SetInferDtypeFn(std::forward(func)); + return *this; +} + +OpMetaInfoBuilder& OpMetaInfoBuilder::SetBackwardOp( + const std::string& bwd_op_name) { + auto& info_vector = OpMetaInfoMap::Instance()[name_]; + auto op_meta = OpMetaInfo(bwd_op_name); + info_vector.emplace_back(std::move(op_meta)); + info_ptr_ = &(info_vector.back()); + return *this; +} + +/////////////////////// Op register API ///////////////////////// + +void RegisterAllCustomOperator() { + auto& op_meta_info_map = OpMetaInfoMap::Instance(); + framework::RegisterOperatorWithMetaInfoMap(op_meta_info_map); +} + +void LoadCustomOperatorLib(const std::string& dso_name) { + paddle::framework::LoadOpMetaInfoAndRegisterOp(dso_name); +} +} // namespace paddle + +extern "C" { + +paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap() { + return paddle::OpMetaInfoMap::Instance(); +} + +} // end extern "C" diff --git a/paddle/fluid/extension/src/tensor.cc b/paddle/fluid/extension/src/tensor.cc new file mode 100644 index 00000000000..11d505a5aab --- /dev/null +++ b/paddle/fluid/extension/src/tensor.cc @@ -0,0 +1,383 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/extension/include/tensor.h" +#include +#include "paddle/fluid/framework/custom_tensor_utils.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/transform.h" + +namespace paddle { + +template +struct CastDataTypeFunctor { + HOSTDEVICE inline OutType operator()(InType in) const { + return static_cast(in); + } +}; + +template +struct CastDataType { + CastDataType(const framework::Tensor &in, framework::Tensor *out, + const platform::DeviceContext *ctx) + : in_(in), out_(out), ctx_(ctx) {} + const framework::Tensor in_; + framework::Tensor *out_; + const platform::DeviceContext *ctx_; + + template + void apply() { + auto *in_begin = in_.data(); + auto *in_end = in_begin + in_.numel(); + auto *out_begin = out_->mutable_data(in_.place()); + + if (platform::is_cpu_place(in_.place())) { + platform::Transform trans; + auto *context = static_cast(ctx_); + trans(*context, in_begin, in_end, out_begin, + CastDataTypeFunctor()); +#ifdef __NVCC__ + } else if (platform::is_gpu_place(in_.place())) { + platform::Transform trans; + auto *context = static_cast(ctx_); + trans(*context, in_begin, in_end, out_begin, + CastDataTypeFunctor()); + context->Wait(); +#endif + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Place type is not supported when casting data type.")); + } + } +}; +template +void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, + int64_t ele_size) { +#ifdef PADDLE_WITH_CUDA + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + int device_num = paddle::platform::GetCurrentDeviceId(); + platform::CUDAPlace gpu_place(device_num); + auto *dev_ctx = + static_cast(pool.Get(gpu_place)); + if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kCPU)) { + memory::Copy(platform::CPUPlace(), static_cast(dst), gpu_place, src, + ele_size, dev_ctx->stream()); + } else if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kGPU)) { + memory::Copy(gpu_place, static_cast(dst), gpu_place, src, ele_size, + dev_ctx->stream()); + } else if ((src_plc == PlaceType::kCPU) && (dst_plc == PlaceType::kGPU)) { + memory::Copy(gpu_place, static_cast(dst), platform::CPUPlace(), src, + ele_size, dev_ctx->stream()); + } else { + PADDLE_THROW(platform::errors::Unavailable( + "Only GPU related Copy can reach this func.")); + } + cudaStreamSynchronize(dev_ctx->stream()); +#endif +} + +#define GET_CASTED_TENSOR \ + if (!tensor_) { \ + tensor_ = std::make_shared(); \ + } \ + auto *tensor = static_cast(tensor_.get()); + +void Tensor::reshape(const std::vector &shape) { + GET_CASTED_TENSOR + tensor->Resize(framework::make_ddim(shape)); +} + +Tensor::Tensor(const PlaceType &place) + : tensor_(std::make_shared()), place_(place) {} + +template +T *Tensor::mutable_data(const PlaceType &place) { + place_ = place; + return mutable_data(); +} + +template +T *Tensor::mutable_data() { + GET_CASTED_TENSOR + PADDLE_ENFORCE_GT( + tensor->numel(), 0, + platform::errors::PreconditionNotMet( + "You should call Tensor::Reshape(const std::vector " + "&shape)" + "function before retrieving mutable_data from input tensor.")); + switch (static_cast(place_)) { + case static_cast(PlaceType::kCPU): { + return tensor->mutable_data(platform::CPUPlace()); + } +#ifdef PADDLE_WITH_CUDA + case static_cast(PlaceType::kGPU): { + int device_num = platform::GetCurrentDeviceId(); + return tensor->mutable_data(platform::CUDAPlace(device_num)); + } +#endif + default: + PADDLE_THROW(platform::errors::Unavailable( + "Custom operator unsupported place id(%d)", + static_cast(place_))); + } +} + +template +T *Tensor::data() const { + GET_CASTED_TENSOR; + auto *res = tensor->data(); + return res; +} + +DataType Tensor::type() const { + GET_CASTED_TENSOR; + auto type = tensor->type(); + if (type == framework::proto::VarType::FP32) { + return DataType::FLOAT32; + } else if (type == framework::proto::VarType::INT64) { + return DataType::INT64; + } else if (type == framework::proto::VarType::INT32) { + return DataType::INT32; + } else if (type == framework::proto::VarType::INT16) { + return DataType::INT16; + } else if (type == framework::proto::VarType::INT8) { + return DataType::INT8; + } else if (type == framework::proto::VarType::UINT8) { + return DataType::UINT8; + } else if (type == framework::proto::VarType::FP64) { + return DataType::FLOAT64; + } else if (type == framework::proto::VarType::BF16) { + return DataType::BFLOAT16; + } else if (type == framework::proto::VarType::FP16) { + return DataType::FLOAT16; + } else if (type == framework::proto::VarType::COMPLEX64) { + return DataType::COMPLEX64; + } else if (type == framework::proto::VarType::COMPLEX128) { + return DataType::COMPLEX128; + } else if (type == framework::proto::VarType::BOOL) { + return DataType::BOOL; + } + return DataType::FLOAT32; +} + +template +Tensor Tensor::copy_to(const PlaceType &target_place) const { + GET_CASTED_TENSOR; + PADDLE_ENFORCE_GE(tensor->numel(), 0, + platform::errors::PreconditionNotMet( + "You should call Tensor::Reshape(const " + "std::vector &shape)" + "function before copying data from cpu.")); + size_t ele_size = tensor->numel() * sizeof(T); + auto *p_src_data = tensor->data(); + auto src_place = place(); + Tensor target = Tensor(target_place); + target.reshape(shape()); + auto *p_target_data = target.template mutable_data(); + + if ((src_place == PlaceType::kCPU) && (target_place == PlaceType::kCPU)) { + std::memcpy(static_cast(p_target_data), p_src_data, ele_size); + } else if ((src_place == PlaceType::kGPU) && + (target_place == PlaceType::kCPU)) { + GpuCopy(p_src_data, p_target_data, src_place, target_place, ele_size); + } else if ((src_place == PlaceType::kCPU) && + (target_place == PlaceType::kGPU)) { + GpuCopy(p_src_data, p_target_data, src_place, target_place, ele_size); + } else if ((src_place == PlaceType::kGPU) && + (target_place == PlaceType::kGPU)) { + GpuCopy(p_src_data, p_target_data, src_place, target_place, ele_size); + } else { + PADDLE_THROW(platform::errors::Unavailable( + "Not supported place transform of place: %d to place: %d", + static_cast(src_place), static_cast(target_place))); + } + return target; +} + +template Tensor Tensor::copy_to( + const PlaceType &target_place) const; +template Tensor Tensor::copy_to( + const PlaceType &target_place) const; +template Tensor Tensor::copy_to( + const PlaceType &target_place) const; +template Tensor Tensor::copy_to( + const PlaceType &target_place) const; +template Tensor Tensor::copy_to(const PlaceType &target_place) const; +template Tensor Tensor::copy_to(const PlaceType &target_place) const; +template Tensor Tensor::copy_to(const PlaceType &target_place) const; +template Tensor Tensor::copy_to(const PlaceType &target_place) const; +template Tensor Tensor::copy_to(const PlaceType &target_place) const; +template Tensor Tensor::copy_to(const PlaceType &target_place) const; +template Tensor Tensor::copy_to(const PlaceType &target_place) const; +template Tensor Tensor::copy_to(const PlaceType &target_place) const; + +template float *Tensor::data() const; +template double *Tensor::data() const; +template int64_t *Tensor::data() const; +template int32_t *Tensor::data() const; +template uint8_t *Tensor::data() const; +template int8_t *Tensor::data() const; +template paddle::platform::float16 *Tensor::data() + const; +template paddle::platform::bfloat16 *Tensor::data() + const; +template paddle::platform::complex128 * +Tensor::data() const; +template paddle::platform::complex64 * +Tensor::data() const; +template int16_t *Tensor::data() const; +template bool *Tensor::data() const; + +template float *Tensor::mutable_data(); +template double *Tensor::mutable_data(); +template int64_t *Tensor::mutable_data(); +template int32_t *Tensor::mutable_data(); +template uint8_t *Tensor::mutable_data(); +template int8_t *Tensor::mutable_data(); +template paddle::platform::float16 * +Tensor::mutable_data(); +template paddle::platform::bfloat16 * +Tensor::mutable_data(); +template paddle::platform::complex128 * +Tensor::mutable_data(); +template paddle::platform::complex64 * +Tensor::mutable_data(); +template int16_t *Tensor::mutable_data(); +template bool *Tensor::mutable_data(); + +template float *Tensor::mutable_data(const PlaceType &place); +template double *Tensor::mutable_data(const PlaceType &place); +template int64_t *Tensor::mutable_data(const PlaceType &place); +template int32_t *Tensor::mutable_data(const PlaceType &place); +template uint8_t *Tensor::mutable_data(const PlaceType &place); +template int8_t *Tensor::mutable_data(const PlaceType &place); +template paddle::platform::float16 * +Tensor::mutable_data(const PlaceType &place); +template paddle::platform::bfloat16 * +Tensor::mutable_data(const PlaceType &place); +template paddle::platform::complex128 * +Tensor::mutable_data(const PlaceType &place); +template paddle::platform::complex64 * +Tensor::mutable_data(const PlaceType &place); +template int16_t *Tensor::mutable_data(const PlaceType &place); +template bool *Tensor::mutable_data(const PlaceType &place); + +std::vector Tensor::shape() const { + GET_CASTED_TENSOR + return framework::vectorize(tensor->dims()); +} + +const PlaceType &Tensor::place() const { + GET_CASTED_TENSOR; + if (platform::is_cpu_place(tensor->place())) { + place_ = PlaceType::kCPU; + } else if (platform::is_gpu_place(tensor->place())) { + place_ = PlaceType::kGPU; + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Current Tensor hold unsupported Place Type, Please Init it" + "using Tensor::mutable_data(PaddlePlace) which T is" + "either Place::kCPU or Place::kGPU")); + } + return place_; +} + +Tensor Tensor::cast(const DataType &target_type) const { + GET_CASTED_TENSOR; + Tensor rlt = Tensor(place()); + rlt.reshape(this->shape()); + auto rlt_tensor_ = static_cast(rlt.tensor_.get()); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto ctx = pool.Get(tensor->place()); + auto src_type = tensor->type(); + auto dst_type = + framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(target_type); + switch (src_type) { + case framework::proto::VarType::FP16: + framework::VisitDataType( + dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::BF16: + framework::VisitDataType(dst_type, CastDataType( + *tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::FP32: + framework::VisitDataType(dst_type, + CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::FP64: + framework::VisitDataType(dst_type, + CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::INT32: + framework::VisitDataType(dst_type, + CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::INT64: + framework::VisitDataType( + dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::BOOL: + framework::VisitDataType(dst_type, + CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::INT16: + framework::VisitDataType( + dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::UINT8: + framework::VisitDataType( + dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::COMPLEX64: + framework::VisitDataType(dst_type, CastDataType( + *tensor, rlt_tensor_, ctx)); + break; + case framework::proto::VarType::COMPLEX128: + framework::VisitDataType(dst_type, CastDataType( + *tensor, rlt_tensor_, ctx)); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when casting data type.", + framework::DataTypeToString(src_type))); + } + return rlt; +} + +int64_t Tensor::size() const { + GET_CASTED_TENSOR; + return tensor->numel(); +} + +namespace framework { + +void CustomTensorUtils::ShareDataTo(const paddle::Tensor &src, void *dst) { + static_cast(dst)->ShareDataWith( + *static_cast(src.tensor_.get())); +} + +void CustomTensorUtils::ShareDataFrom(const void *src, + const paddle::Tensor &dst) { + if (!dst.tensor_) { + dst.tensor_ = std::make_shared(); + } + auto *tensor = static_cast(dst.tensor_.get()); + tensor->ShareDataWith(*static_cast(src)); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 6caa352f7ad..2f4dcf465de 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -321,11 +321,17 @@ message(STATUS "branch: ${PADDLE_BRANCH}") configure_file(commit.h.in commit.h) -set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer) +cc_library(custom_tensor SRCS ../extension/src/tensor.cc DEPS lod_tensor) +cc_library(op_meta_info SRCS ../extension/src/op_meta_info.cc DEPS custom_tensor) +cc_library(custom_operator SRCS custom_operator.cc DEPS operator op_registry device_context dynamic_loader custom_tensor op_meta_info) +cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog) + +set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator) cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES}) cc_library(paddle_framework_shared - SHARED SRCS executor.cc operator.cc + SHARED SRCS executor.cc operator.cc custom_operator.cc ../extension/src/tensor.cc + ../extension/src/op_meta_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/c/c_api.cc ${CMAKE_SOURCE_DIR}/paddle/fluid/imperative/layer.cc DEPS ${FLUID_FRAMEWORK_MODULES}) diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc new file mode 100644 index 00000000000..1e2a77e915d --- /dev/null +++ b/paddle/fluid/framework/custom_operator.cc @@ -0,0 +1,534 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/custom_operator.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/extension/include/tensor.h" +#include "paddle/fluid/framework/attribute.h" +#include "paddle/fluid/framework/c/c_api.h" +#include "paddle/fluid/framework/custom_tensor_utils.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/op_meta_info_helper.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace framework { + +namespace detail { + +// dynamic lib load func +template +static T* DynLoad(void* handle, std::string name) { + T* func = reinterpret_cast(dlsym(handle, name.c_str())); +#if !defined(_WIN32) + auto errorno = dlerror(); +#else + auto errorno = GetLastError(); +#endif // !_WIN32 + PADDLE_ENFORCE_NOT_NULL( + func, platform::errors::NotFound( + "Failed to load dynamic operator library, error message(%s).", + errorno)); + return func; +} + +inline bool IsGradVar(const std::string& var_name) { + std::string suffix = kGradVarSuffix; + return var_name.rfind(suffix) != std::string::npos; +} + +inline std::string NoGrad(const std::string& var_name) { + std::string suffix = kGradVarSuffix; + return var_name.substr(0, var_name.size() - kGradVarSuffixSize); +} + +inline bool IsMemberOf(const std::vector& vec, + const std::string& name) { + return std::find(vec.cbegin(), vec.cend(), name) != vec.cend(); +} + +} // namespace detail + +////////////////// Kernel Define //////////////////// + +// custom op kernel call function define +static void RunKernelFunc(const framework::ExecutionContext& ctx, + const paddle::KernelFunc& func, + const std::vector& inputs, + const std::vector& outputs) { + VLOG(1) << "Custom Operator: Start run KernelFunc."; + std::vector custom_ins; + for (auto& in_name : inputs) { + VLOG(1) << "Custom Operator: input name - " << in_name; + auto* x = ctx.Input(in_name); + PADDLE_ENFORCE_NOT_NULL(x, platform::errors::NotFound( + "Input tensor (%s) is nullptr.", in_name)); + PADDLE_ENFORCE_EQ(x->IsInitialized(), true, + platform::errors::InvalidArgument( + "Input tensor (%s) is not initialized.")); + auto custom_in = paddle::Tensor( + CustomTensorUtils::ConvertInnerPlaceToEnumPlace(x->place())); + CustomTensorUtils::ShareDataFrom(static_cast(x), custom_in); + custom_ins.emplace_back(custom_in); + } + + std::vector attrs; + + VLOG(1) << "Run ComputeFunc."; + auto outs = func(custom_ins, attrs); + + VLOG(1) << "Custom Operator: Share outputs into ExecutionContext."; + for (size_t i = 0; i < outputs.size(); ++i) { + auto* true_out = ctx.Output(outputs[i]); + CustomTensorUtils::ShareDataTo(outs.at(i), true_out); + } +} + +//////////////////// Operator Define ///////////////// + +class CustomOperator : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + // Dummy infershape + // Because it is a pure virtual function, it must be implemented + void InferShape(framework::InferShapeContext* ctx) const override { + VLOG(1) << "Custom Operator: Dummy infer shape of custom operator."; + } + + /** + * NOTE: [Skip the Kernel Selection] + * Custom Op only registers one Op kernel on each device, so that the + * data type selection and promotion that depends on GetExpectedKernelType, + * as well as the adaptation of various other special situations, + * need users to implement, to avoid users needs to implement + * GetExpectedKernelType function when expanding other cases. + * The RAW type is used here as the data type, indicating that + * it can only be determined at runtime. + */ + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType(proto::VarType::RAW, ctx.GetPlace()); + } + + /** + * NOTE: [Skip Input Variable Cast for DataType] + * Because the kernel data type is RAW, we should skip the cast for + * data type difference when PrepareData. + */ + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const OpKernelType& expected_kernel_type) { + return OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, tensor.layout()); + } +}; + +class CustomOpMaker : public OpProtoAndCheckerMaker { + public: + explicit CustomOpMaker(const std::vector& inputs, + const std::vector& outputs, + const std::vector& attrs) + : inputs_(inputs), outputs_(outputs), attrs_(attrs) {} + + void Make() override { + for (auto& in_name : inputs_) { + AddInput(in_name, "The input " + in_name + "of Custom operator."); + } + for (auto& out_name : outputs_) { + AddOutput(out_name, "The output " + out_name + "of Custom Operator."); + } + // TODO(chenweihang): support attrs in later PR + AddComment(R"DOC( +Custom Operator. + +According to the Tensor operation function implemented by the user +independently of the framework, it is encapsulated into a framework +operator to adapt to various execution scenarios such as dynamic graph, +mode static graph mode, and inference mode. + +)DOC"); + } + + private: + std::vector inputs_; + std::vector outputs_; + std::vector attrs_; +}; + +template +class CustomGradOpMaker; + +template <> +class CustomGradOpMaker : public SingleGradOpMaker { + public: + explicit CustomGradOpMaker( + const OpDesc& fwd_op, const std::unordered_set& no_grad_set, + std::unordered_map* grad_to_var, + const std::vector& grad_block, const std::string& name, + const std::vector& inputs, + const std::vector& outputs) + : SingleGradOpMaker(fwd_op, no_grad_set, grad_to_var, grad_block), + name_(name), + inputs_(inputs), + outputs_(outputs) {} + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType(name_); + + auto fwd_op_inputs = this->InputNames(); + auto fwd_op_outputs = this->OutputNames(); + + for (auto& in_name : inputs_) { + VLOG(1) << "Custom Operator: GradOpDescMaker - input: " << in_name; + if (!detail::IsGradVar(in_name)) { + if (detail::IsMemberOf(fwd_op_inputs, in_name)) { + grad_op->SetInput(in_name, this->Input(in_name)); + } else if (detail::IsMemberOf(fwd_op_outputs, in_name)) { + grad_op->SetInput(in_name, this->Output(in_name)); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The input tensor name `%s` is invalid, expected it is the input " + "or output of forward operator.", + in_name)); + } + } else { + grad_op->SetInput(in_name, this->OutputGrad(detail::NoGrad(in_name))); + } + } + for (auto& out_name : outputs_) { + VLOG(1) << "Custom Operator: GradOpDescMaker - output: " << out_name; + grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name))); + } + // TODO(chenweihang): support attrs in later PR + } + + private: + std::string name_; + std::vector inputs_; + std::vector outputs_; +}; + +template <> +class CustomGradOpMaker + : public SingleGradOpMaker { + public: + explicit CustomGradOpMaker( + const std::string& type, + const imperative::NameVarBaseMap& var_base_map_in, + const imperative::NameVarBaseMap& var_base_map_out, + const AttributeMap& attrs, + const std::map& inplace_map, + const std::string& name, const std::vector& inputs, + const std::vector& outputs) + : SingleGradOpMaker( + type, var_base_map_in, var_base_map_out, attrs, inplace_map), + name_(name), + inputs_(inputs), + outputs_(outputs) {} + + protected: + // TODO(chenweihang): The code is duplicated with the previous one, because + // ere OpMaker's Input, Output and other methods are protected. Putting the + // function implementation outside the class will cause the method to be + // uncallable, + // so it is still implemented in the class for the time being. + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType(name_); + + auto fwd_op_inputs = this->InputNames(); + auto fwd_op_outputs = this->OutputNames(); + + for (auto& in_name : inputs_) { + VLOG(1) << "Custom Operator: GradOpBaseMaker - input: " << in_name; + if (!detail::IsGradVar(in_name)) { + if (detail::IsMemberOf(fwd_op_inputs, in_name)) { + grad_op->SetInput(in_name, this->Input(in_name)); + } else if (detail::IsMemberOf(fwd_op_outputs, in_name)) { + grad_op->SetInput(in_name, this->Output(in_name)); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The input tensor name `%s` is invalid, expected it is the input " + "or output of forward operator.", + in_name)); + } + } else { + grad_op->SetInput(in_name, this->OutputGrad(detail::NoGrad(in_name))); + } + } + for (auto& out_name : outputs_) { + VLOG(1) << "Custom Operator: GradOpBaseMaker - output: " << out_name; + grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name))); + } + // TODO(chenweihang): support attrs in later PR + } + + private: + std::string name_; + std::vector inputs_; + std::vector outputs_; +}; + +//////////// Operator and Kernel Register ////////////// + +void RegisterOperatorKernelWithPlace(const std::string& name, + const paddle::KernelFunc& kernel_func, + const proto::VarType::Type type, + const PlaceType& place, + const std::vector& inputs, + const std::vector& outputs) { + OpKernelType key(type, + CustomTensorUtils::ConvertEnumPlaceToInnerPlace(place)); + VLOG(1) << "Custom Operator: op kernel key: " << key; + OperatorWithKernel::AllOpKernels()[name][key] = + [kernel_func, inputs, outputs](const framework::ExecutionContext& ctx) { + VLOG(1) << "Custom Operator: run custom kernel func in lambda."; + RunKernelFunc(ctx, kernel_func, inputs, outputs); + }; +} + +void RegisterOperatorKernel(const std::string& name, + const paddle::KernelFunc& kernel_func, + const std::vector& inputs, + const std::vector& outputs) { + VLOG(1) << "Custom Operator: op name in kernel: " << name; + // NOTE [ Dummy Op Kernel Key ] + // TODO(chenweihang): Because execute engine need get device context based + // op_kernel_key.place_, so we should register kernel for each + // device. But this is not entirely correct, if user only give a cpu kernel, + // but call api in gpu device, it will cause error. + RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, + PlaceType::kCPU, inputs, outputs); + RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, + PlaceType::kGPU, inputs, outputs); +} + +void RegisterOperatorWithMetaInfo( + const std::vector& op_meta_infos) { + /* Op register */ + OpInfo info; + + auto& base_op_meta = op_meta_infos.front(); + + auto op_name = OpMetaInfoHelper::GetOpName(base_op_meta); + auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta); + auto& op_outputs = OpMetaInfoHelper::GetOutputs(base_op_meta); + auto& op_attrs = OpMetaInfoHelper::GetAttrs(base_op_meta); + auto& kernel_fn = OpMetaInfoHelper::GetKernelFn(base_op_meta); + auto& infer_shape_func = OpMetaInfoHelper::GetInferShapeFn(base_op_meta); + auto& infer_dtype_func = OpMetaInfoHelper::GetInferDtypeFn(base_op_meta); + + VLOG(1) << "Custom Operator: forward, op name: " << op_name; + VLOG(1) << "Custom Operator: forward, op inputs: " + << string::join_strings(op_inputs, ','); + VLOG(1) << "Custom Operator: forward, op outputs: " + << string::join_strings(op_outputs, ','); + + // Op + info.creator_ = [](const std::string& op_name, const VariableNameMap& inputs, + const VariableNameMap& outputs, + const AttributeMap& attrs) { + return new CustomOperator(op_name, inputs, outputs, attrs); + }; + + // OpMaker + info.proto_ = new proto::OpProto; + info.proto_->set_type(op_name); + + info.checker_ = new OpAttrChecker(); + CustomOpMaker custom_maker(op_inputs, op_outputs, op_attrs); + custom_maker(info.proto_, info.checker_); + PADDLE_ENFORCE_EQ( + info.proto_->IsInitialized(), true, + platform::errors::PreconditionNotMet( + "Fail to initialize %s's OpProto, because %s is not initialized.", + op_name, info.proto_->InitializationErrorString())); + + // InferShape + PADDLE_ENFORCE_NOT_NULL( + infer_shape_func, + platform::errors::PreconditionNotMet( + "InferShapeFn is nullptr. Need to set the InferShapeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + info.infer_shape_ = [op_inputs, op_outputs, + infer_shape_func](InferShapeContext* ctx) { + std::vector> input_shapes; + + VLOG(1) << "Custom Operator: InferShape - get input ddim."; + for (auto& in_name : op_inputs) { + OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom"); + auto ddim = ctx->GetInputDim(in_name); + input_shapes.emplace_back(framework::vectorize(ddim)); + } + + VLOG(1) << "Custom Operator: InferShape - calc output ddim."; + auto output_shapes = infer_shape_func(input_shapes); + + VLOG(1) << "Custom Operator: InferShape - set output ddim."; + for (size_t i = 0; i < op_outputs.size(); ++i) { + ctx->SetOutputDim(op_outputs[i], framework::make_ddim(output_shapes[i])); + } + }; + + // Infer Dtype + PADDLE_ENFORCE_NOT_NULL( + infer_dtype_func, + platform::errors::PreconditionNotMet( + "InferDtypeFn is nullptr. Need to set the InferDtypeFn of custom " + "operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))")); + info.infer_var_type_ = [op_inputs, op_outputs, + infer_dtype_func](InferVarTypeContext* ctx) { + std::vector input_dtypes; + + VLOG(1) << "Custom Operator: InferDtype - get input dtype."; + for (auto& in_name : op_inputs) { + auto dtype = ctx->GetInputDataType(in_name); + input_dtypes.emplace_back( + CustomTensorUtils::ConvertInnerDTypeToEnumDType(dtype)); + } + + VLOG(1) << "Custom Operator: InferDtype - infer output dtype."; + auto output_dtypes = infer_dtype_func(input_dtypes); + + VLOG(1) << "Custom Operator: InferDtype - set output dtype."; + for (size_t i = 0; i < op_outputs.size(); ++i) { + ctx->SetOutputDataType( + op_outputs[i], + CustomTensorUtils::ConvertEnumDTypeToInnerDType(output_dtypes[i])); + } + }; + + // Kernel func + RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs); + + // If grad op or double grad op exists + std::string cur_op_name = op_name; + for (size_t i = 1; i < op_meta_infos.size(); ++i) { + auto& cur_grad_op = op_meta_infos[i]; + + auto& grad_op_name = OpMetaInfoHelper::GetOpName(cur_grad_op); + auto& grad_op_inputs = OpMetaInfoHelper::GetInputs(cur_grad_op); + auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op); + auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op); + + VLOG(1) << "Custom Operator: backward, op name: " << grad_op_name; + VLOG(1) << "Custom Operator: backward, op inputs: " + << string::join_strings(grad_op_inputs, ','); + VLOG(1) << "Custom Operator: backward, op outputs: " + << string::join_strings(grad_op_outputs, ','); + + // GradOpDescMaker + info.grad_op_maker_ = [grad_op_name, grad_op_inputs, grad_op_outputs]( + const OpDesc& fwd_op, + const std::unordered_set& no_grad_set, + std::unordered_map* grad_to_var, + const std::vector& grad_block) { + CustomGradOpMaker maker( + fwd_op, no_grad_set, grad_to_var, grad_block, grad_op_name, + grad_op_inputs, grad_op_outputs); + return maker(); + }; + + // GradOpBaseMaker + info.dygraph_grad_op_maker_ = [grad_op_name, grad_op_inputs, + grad_op_outputs]( + const std::string& type, + const imperative::NameVarBaseMap& var_base_map_in, + const imperative::NameVarBaseMap& var_base_map_out, + const framework::AttributeMap& attrs, + const std::map& inplace_map) { + CustomGradOpMaker maker( + type, var_base_map_in, var_base_map_out, attrs, inplace_map, + grad_op_name, grad_op_inputs, grad_op_outputs); + return maker(); + }; + + /* Grad op register */ + OpInfo grad_info; + + // Grad Op + grad_info.creator_ = []( + const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) { + return new CustomOperator(type, inputs, outputs, attrs); + }; + + // Grad InferShape (gradient's shape is same with forward input default) + grad_info.infer_shape_ = [grad_op_outputs](InferShapeContext* ctx) { + for (auto& out_name : grad_op_outputs) { + ctx->ShareDim(detail::NoGrad(out_name), out_name); + } + }; + + // Kernel func + RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs, + grad_op_outputs); + + // update current info + OpInfoMap::Instance().Insert(cur_op_name, info); + cur_op_name = grad_op_name; + info = grad_info; + } + // insert last info + OpInfoMap::Instance().Insert(cur_op_name, info); +} + +void RegisterOperatorWithMetaInfoMap( + const paddle::OpMetaInfoMap& op_meta_info_map) { + auto& meta_info_map = op_meta_info_map.GetMap(); + + PADDLE_ENFORCE_EQ(meta_info_map.empty(), false, + platform::errors::PreconditionNotMet( + "No custom operator that needs to be registered.")); + VLOG(1) << "Custom Operator: size of op meta info map - " + << meta_info_map.size(); + // pair: {op_type, OpMetaInfo} + for (auto& pair : meta_info_map) { + VLOG(1) << "Custom Operator: pair first -> op name: " << pair.first; + RegisterOperatorWithMetaInfo(pair.second); + } +} + +////////////////////// User APIs /////////////////////// + +// load op api +void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { + void* handle = paddle::platform::dynload::GetOpDsoHandle(dso_name); + + typedef OpMetaInfoMap& get_op_meta_info_map_t(); + auto* get_op_meta_info_map = + detail::DynLoad(handle, "PD_GetOpMetaInfoMap"); + auto& op_meta_info_map = get_op_meta_info_map(); + + RegisterOperatorWithMetaInfoMap(op_meta_info_map); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/custom_operator.h b/paddle/fluid/framework/custom_operator.h new file mode 100644 index 00000000000..f2f97e5e582 --- /dev/null +++ b/paddle/fluid/framework/custom_operator.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/extension/include/op_meta_info.h" + +namespace paddle { +namespace framework { + +// Load custom op api: register op after user compiled +void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name); + +// Register custom op api: register op directly +void RegisterOperatorWithMetaInfoMap( + const paddle::OpMetaInfoMap& op_meta_info_map); + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/custom_tensor_test.cc b/paddle/fluid/framework/custom_tensor_test.cc new file mode 100644 index 00000000000..33b66245428 --- /dev/null +++ b/paddle/fluid/framework/custom_tensor_test.cc @@ -0,0 +1,249 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "paddle/fluid/extension/include/all.h" +#include "paddle/fluid/framework/custom_tensor_utils.h" +#include "paddle/fluid/framework/lod_tensor.h" + +template +paddle::Tensor InitCPUTensorForTest() { + std::vector tensor_shape{5, 5}; + auto t1 = paddle::Tensor(paddle::PlaceType::kCPU); + t1.reshape(tensor_shape); + auto* p_data_ptr = t1.mutable_data(paddle::PlaceType::kCPU); + for (int64_t i = 0; i < t1.size(); i++) { + p_data_ptr[i] = T(5); + } + return t1; +} + +template +void TestCopyTensor() { + auto t1 = InitCPUTensorForTest(); + auto t1_cpu_cp = t1.template copy_to(paddle::PlaceType::kCPU); + CHECK((paddle::PlaceType::kCPU == t1_cpu_cp.place())); + for (int64_t i = 0; i < t1.size(); i++) { + CHECK_EQ(t1_cpu_cp.template data()[i], T(5)); + } +#ifdef PADDLE_WITH_CUDA + VLOG(2) << "Do GPU copy test"; + auto t1_gpu_cp = t1_cpu_cp.template copy_to(paddle::PlaceType::kGPU); + CHECK((paddle::PlaceType::kGPU == t1_gpu_cp.place())); + auto t1_gpu_cp_cp = t1_gpu_cp.template copy_to(paddle::PlaceType::kGPU); + CHECK((paddle::PlaceType::kGPU == t1_gpu_cp_cp.place())); + auto t1_gpu_cp_cp_cpu = + t1_gpu_cp.template copy_to(paddle::PlaceType::kCPU); + CHECK((paddle::PlaceType::kCPU == t1_gpu_cp_cp_cpu.place())); + for (int64_t i = 0; i < t1.size(); i++) { + CHECK_EQ(t1_gpu_cp_cp_cpu.template data()[i], T(5)); + } +#endif +} + +void TestAPIPlace() { + std::vector tensor_shape = {5, 5}; +#ifdef PADDLE_WITH_CUDA + auto t1 = paddle::Tensor(paddle::PlaceType::kGPU); + t1.reshape(tensor_shape); + t1.mutable_data(); + CHECK((paddle::PlaceType::kGPU == t1.place())); +#endif + auto t2 = paddle::Tensor(paddle::PlaceType::kCPU); + t2.reshape(tensor_shape); + t2.mutable_data(); + CHECK((paddle::PlaceType::kCPU == t2.place())); +} + +void TestAPISizeAndShape() { + std::vector tensor_shape = {5, 5}; + auto t1 = paddle::Tensor(paddle::PlaceType::kCPU); + t1.reshape(tensor_shape); + CHECK_EQ(t1.size(), 25); + CHECK(t1.shape() == tensor_shape); +} + +template +paddle::DataType TestDtype() { + std::vector tensor_shape = {5, 5}; + auto t1 = paddle::Tensor(paddle::PlaceType::kCPU); + t1.reshape(tensor_shape); + t1.template mutable_data(); + return t1.type(); +} + +template +void TestCast(paddle::DataType data_type) { + std::vector tensor_shape = {5, 5}; + auto t1 = paddle::Tensor(paddle::PlaceType::kCPU); + t1.reshape(tensor_shape); + t1.template mutable_data(); + auto t2 = t1.cast(data_type); + CHECK_EQ(t2.type(), data_type); +} + +void GroupTestCopy() { + VLOG(2) << "Float cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(2) << "BF16 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(2) << "int cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(2) << "int16 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(2) << "int8 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); + VLOG(2) << "uint8 cpu-cpu-gpu-gpu-cpu"; + TestCopyTensor(); +} + +void GroupTestCast() { + VLOG(2) << "int cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "int32 cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "int64 cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "double cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "bfloat16 cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "float16 cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "bool cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "uint8 cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "float cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "complex64 cast"; + TestCast(paddle::DataType::FLOAT32); + VLOG(2) << "complex128 cast"; + TestCast(paddle::DataType::FLOAT32); +} + +void GroupTestDtype() { + CHECK(TestDtype() == paddle::DataType::FLOAT32); + CHECK(TestDtype() == paddle::DataType::FLOAT64); + CHECK(TestDtype() == paddle::DataType::FLOAT16); + CHECK(TestDtype() == paddle::DataType::BFLOAT16); + CHECK(TestDtype() == + paddle::DataType::COMPLEX128); + CHECK(TestDtype() == + paddle::DataType::COMPLEX64); + CHECK(TestDtype() == paddle::DataType::INT32); + CHECK(TestDtype() == paddle::DataType::INT64); + CHECK(TestDtype() == paddle::DataType::INT16); + CHECK(TestDtype() == paddle::DataType::INT8); + CHECK(TestDtype() == paddle::DataType::UINT8); +} + +void GroupTestDtypeConvert() { + // enum -> proto + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::COMPLEX128) == + paddle::framework::proto::VarType::COMPLEX128); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::COMPLEX64) == + paddle::framework::proto::VarType::COMPLEX64); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::FLOAT64) == + paddle::framework::proto::VarType::FP64); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::FLOAT32) == + paddle::framework::proto::VarType::FP32); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::FLOAT16) == + paddle::framework::proto::VarType::FP16); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::BFLOAT16) == + paddle::framework::proto::VarType::BF16); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::UINT8) == + paddle::framework::proto::VarType::UINT8); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::INT8) == paddle::framework::proto::VarType::INT8); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::INT32) == + paddle::framework::proto::VarType::INT32); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::INT64) == + paddle::framework::proto::VarType::INT64); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::INT16) == + paddle::framework::proto::VarType::INT16); + CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType( + paddle::DataType::BOOL) == paddle::framework::proto::VarType::BOOL); + // proto -> enum + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::COMPLEX128) == + paddle::DataType::COMPLEX128); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::COMPLEX64) == + paddle::DataType::COMPLEX64); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::FP64) == + paddle::DataType::FLOAT64); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::FP32) == + paddle::DataType::FLOAT32); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::FP16) == + paddle::DataType::FLOAT16); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::BF16) == + paddle::DataType::BFLOAT16); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::INT64) == + paddle::DataType::INT64); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::INT32) == + paddle::DataType::INT32); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::INT8) == paddle::DataType::INT8); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::UINT8) == + paddle::DataType::UINT8); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::INT16) == + paddle::DataType::INT16); + CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType( + paddle::framework::proto::VarType::BOOL) == paddle::DataType::BOOL); +} + +TEST(CustomTensor, copyTest) { + VLOG(2) << "TestCopy"; + GroupTestCopy(); + VLOG(2) << "TestDtype"; + GroupTestDtype(); + VLOG(2) << "TestShape"; + TestAPISizeAndShape(); + VLOG(2) << "TestPlace"; + TestAPIPlace(); + VLOG(2) << "TestCast"; + GroupTestCast(); + VLOG(2) << "TestDtypeConvert"; + GroupTestDtypeConvert(); +} diff --git a/paddle/fluid/framework/custom_tensor_utils.h b/paddle/fluid/framework/custom_tensor_utils.h new file mode 100644 index 00000000000..4b465d3911d --- /dev/null +++ b/paddle/fluid/framework/custom_tensor_utils.h @@ -0,0 +1,145 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/extension/include/tensor.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { + +class CustomTensorUtils { + public: + /// \brief Share data TO another tensor. + /// Use this to pass tensor from op to op + /// \return void. + static void ShareDataTo(const paddle::Tensor& src, void* dst); + + /// \brief Share data FROM another tensor. + /// Use this to pass tensor from op to op + /// \return void. + static void ShareDataFrom(const void* src, const Tensor& dst); + + static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType( + const paddle::DataType& dtype) { + switch (dtype) { + case paddle::DataType::COMPLEX128: + return framework::proto::VarType::COMPLEX128; + case paddle::DataType::COMPLEX64: + return framework::proto::VarType::COMPLEX64; + case paddle::DataType::FLOAT64: + return framework::proto::VarType::FP64; + case paddle::DataType::FLOAT32: + return framework::proto::VarType::FP32; + case paddle::DataType::FLOAT16: + return framework::proto::VarType::FP16; + case paddle::DataType::BFLOAT16: + return framework::proto::VarType::BF16; + case paddle::DataType::UINT8: + return framework::proto::VarType::UINT8; + case paddle::DataType::INT8: + return framework::proto::VarType::INT8; + case paddle::DataType::INT32: + return framework::proto::VarType::INT32; + case paddle::DataType::INT64: + return framework::proto::VarType::INT64; + case paddle::DataType::INT16: + return framework::proto::VarType::INT16; + case paddle::DataType::BOOL: + return framework::proto::VarType::BOOL; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported data type code(%d) when casting enum data type into " + "paddle data type.", + static_cast(dtype))); + } + } + + static paddle::DataType ConvertInnerDTypeToEnumDType( + const framework::proto::VarType::Type& dtype) { + switch (dtype) { + case framework::proto::VarType::COMPLEX128: + return paddle::DataType::COMPLEX128; + case framework::proto::VarType::COMPLEX64: + return paddle::DataType::COMPLEX64; + case framework::proto::VarType::FP64: + return paddle::DataType::FLOAT64; + case framework::proto::VarType::FP32: + return paddle::DataType::FLOAT32; + case framework::proto::VarType::FP16: + return paddle::DataType::FLOAT16; + case framework::proto::VarType::BF16: + return paddle::DataType::BFLOAT16; + case framework::proto::VarType::INT64: + return paddle::DataType::INT64; + case framework::proto::VarType::INT32: + return paddle::DataType::INT32; + case framework::proto::VarType::INT8: + return paddle::DataType::INT8; + case framework::proto::VarType::UINT8: + return paddle::DataType::UINT8; + case framework::proto::VarType::INT16: + return paddle::DataType::INT16; + case framework::proto::VarType::BOOL: + return paddle::DataType::BOOL; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported data type `%s` when casting paddle data type into " + "enum data type.", + DataTypeToString(dtype))); + } + } + + // PaddlePlace <-> platform::Place + static platform::Place ConvertEnumPlaceToInnerPlace(const PlaceType& pc) { + if (pc == PlaceType::kCPU) { + return platform::Place(platform::CPUPlace()); + } else if (pc == PlaceType::kGPU) { +#ifdef PADDLE_WITH_CUDA + return platform::Place( + platform::CUDAPlace(platform::GetCurrentDeviceId())); +#endif + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported place type code(%d) when " + "casting enum place to paddle place.", + static_cast(pc))); + } + return platform::Place(); + } + + static PlaceType ConvertInnerPlaceToEnumPlace(const platform::Place& pc) { + if (platform::is_cpu_place(pc)) { + return PlaceType::kCPU; + } else if (platform::is_gpu_place(pc)) { +#ifdef PADDLE_WITH_CUDA + return PlaceType::kGPU; +#endif + } else { + PADDLE_THROW( + platform::errors::Unimplemented("Unsupported place type `%s` when " + "casting paddle place to enum place.", + pc)); + } + return PlaceType::kUNK; + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc index 0959a060515..05d3541d6a9 100644 --- a/paddle/fluid/framework/data_type.cc +++ b/paddle/fluid/framework/data_type.cc @@ -84,6 +84,10 @@ std::string DataTypeToString(const proto::VarType::Type type) { if (it != gDataTypeMap().proto_to_str_.end()) { return it->second; } + // deal with RAW type + if (type == proto::VarType::RAW) { + return "RAW(runtime decided type)"; + } PADDLE_THROW(platform::errors::Unimplemented( "Not support proto::VarType::Type(%d) as tensor type.", static_cast(type))); diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index 30a2ac2c6f6..084c6e6816b 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -97,10 +97,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var, framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; case proto::VarType::INT16: - framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; case proto::VarType::UINT8: - framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; default: PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/framework/op_meta_info_helper.h b/paddle/fluid/framework/op_meta_info_helper.h new file mode 100644 index 00000000000..06d9c94172d --- /dev/null +++ b/paddle/fluid/framework/op_meta_info_helper.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/fluid/extension/include/op_meta_info.h" + +namespace paddle { +namespace framework { + +class OpMetaInfoHelper { + public: + static const std::string& GetOpName(const paddle::OpMetaInfo& info) { + return info.name_; + } + static const std::vector& GetInputs( + const paddle::OpMetaInfo& info) { + return info.inputs_; + } + static const std::vector& GetOutputs( + const paddle::OpMetaInfo& info) { + return info.outputs_; + } + static const std::vector& GetAttrs( + const paddle::OpMetaInfo& info) { + return info.attrs_; + } + static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info) { + return info.kernel_fn_; + } + static const InferShapeFunc& GetInferShapeFn(const paddle::OpMetaInfo& info) { + return info.infer_shape_fn_; + } + static const InferDtypeFunc& GetInferDtypeFn(const paddle::OpMetaInfo& info) { + return info.infer_dtype_fn_; + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 2615b98d30d..792e181047f 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,7 +1,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context - gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper) + gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator) if (WITH_NCCL) set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 72b3c9645ba..0ebf77c4c29 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/feed_fetch_method.h" @@ -386,7 +387,7 @@ PYBIND11_MODULE(core_noavx, m) { PyCapsule_GetPointer(dltensor->ptr(), "dltensor")); PyCapsule_SetName(dltensor->ptr(), "used_dltensor"); DLTensor dl = dmt->dl_tensor; - Tensor tensor; + framework::Tensor tensor; if (dl.ctx.device_type == kDLCPU) { paddle::framework::TensorFromDLPack(dl, &tensor); @@ -524,77 +525,80 @@ PYBIND11_MODULE(core_noavx, m) { BindImperative(&m); - py::class_(m, "Tensor", py::buffer_protocol()) - .def("__array__", [](Tensor &self) { return TensorToPyArray(self); }) + py::class_(m, "Tensor", py::buffer_protocol()) + .def("__array__", + [](framework::Tensor &self) { return TensorToPyArray(self); }) .def("_is_initialized", - [](const Tensor &self) { return self.IsInitialized(); }) + [](const framework::Tensor &self) { return self.IsInitialized(); }) .def("_get_dims", - [](const Tensor &self) { return vectorize(self.dims()); }) + [](const framework::Tensor &self) { return vectorize(self.dims()); }) .def("_set_dims", - [](Tensor &self, const std::vector &dim) { + [](framework::Tensor &self, const std::vector &dim) { self.Resize(make_ddim(dim)); }) .def("_set_layout", - [](Tensor &self, const std::string &layout) { + [](framework::Tensor &self, const std::string &layout) { self.set_layout(StringToDataLayout(layout)); }) .def("_alloc_float", - [](Tensor &self, paddle::platform::CUDAPlace &place) { + [](framework::Tensor &self, paddle::platform::CUDAPlace &place) { self.mutable_data(place); }) .def("_alloc_float", - [](Tensor &self, paddle::platform::XPUPlace &place) { + [](framework::Tensor &self, paddle::platform::XPUPlace &place) { self.mutable_data(place); }) .def("_alloc_float", - [](Tensor &self, paddle::platform::CPUPlace &place) { + [](framework::Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); }) .def("_alloc_double", - [](Tensor &self, paddle::platform::CPUPlace &place) { + [](framework::Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); }) .def("_alloc_int", - [](Tensor &self, paddle::platform::CPUPlace &place) { + [](framework::Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); }) .def("_alloc_int", - [](Tensor &self, paddle::platform::XPUPlace &place) { + [](framework::Tensor &self, paddle::platform::XPUPlace &place) { self.mutable_data(place); }) .def("_alloc_int", - [](Tensor &self, paddle::platform::CUDAPlace &place) { + [](framework::Tensor &self, paddle::platform::CUDAPlace &place) { self.mutable_data(place); }) .def("_alloc_int", - [](Tensor &self, paddle::platform::CUDAPinnedPlace &place) { + [](framework::Tensor &self, + paddle::platform::CUDAPinnedPlace &place) { self.mutable_data(place); }) .def("_alloc_float", - [](Tensor &self, paddle::platform::CUDAPinnedPlace &place) { + [](framework::Tensor &self, + paddle::platform::CUDAPinnedPlace &place) { self.mutable_data(place); }) .def("_mutable_data", - [](Tensor &self, paddle::platform::CPUPlace &place, + [](framework::Tensor &self, paddle::platform::CPUPlace &place, paddle::framework::proto::VarType::Type type) { return reinterpret_cast(self.mutable_data(place, type)); }) .def("_mutable_data", - [](Tensor &self, paddle::platform::XPUPlace &place, + [](framework::Tensor &self, paddle::platform::XPUPlace &place, paddle::framework::proto::VarType::Type type) { return reinterpret_cast(self.mutable_data(place, type)); }) .def("_mutable_data", - [](Tensor &self, paddle::platform::CUDAPlace &place, + [](framework::Tensor &self, paddle::platform::CUDAPlace &place, paddle::framework::proto::VarType::Type type) { return reinterpret_cast(self.mutable_data(place, type)); }) .def("_mutable_data", - [](Tensor &self, paddle::platform::CUDAPinnedPlace &place, + [](framework::Tensor &self, paddle::platform::CUDAPinnedPlace &place, paddle::framework::proto::VarType::Type type) { return reinterpret_cast(self.mutable_data(place, type)); }) - .def("_clear", &Tensor::clear) + .def("_clear", &framework::Tensor::clear) .def("set", SetTensorFromPyArray, py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) .def("set", SetTensorFromPyArray, @@ -626,7 +630,9 @@ PYBIND11_MODULE(core_noavx, m) { t.set(np.ndarray([5, 30]), fluid.CPUPlace()) )DOC") - .def("shape", [](Tensor &self) { return vectorize(self.dims()); }, R"DOC( + .def("shape", + [](framework::Tensor &self) { return vectorize(self.dims()); }, + R"DOC( Return the shape of LoDTensor. Returns: @@ -644,7 +650,7 @@ PYBIND11_MODULE(core_noavx, m) { print(t.shape()) # [5, 30] )DOC") .def("_to_dlpack", - [](Tensor &self) { + [](framework::Tensor &self) { DLPackTensor dlpack_tensor(self, 1); DLManagedTensor *dmt = dlpack_tensor.ToCudfCompatibleDLManagedTensor(); @@ -669,20 +675,22 @@ PYBIND11_MODULE(core_noavx, m) { .def("_get_float_element", TensorGetElement) .def("_set_double_element", TensorSetElement) .def("_get_double_element", TensorGetElement) - .def("_place", [](Tensor &self) { return self.place(); }) - .def("_dtype", [](Tensor &self) { return self.type(); }) + .def("_place", [](framework::Tensor &self) { return self.place(); }) + .def("_dtype", [](framework::Tensor &self) { return self.type(); }) .def("_layout", - [](Tensor &self) { return DataLayoutToString(self.layout()); }) - .def("_share_data_with", &Tensor::ShareDataWith) + [](framework::Tensor &self) { + return DataLayoutToString(self.layout()); + }) + .def("_share_data_with", &framework::Tensor::ShareDataWith) .def("__getitem__", PySliceTensor, py::return_value_policy::reference) - .def("__str__", [](const Tensor &self) { + .def("__str__", [](const framework::Tensor &self) { std::stringstream ostr; ostr << self; return ostr.str(); }); // TODO(cql): add reference: en_user_guide_lod_tensor - py::class_(m, "LoDTensor", R"DOC( + py::class_(m, "LoDTensor", R"DOC( LoDTensor is a Tensor with optional LoD (Level of Details) information, it can be used for variable-length sequences, see :ref:`user_guide_lod_tensor` for details. @@ -766,7 +774,8 @@ PYBIND11_MODULE(core_noavx, m) { t = fluid.LoDTensor() )DOC") - .def("__array__", [](Tensor &self) { return TensorToPyArray(self); }) + .def("__array__", + [](framework::Tensor &self) { return TensorToPyArray(self); }) .def("__init__", [](LoDTensor &instance, const std::vector> &recursive_sequence_lengths) { @@ -1724,6 +1733,8 @@ All parameter, weight, gradient are variables in Paddle. m.def("init_gflags", framework::InitGflags); m.def("init_glog", framework::InitGLOG); m.def("load_op_library", framework::LoadOpLib); + m.def("load_op_meta_info_and_register_op", + framework::LoadOpMetaInfoAndRegisterOp); m.def("init_devices", []() { framework::InitDevices(); }); m.def("is_compiled_with_cuda", IsCompiledWithCUDA); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b8de693b1fe..e59ee7186e9 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1988,9 +1988,13 @@ class OpProtoHolder(object): def update_op_proto(self): op_protos = get_all_op_protos() + custom_op_names = [] for proto in op_protos: if proto.type not in self.op_proto_map: self.op_proto_map[proto.type] = proto + custom_op_names.append(proto.type) + + return custom_op_names @staticmethod def generated_op_attr_names(): @@ -5699,6 +5703,9 @@ def load_op_library(lib_filename): Args: lib_filename (str): name of dynamic library. + + Returns: + list[str]: new registered custom op names. Examples: .. code-block:: python @@ -5708,7 +5715,7 @@ def load_op_library(lib_filename): """ core.load_op_library(lib_filename) - OpProtoHolder.instance().update_op_proto() + return OpProtoHolder.instance().update_op_proto() def switch_device(device): diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index bb74c37c043..df1dc75a38c 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -20,9 +20,15 @@ set_property(TARGET relu_op_shared PROPERTY LINK_LIBRARIES ${TARGET_LIBRARIES} file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -# for coverage -LIST(REMOVE_ITEM TEST_OPS test_custom_op) - foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) endforeach() + +# Compiling .so will cost some time, but running process is very fast. +set_tests_properties(test_jit_load PROPERTIES TIMEOUT 180) +set_tests_properties(test_setup_install PROPERTIES TIMEOUT 180) +set_tests_properties(test_setup_build PROPERTIES TIMEOUT 180) +set_tests_properties(test_dispatch PROPERTIES TIMEOUT 180) + +set_tests_properties(test_simple_custom_op_setup PROPERTIES TIMEOUT 250) +set_tests_properties(test_simple_custom_op_jit PROPERTIES TIMEOUT 180) diff --git a/python/paddle/fluid/tests/custom_op/__init__.py b/python/paddle/fluid/tests/custom_op/__init__.py new file mode 100644 index 00000000000..6f0ea85344b --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc b/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc new file mode 100644 index 00000000000..e09ac2f87c8 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/dispatch_test_op.cc @@ -0,0 +1,138 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" + +template +void assign_cpu_kernel(const data_t* x_data, + data_t* out_data, + int64_t x_numel) { + for (int i = 0; i < x_numel; ++i) { + out_data[i] = x_data[i]; + } +} + +std::vector> InferShape(std::vector x_shape) { + return {x_shape}; +} + +std::vector InferDType(paddle::DataType x_dtype) { + return {x_dtype}; +} + +std::vector DispatchTestInterger(const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_INTEGRAL_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(), x.size()); + })); + + return {out}; +} + +PD_BUILD_OP("dispatch_test_integer") + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(DispatchTestInterger)) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDType)); + +std::vector DispatchTestComplex(const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_COMPLEX_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(), x.size()); + })); + + return {out}; +} + +PD_BUILD_OP("dispatch_test_complex") + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(DispatchTestComplex)) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDType)); + +std::vector DispatchTestFloatAndInteger( + const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(), x.size()); + })); + + return {out}; +} + +PD_BUILD_OP("dispatch_test_float_and_integer") + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(DispatchTestFloatAndInteger)) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDType)); + +std::vector DispatchTestFloatAndComplex( + const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(), x.size()); + })); + + return {out}; +} + +PD_BUILD_OP("dispatch_test_float_and_complex") + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(DispatchTestFloatAndComplex)) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDType)); + +std::vector DispatchTestFloatAndIntegerAndComplex( + const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(), x.size()); + })); + + return {out}; +} + +PD_BUILD_OP("dispatch_test_float_and_integer_and_complex") + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex)) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDType)); diff --git a/python/paddle/fluid/tests/custom_op/relu_op3.cc b/python/paddle/fluid/tests/custom_op/relu_op3.cc new file mode 100644 index 00000000000..ace9598c586 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/relu_op3.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class Relu3Op : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto in_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Y", in_dims); + } +}; + +class Relu3OpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddOutput("Y", "Output of relu_op"); + AddComment(R"DOC( +Relu3 Operator. +)DOC"); + } +}; + +class Relu3GradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto in_dims = ctx->GetInputDim(framework::GradVarName("Y")); + ctx->SetOutputDim(framework::GradVarName("X"), in_dims); + } +}; + +template +class Relu3GradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr op) const override { + op->SetType("relu3_grad"); + op->SetInput("Y", this->Output("Y")); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + op->SetAttrMap(this->Attrs()); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +using Tensor = framework::Tensor; + +template +class Relu3Kernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_t = ctx.Input("X"); + auto* out_t = ctx.Output("Y"); + auto x = in_t->data(); + auto y = out_t->mutable_data(ctx.GetPlace()); + for (int i = 0; i < in_t->numel(); ++i) { + y[i] = std::max(static_cast(0.), x[i]); + } + } +}; + +template +class Relu3GradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dy_t = ctx.Input(framework::GradVarName("Y")); + auto* y_t = ctx.Input("Y"); + auto* dx_t = ctx.Output(framework::GradVarName("X")); + + auto dy = dy_t->data(); + auto y = y_t->data(); + auto dx = dx_t->mutable_data(ctx.GetPlace()); + + for (int i = 0; i < y_t->numel(); ++i) { + dx[i] = dy[i] * (y[i] > static_cast(0) ? 1. : 0.); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; +REGISTER_OPERATOR(relu3, + ops::Relu3Op, + ops::Relu3OpMaker, + ops::Relu3GradMaker, + ops::Relu3GradMaker); +REGISTER_OPERATOR(relu3_grad, ops::Relu3GradOp); +REGISTER_OP_CPU_KERNEL(relu3, + ops::Relu3Kernel, + ops::Relu3Kernel); +REGISTER_OP_CPU_KERNEL(relu3_grad, + ops::Relu3GradKernel, + ops::Relu3GradKernel); diff --git a/python/paddle/fluid/tests/custom_op/relu_op3.cu b/python/paddle/fluid/tests/custom_op/relu_op3.cu new file mode 100644 index 00000000000..8a229cafebb --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/relu_op3.cu @@ -0,0 +1,87 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__global__ void KeRelu3(const T* x, const int num, T* y) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + y[i] = max(x[i], static_cast(0.)); + } +} + +template +class Relu3CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_t = ctx.Input("X"); + auto* out_t = ctx.Output("Y"); + auto x = in_t->data(); + auto y = out_t->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + int num = in_t->numel(); + int block = 512; + int grid = (num + block - 1) / block; + KeRelu3<<>>(x, num, y); + } +}; + +template +__global__ void KeRelu3Grad(const T* y, const T* dy, const int num, T* dx) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.); + } +} + +template +class Relu3GradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dy_t = ctx.Input(framework::GradVarName("Y")); + auto* y_t = ctx.Input("Y"); + auto* dx_t = ctx.Output(framework::GradVarName("X")); + + auto dy = dy_t->data(); + auto y = y_t->data(); + auto dx = dx_t->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + int num = dy_t->numel(); + int block = 512; + int grid = (num + block - 1) / block; + KeRelu3Grad<<>>(y, dy, num, dx); + } +}; + +} // namespace operators +} // namespace paddle + +using CUDA = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(relu3, + paddle::operators::Relu3CUDAKernel, + paddle::operators::Relu3CUDAKernel); + +REGISTER_OP_CUDA_KERNEL(relu3_grad, + paddle::operators::Relu3GradCUDAKernel, + paddle::operators::Relu3GradCUDAKernel); diff --git a/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc b/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc new file mode 100644 index 00000000000..ec64bce1873 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +std::vector relu_cuda_forward(const paddle::Tensor& x); +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out); + +std::vector ReluForward(const paddle::Tensor& x); + +std::vector ReluBackward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out); + +std::vector> ReluInferShape(std::vector x_shape); + +std::vector ReluInferDType(paddle::DataType x_dtype); + +// Reuse codes in `relu_op_simple.cc/cu` to register another custom operator +// to test jointly compile multi operators at same time. +PD_BUILD_OP("relu3") + .Inputs({"X"}) + .Outputs({"Out", "Fake_float64", "ZFake_int32"}) + .SetKernelFn(PD_KERNEL(ReluForward)) + .SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) + .SetBackwardOp("relu3_grad") + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackward)); diff --git a/python/paddle/fluid/tests/custom_op/relu_op_simple.cc b/python/paddle/fluid/tests/custom_op/relu_op_simple.cc new file mode 100644 index 00000000000..b02ecba6826 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/relu_op_simple.cc @@ -0,0 +1,136 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" + +template +void fill_constant_cpu_kernel(data_t* out_data, int64_t x_numel, data_t value) { + for (int i = 0; i < x_numel; ++i) { + out_data[i] = value; + } +} + +template +void relu_cpu_forward_kernel(const data_t* x_data, + data_t* out_data, + int64_t x_numel) { + for (int i = 0; i < x_numel; ++i) { + out_data[i] = std::max(static_cast(0.), x_data[i]); + } +} + +template +void relu_cpu_backward_kernel(const data_t* grad_out_data, + const data_t* out_data, + data_t* grad_x_data, + int64_t out_numel) { + for (int i = 0; i < out_numel; ++i) { + grad_x_data[i] = + grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); + } +} + +std::vector relu_cpu_forward(const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_FLOATING_TYPES( + x.type(), "relu_cpu_forward", ([&] { + relu_cpu_forward_kernel( + x.data(), out.mutable_data(x.place()), x.size()); + })); + // fake multi output: Fake_float64 with float64 dtype + auto fake_float64 = paddle::Tensor(paddle::PlaceType::kCPU); + fake_float64.reshape(x.shape()); + + fill_constant_cpu_kernel( + fake_float64.mutable_data(x.place()), x.size(), 0.); + + // fake multi output: ZFake_int32 with int32 dtype + auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kCPU); + zfake_int32.reshape(x.shape()); + + fill_constant_cpu_kernel( + zfake_int32.mutable_data(x.place()), x.size(), 1); + + return {out, fake_float64, zfake_int32}; +} + +std::vector relu_cpu_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU); + grad_x.reshape(x.shape()); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { + relu_cpu_backward_kernel( + grad_out.data(), + out.data(), + grad_x.mutable_data(x.place()), + out.size()); + })); + + return {grad_x}; +} + +std::vector relu_cuda_forward(const paddle::Tensor& x); +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out); + +std::vector ReluForward(const paddle::Tensor& x) { + // TODO(chenweihang): Check Input + if (x.place() == paddle::PlaceType::kCPU) { + return relu_cpu_forward(x); + } else if (x.place() == paddle::PlaceType::kGPU) { + return relu_cuda_forward(x); + } else { + throw std::runtime_error("Not implemented."); + } +} + +std::vector ReluBackward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + // TODO(chenweihang): Check Input + if (x.place() == paddle::PlaceType::kCPU) { + return relu_cpu_backward(x, out, grad_out); + } else if (x.place() == paddle::PlaceType::kGPU) { + return relu_cuda_backward(x, out, grad_out); + } else { + throw std::runtime_error("Not implemented."); + } +} + +std::vector> ReluInferShape(std::vector x_shape) { + return {x_shape, x_shape, x_shape}; +} + +std::vector ReluInferDType(paddle::DataType x_dtype) { + return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32}; +} + +PD_BUILD_OP("relu2") + .Inputs({"X"}) + .Outputs({"Out", "Fake_float64", "ZFake_int32"}) + .SetKernelFn(PD_KERNEL(ReluForward)) + .SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) + .SetBackwardOp("relu2_grad") + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackward)); diff --git a/python/paddle/fluid/tests/custom_op/relu_op_simple.cu b/python/paddle/fluid/tests/custom_op/relu_op_simple.cu new file mode 100644 index 00000000000..2ef6a5c1451 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/relu_op_simple.cu @@ -0,0 +1,93 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +template +__global__ void fill_constant_cuda_kernel(data_t* y, + const int num, + data_t value) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + y[i] = value; + } +} + +template +__global__ void relu_cuda_forward_kernel(const data_t* x, + data_t* y, + const int num) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + y[i] = max(x[i], static_cast(0.)); + } +} + +template +__global__ void relu_cuda_backward_kernel(const data_t* dy, + const data_t* y, + data_t* dx, + const int num) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.); + } +} + +std::vector relu_cuda_forward(const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kGPU); + out.reshape(x.shape()); + + int numel = x.size(); + int block = 512; + int grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_TYPES( + x.type(), "relu_cuda_forward_kernel", ([&] { + relu_cuda_forward_kernel<<>>( + x.data(), out.mutable_data(x.place()), numel); + })); + // fake multi output: Fake_1 + auto fake_float64 = paddle::Tensor(paddle::PlaceType::kGPU); + fake_float64.reshape(x.shape()); + fill_constant_cuda_kernel<<>>( + fake_float64.mutable_data(x.place()), numel, 0.); + // fake multi output: ZFake_1 + auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kGPU); + zfake_int32.reshape(x.shape()); + fill_constant_cuda_kernel<<>>( + zfake_int32.mutable_data(x.place()), numel, 1); + + return {out, fake_float64, zfake_int32}; +} + +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU); + grad_x.reshape(x.shape()); + + int numel = out.size(); + int block = 512; + int grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_TYPES( + out.type(), "relu_cuda_backward_kernel", ([&] { + relu_cuda_backward_kernel<<>>( + grad_out.data(), + out.data(), + grad_x.mutable_data(x.place()), + numel); + })); + + return {grad_x}; +} diff --git a/python/paddle/fluid/tests/custom_op/setup_build.py b/python/paddle/fluid/tests/custom_op/setup_build.py new file mode 100644 index 00000000000..16a74779307 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/setup_build.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from utils import paddle_includes, extra_compile_args +from paddle.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension, setup +from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method + +# switch to old custom op method +use_new_custom_op_load_method(False) + +file_dir = os.path.dirname(os.path.abspath(__file__)) + +setup( + name='librelu2_op_from_setup', + ext_modules=[ + CUDAExtension( + sources=['relu_op3.cc', 'relu_op3.cu', 'relu_op.cc', + 'relu_op.cu'], # test for multi ops + include_dirs=paddle_includes, + extra_compile_args=extra_compile_args) + ], + cmdclass={ + 'build_ext': BuildExtension.with_options( + no_python_abi_suffix=True, output_dir=file_dir) # for unittest + }) diff --git a/python/paddle/fluid/tests/custom_op/setup_install.py b/python/paddle/fluid/tests/custom_op/setup_install.py new file mode 100644 index 00000000000..18fbfbaf8b6 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/setup_install.py @@ -0,0 +1,29 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from utils import paddle_includes, extra_compile_args +from paddle.utils.cpp_extension import CUDAExtension, setup +from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method + +# switch to old custom op method +use_new_custom_op_load_method(False) + +setup( + name='custom_relu2', + ext_modules=CUDAExtension( # test for not specific name here. + sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc', + 'relu_op3.cu'], # test for multi ops + include_dirs=paddle_includes, + extra_compile_args=extra_compile_args)) diff --git a/python/paddle/fluid/tests/custom_op/setup_install_simple.py b/python/paddle/fluid/tests/custom_op/setup_install_simple.py new file mode 100644 index 00000000000..ed236ccbd4c --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/setup_install_simple.py @@ -0,0 +1,27 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from utils import paddle_includes, extra_compile_args +from paddle.utils.cpp_extension import CUDAExtension, setup + +setup( + name='simple_setup_relu2', + ext_modules=CUDAExtension( # test for not specific name here. + sources=[ + 'relu_op_simple.cc', 'relu_op_simple.cu', 'relu_op3_simple.cc' + ], # test for multi ops + include_dirs=paddle_includes, + extra_compile_args=extra_compile_args)) diff --git a/python/paddle/fluid/tests/custom_op/test_check_abi.py b/python/paddle/fluid/tests/custom_op/test_check_abi.py new file mode 100644 index 00000000000..b171fca2076 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_check_abi.py @@ -0,0 +1,135 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os +import warnings + +import paddle.utils.cpp_extension.extension_utils as utils + + +class TestABIBase(unittest.TestCase): + def test_environ(self): + compiler = 'gcc' + for flag in ['1', 'True', 'true']: + os.environ['PADDLE_SKIP_CHECK_ABI'] = flag + self.assertTrue(utils.check_abi_compatibility(compiler)) + + def del_environ(self): + key = 'PADDLE_SKIP_CHECK_ABI' + if key in os.environ: + del os.environ[key] + + +class TestCheckLinux(TestABIBase): + def test_expected_compiler(self): + if utils.OS_NAME.startswith('linux'): + gt = ['gcc', 'g++', 'gnu-c++', 'gnu-cc'] + self.assertListEqual(utils._expected_compiler_current_platform(), + gt) + + def test_gcc_version(self): + # clear environ + self.del_environ() + compiler = 'g++' + if utils.OS_NAME.startswith('linux'): + # all CI gcc version > 5.4.0 + self.assertTrue( + utils.check_abi_compatibility( + compiler, verbose=True)) + + def test_wrong_compiler_warning(self): + # clear environ + self.del_environ() + compiler = 'nvcc' # fake wrong compiler + if utils.OS_NAME.startswith('linux'): + with warnings.catch_warnings(record=True) as error: + flag = utils.check_abi_compatibility(compiler, verbose=True) + # check return False + self.assertFalse(flag) + # check Compiler Compatibility WARNING + self.assertTrue(len(error) == 1) + self.assertTrue( + "Compiler Compatibility WARNING" in str(error[0].message)) + + def test_exception(self): + # clear environ + self.del_environ() + compiler = 'python' # fake command + if utils.OS_NAME.startswith('linux'): + # to skip _expected_compiler_current_platform + def fake(): + return [compiler] + + # mock a fake function + raw_func = utils._expected_compiler_current_platform + utils._expected_compiler_current_platform = fake + with warnings.catch_warnings(record=True) as error: + flag = utils.check_abi_compatibility(compiler, verbose=True) + # check return False + self.assertFalse(flag) + # check ABI Compatibility WARNING + self.assertTrue(len(error) == 1) + self.assertTrue("Failed to check compiler version for" in + str(error[0].message)) + + # restore + utils._expected_compiler_current_platform = raw_func + + +class TestCheckMacOs(TestABIBase): + def test_expected_compiler(self): + if utils.OS_NAME.startswith('darwin'): + gt = ['clang', 'clang++'] + self.assertListEqual(utils._expected_compiler_current_platform(), + gt) + + def test_gcc_version(self): + # clear environ + self.del_environ() + + if utils.OS_NAME.startswith('darwin'): + # clang has no version limitation. + self.assertTrue(utils.check_abi_compatibility()) + + +class TestCheckWindows(TestABIBase): + def test_gcc_version(self): + # clear environ + self.del_environ() + + if utils.IS_WINDOWS: + # we skip windows now + self.assertTrue(utils.check_abi_compatibility()) + + +class TestJITCompilerException(unittest.TestCase): + def test_exception(self): + with self.assertRaisesRegexp(RuntimeError, + "Failed to check Python interpreter"): + file_path = os.path.abspath(__file__) + utils._jit_compile(file_path, interpreter='fake_cmd', verbose=True) + + +class TestRunCMDException(unittest.TestCase): + def test_exception(self): + for verbose in [True, False]: + with self.assertRaisesRegexp(RuntimeError, "Failed to run command"): + cmd = "fake cmd" + utils.run_cmd(cmd, verbose) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_custom_op.py b/python/paddle/fluid/tests/custom_op/test_custom_op.py index c9f7d0b7c96..1c0db0be154 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_op.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_op.py @@ -20,11 +20,16 @@ import contextlib import paddle import paddle.fluid as fluid - paddle.enable_static() -file_dir = os.path.dirname(os.path.abspath(__file__)) -fluid.load_op_library(os.path.join(file_dir, 'librelu2_op.so')) + +def load_so(so_name): + """ + Load .so file and parse custom op into OpInfoMap. + """ + file_dir = os.path.dirname(os.path.abspath(__file__)) + fluid.load_op_library(os.path.join(file_dir, so_name)) + from paddle.fluid.layer_helper import LayerHelper @@ -111,4 +116,5 @@ class CustomOpTest(unittest.TestCase): if __name__ == '__main__': + load_so(so_name='librelu2_op.so') unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_dispatch.py b/python/paddle/fluid/tests/custom_op/test_dispatch.py new file mode 100644 index 00000000000..1766a6042f3 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_dispatch.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import paddle +import numpy as np +from paddle.utils.cpp_extension import load +from utils import paddle_includes, extra_compile_args + +dispatch_op = load( + name='dispatch_op', + sources=['dispatch_test_op.cc'], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cflags=extra_compile_args) # add for Coverage CI + + +class TestJitDispatch(unittest.TestCase): + def setUp(self): + paddle.set_device('cpu') + + def run_dispatch_test(self, func, dtype): + np_x = np.ones([2, 2]).astype(dtype) + x = paddle.to_tensor(np_x) + out = func(x) + np_x = x.numpy() + np_out = out.numpy() + self.assertTrue(dtype in str(np_out.dtype)) + self.assertTrue( + np.array_equal(np_x, np_out), + "custom op x: {},\n custom op out: {}".format(np_x, np_out)) + + def test_dispatch_integer(self): + dtypes = ["int32", "int64", "int8", "uint8", "int16"] + for dtype in dtypes: + self.run_dispatch_test(dispatch_op.dispatch_test_integer, dtype) + + def test_dispatch_complex(self): + dtypes = ["complex64", "complex128"] + for dtype in dtypes: + self.run_dispatch_test(dispatch_op.dispatch_test_complex, dtype) + + def test_dispatch_float_and_integer(self): + dtypes = [ + "float32", "float64", "int32", "int64", "int8", "uint8", "int16" + ] + for dtype in dtypes: + self.run_dispatch_test(dispatch_op.dispatch_test_float_and_integer, + dtype) + + def test_dispatch_float_and_complex(self): + dtypes = ["float32", "float64", "complex64", "complex128"] + for dtype in dtypes: + self.run_dispatch_test(dispatch_op.dispatch_test_float_and_complex, + dtype) + + def test_dispatch_float_and_integer_and_complex(self): + dtypes = [ + "float32", "float64", "int32", "int64", "int8", "uint8", "int16", + "complex64", "complex128" + ] + for dtype in dtypes: + self.run_dispatch_test( + dispatch_op.dispatch_test_float_and_integer_and_complex, dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_jit_load.py b/python/paddle/fluid/tests/custom_op/test_jit_load.py new file mode 100644 index 00000000000..222c69f5edc --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_jit_load.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import paddle +import numpy as np +from paddle.utils.cpp_extension import load +from utils import paddle_includes, extra_compile_args +from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method + +# switch to old custom op method +use_new_custom_op_load_method(False) + +# Compile and load custom op Just-In-Time. +custom_module = load( + name='custom_relu2', + sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc', 'relu_op3.cu'], + interpreter='python', # add for unittest + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cflags=extra_compile_args, # add for Coverage CI + verbose=True # add for unittest +) + + +class TestJITLoad(unittest.TestCase): + def test_api(self): + raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32') + gt_data = np.array([[0, 1, 0], [1, 0, 0]]).astype('float32') + x = paddle.to_tensor(raw_data, dtype='float32') + # use custom api + out = custom_module.relu2(x) + out3 = custom_module.relu3(x) + + self.assertTrue(np.array_equal(out.numpy(), gt_data)) + self.assertTrue(np.array_equal(out3.numpy(), gt_data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_setup_build.py b/python/paddle/fluid/tests/custom_op/test_setup_build.py new file mode 100644 index 00000000000..1ef14c2e3aa --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_setup_build.py @@ -0,0 +1,69 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import numpy as np +from test_custom_op import CustomOpTest, load_so +import paddle +from paddle.utils.cpp_extension.extension_utils import run_cmd +from paddle.fluid.layer_helper import LayerHelper +from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method + +# switch to old custom op method +use_new_custom_op_load_method(False) + + +def compile_so(): + """ + Compile .so file by running setup.py config. + """ + # build .so with setup.py + file_dir = os.path.dirname(os.path.abspath(__file__)) + cmd = 'cd {} && python setup_build.py build'.format(file_dir) + run_cmd(cmd) + + +# `setup.py build` only produce .so file containing multi operators. +# Python Interface should be added manually. `relu2` api is in `test_custom_op.py` +def relu3(x, name=None): + helper = LayerHelper("relu3", **locals()) + out = helper.create_variable( + type=x.type, name=name, dtype=x.dtype, persistable=False) + helper.append_op(type="relu3", inputs={"X": x}, outputs={"Y": out}) + return out + + +class TestCompileMultiOp(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_relu3(self): + raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32') + x = paddle.to_tensor(raw_data, dtype='float32') + # use custom api + out = relu3(x) + + self.assertTrue( + np.array_equal(out.numpy(), + np.array([[0, 1, 0], [1, 0, 0]]).astype('float32'))) + + def tearDown(self): + paddle.enable_static() + + +if __name__ == '__main__': + compile_so() + load_so(so_name='librelu2_op_from_setup.so') + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_setup_install.py b/python/paddle/fluid/tests/custom_op/test_setup_install.py new file mode 100644 index 00000000000..1fd7b8a06f9 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_setup_install.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import site +import unittest +import paddle +import subprocess +import numpy as np +from paddle.utils.cpp_extension.extension_utils import run_cmd +from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method + +# switch to old custom op method +use_new_custom_op_load_method(False) + + +class TestSetUpInstall(unittest.TestCase): + def setUp(self): + cur_dir = os.path.dirname(os.path.abspath(__file__)) + # compile, install the custom op egg into site-packages under background + cmd = 'cd {} && python setup_install.py install'.format(cur_dir) + run_cmd(cmd) + + # NOTE(Aurelius84): Normally, it's no need to add following codes for users. + # But we simulate to pip install in current process, so interpreter don't snap + # sys.path has been updated. So we update it manually. + + # See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3 + site_dir = site.getsitepackages()[0] + custom_egg_path = [ + x for x in os.listdir(site_dir) if 'custom_relu2' in x + ] + assert len(custom_egg_path) == 1, "Matched egg number is %d." % len( + custom_egg_path) + sys.path.append(os.path.join(site_dir, custom_egg_path[0])) + + def test_api(self): + # usage: import the package directly + import custom_relu2 + + raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32') + gt_data = np.array([[0, 1, 0], [1, 0, 0]]).astype('float32') + x = paddle.to_tensor(raw_data, dtype='float32') + # use custom api + out = custom_relu2.relu2(x) + out3 = custom_relu2.relu3(x) + + self.assertTrue(np.array_equal(out.numpy(), gt_data)) + self.assertTrue(np.array_equal(out3.numpy(), gt_data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py b/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py new file mode 100644 index 00000000000..2c0dc1a4ca6 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py @@ -0,0 +1,125 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import paddle +import numpy as np +from paddle.utils.cpp_extension import load +from utils import paddle_includes, extra_compile_args +from test_simple_custom_op_setup import relu2_dynamic, relu2_static + +# Compile and load custom op Just-In-Time. +custom_module = load( + name='simple_jit_relu2', + sources=['relu_op_simple.cc', 'relu_op_simple.cu', 'relu_op3_simple.cc'], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cflags=extra_compile_args) # add for Coverage CI + + +class TestJITLoad(unittest.TestCase): + def setUp(self): + self.custom_ops = [custom_module.relu2, custom_module.relu3] + self.dtypes = ['float32', 'float64'] + self.devices = ['cpu', 'gpu'] + + def test_static(self): + for device in self.devices: + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + for custom_op in self.custom_ops: + out = relu2_static(custom_op, device, dtype, x) + pd_out = relu2_static(custom_op, device, dtype, x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + + def test_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + for custom_op in self.custom_ops: + out, x_grad = relu2_dynamic(custom_op, device, dtype, x) + pd_out, pd_x_grad = relu2_dynamic(custom_op, device, dtype, + x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + self.assertTrue( + np.array_equal(x_grad, pd_x_grad), + "custom op x grad: {},\n paddle api x grad: {}".format( + x_grad, pd_x_grad)) + + +class TestMultiOutputDtypes(unittest.TestCase): + def setUp(self): + self.custom_op = custom_module.relu2 + self.dtypes = ['float32', 'float64'] + self.devices = ['cpu', 'gpu'] + + def test_static(self): + paddle.enable_static() + for device in self.devices: + for dtype in self.dtypes: + res = self.run_static(device, dtype) + self.check_multi_outputs(res) + paddle.disable_static() + + def test_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + paddle.set_device(device) + x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + x = paddle.to_tensor(x_data) + outs = self.custom_op(x) + + self.assertTrue(len(outs) == 3) + self.check_multi_outputs(outs, True) + + def check_multi_outputs(self, outs, is_dynamic=False): + out, zero_float64, one_int32 = outs + if is_dynamic: + zero_float64 = zero_float64.numpy() + one_int32 = one_int32.numpy() + # Fake_float64 + self.assertTrue('float64' in str(zero_float64.dtype)) + self.assertTrue( + np.array_equal(zero_float64, np.zeros([4, 8]).astype('float64'))) + # ZFake_int32 + self.assertTrue('int32' in str(one_int32.dtype)) + self.assertTrue( + np.array_equal(one_int32, np.ones([4, 8]).astype('int32'))) + + def run_static(self, device, dtype): + paddle.set_device(device) + x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + + with paddle.static.scope_guard(paddle.static.Scope()): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name='X', shape=[None, 8], dtype=dtype) + outs = self.custom_op(x) + + exe = paddle.static.Executor() + exe.run(paddle.static.default_startup_program()) + res = exe.run(paddle.static.default_main_program(), + feed={'X': x_data}, + fetch_list=outs) + + return res + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py b/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py new file mode 100644 index 00000000000..cfa2db0ba24 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py @@ -0,0 +1,160 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import site +import unittest +import paddle +import paddle.static as static +import subprocess +import numpy as np +from paddle.utils.cpp_extension.extension_utils import run_cmd + + +def relu2_dynamic(func, device, dtype, np_x, use_func=True): + paddle.set_device(device) + + t = paddle.to_tensor(np_x) + t.stop_gradient = False + + out = func(t)[0] if use_func else paddle.nn.functional.relu(t) + out.stop_gradient = False + + out.backward() + + return out.numpy(), t.grad + + +def relu2_static(func, device, dtype, np_x, use_func=True): + paddle.enable_static() + paddle.set_device(device) + + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + x = static.data(name='X', shape=[None, 8], dtype=dtype) + x.stop_gradient = False + # out, fake_float64, fake_int32 + out = func(x)[0] if use_func else paddle.nn.functional.relu(x) + static.append_backward(out) + + exe = static.Executor() + exe.run(static.default_startup_program()) + # in static mode, x data has been covered by out + out_v = exe.run(static.default_main_program(), + feed={'X': np_x}, + fetch_list=[out.name]) + + paddle.disable_static() + return out_v + + +def relu2_static_pe(func, device, dtype, np_x, use_func=True): + paddle.enable_static() + paddle.set_device(device) + + places = static.cpu_places() if device is 'cpu' else static.cuda_places() + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + x = static.data(name='X', shape=[None, 8], dtype=dtype) + x.stop_gradient = False + out = func(x)[0] if use_func else paddle.nn.functional.relu(x) + static.append_backward(out) + + exe = static.Executor() + exe.run(static.default_startup_program()) + + # in static mode, x data has been covered by out + compiled_prog = static.CompiledProgram(static.default_main_program( + )).with_data_parallel( + loss_name=out.name, places=places) + out_v = exe.run(compiled_prog, + feed={'X': np_x}, + fetch_list=[out.name]) + + paddle.disable_static() + return out_v + + +class TestNewCustomOpSetUpInstall(unittest.TestCase): + def setUp(self): + cur_dir = os.path.dirname(os.path.abspath(__file__)) + # compile, install the custom op egg into site-packages under background + cmd = 'cd {} && python setup_install_simple.py install'.format(cur_dir) + run_cmd(cmd) + + # NOTE(Aurelius84): Normally, it's no need to add following codes for users. + # But we simulate to pip install in current process, so interpreter don't snap + # sys.path has been updated. So we update it manually. + + # See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3 + site_dir = site.getsitepackages()[0] + custom_egg_path = [ + x for x in os.listdir(site_dir) if 'simple_setup_relu2' in x + ] + assert len(custom_egg_path) == 1, "Matched egg number is %d." % len( + custom_egg_path) + sys.path.append(os.path.join(site_dir, custom_egg_path[0])) + + # usage: import the package directly + import simple_setup_relu2 + self.custom_ops = [simple_setup_relu2.relu2, simple_setup_relu2.relu3] + + self.dtypes = ['float32', 'float64'] + self.devices = ['cpu', 'gpu'] + + def test_static(self): + for device in self.devices: + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + for custom_op in self.custom_ops: + out = relu2_static(custom_op, device, dtype, x) + pd_out = relu2_static(custom_op, device, dtype, x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + + def test_static_pe(self): + for device in self.devices: + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + for custom_op in self.custom_ops: + out = relu2_static_pe(custom_op, device, dtype, x) + pd_out = relu2_static_pe(custom_op, device, dtype, x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + + def test_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + for custom_op in self.custom_ops: + out, x_grad = relu2_dynamic(custom_op, device, dtype, x) + pd_out, pd_x_grad = relu2_dynamic(custom_op, device, dtype, + x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + self.assertTrue( + np.array_equal(x_grad, pd_x_grad), + "custom op x grad: {},\n paddle api x grad: {}".format( + x_grad, pd_x_grad)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/utils.py b/python/paddle/fluid/tests/custom_op/utils.py new file mode 100644 index 00000000000..f293c751942 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/utils.py @@ -0,0 +1,33 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import six +from distutils.sysconfig import get_python_lib +from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS + +site_packages_path = get_python_lib() +# Note(Aurelius84): We use `add_test` in Cmake to config how to run unittest in CI. +# `PYTHONPATH` will be set as `build/python/paddle` that will make no way to find +# paddle include directory. Because the following path is generated after insalling +# PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI. +paddle_includes = [ + os.path.join(site_packages_path, 'paddle/include'), + os.path.join(site_packages_path, 'paddle/include/third_party') +] + +# TODO(Aurelius84): Memory layout is different if build paddle with PADDLE_WITH_MKLDNN=ON, +# and will lead to ABI problem on Coverage CI. We will handle it in next PR. +extra_compile_args = ['-DPADDLE_WITH_MKLDNN' + ] if six.PY2 and not IS_WINDOWS else [] diff --git a/python/paddle/utils/__init__.py b/python/paddle/utils/__init__.py index faf0fd4984d..1db1b66426c 100644 --- a/python/paddle/utils/__init__.py +++ b/python/paddle/utils/__init__.py @@ -25,6 +25,8 @@ from ..fluid.framework import require_version from . import download +from . import cpp_extension + __all__ = ['dump_config', 'deprecated', 'download', 'run_check'] #TODO: define new api under this directory diff --git a/python/paddle/utils/cpp_extension/__init__.py b/python/paddle/utils/cpp_extension/__init__.py new file mode 100644 index 00000000000..024fbb6bf7c --- /dev/null +++ b/python/paddle/utils/cpp_extension/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .cpp_extension import CUDAExtension +from .cpp_extension import CppExtension +from .cpp_extension import BuildExtension +from .cpp_extension import load, setup + +from .extension_utils import parse_op_info +from .extension_utils import get_build_directory +from .extension_utils import load_op_meta_info_and_register_op + +from . import cpp_extension +from . import extension_utils + +__all__ = [ + 'CppExtension', 'CUDAExtension', 'BuildExtension', 'load', 'setup', + 'get_build_directory' +] diff --git a/python/paddle/utils/cpp_extension/cpp_extension.py b/python/paddle/utils/cpp_extension/cpp_extension.py new file mode 100644 index 00000000000..121c1626125 --- /dev/null +++ b/python/paddle/utils/cpp_extension/cpp_extension.py @@ -0,0 +1,471 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import six +import sys +import textwrap +import copy + +import setuptools +from setuptools.command.easy_install import easy_install +from setuptools.command.build_ext import build_ext + +from .extension_utils import find_cuda_home, normalize_extension_kwargs, add_compile_flag, bootstrap_context +from .extension_utils import is_cuda_file, prepare_unix_cflags, add_std_without_repeat, get_build_directory +from .extension_utils import _import_module_from_library, CustomOpInfo, _write_setup_file, _jit_compile, parse_op_name_from +from .extension_utils import check_abi_compatibility, log_v, IS_WINDOWS +from .extension_utils import use_new_custom_op_load_method + +CUDA_HOME = find_cuda_home() + + +def setup(**attr): + """ + Wrapper setuptools.setup function to valid `build_ext` command and + implement paddle api code injection by switching `write_stub` + function in bdist_egg with `custom_write_stub`. + + Its usage is almost same as `setuptools.setup` except for `ext_modules` + arguments. For compiling multi custom operators, all necessary source files + can be include into just one Extension (CppExtension/CUDAExtension). + Moreover, only one `name` argument is required in `setup` and no need to spcific + `name` in Extension. + + Example: + + >> from paddle.utils.cpp_extension import CUDAExtension, setup + >> setup(name='custom_module', + ext_modules=CUDAExtension( + sources=['relu_op.cc', 'relu_op.cu'], + include_dirs=[], # specific user-defined include dirs + extra_compile_args=[]) # specific user-defined compil arguments. + """ + cmdclass = attr.get('cmdclass', {}) + assert isinstance(cmdclass, dict) + # if not specific cmdclass in setup, add it automaticaly. + if 'build_ext' not in cmdclass: + cmdclass['build_ext'] = BuildExtension.with_options( + no_python_abi_suffix=True) + attr['cmdclass'] = cmdclass + + error_msg = """ + Required to specific `name` argument in paddle.utils.cpp_extension.setup. + It's used as `import XXX` when you want install and import your custom operators.\n + For Example: + # setup.py file + from paddle.utils.cpp_extension import CUDAExtension, setup + setup(name='custom_module', + ext_modules=CUDAExtension( + sources=['relu_op.cc', 'relu_op.cu']) + + # After running `python setup.py install` + from custom_module import relue + """ + # name argument is required + if 'name' not in attr: + raise ValueError(error_msg) + + ext_modules = attr.get('ext_modules', []) + if not isinstance(ext_modules, list): + ext_modules = [ext_modules] + assert len( + ext_modules + ) == 1, "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extenion.".format( + len(ext_modules)) + # replace Extension.name with attr['name] to keep consistant with Package name. + for ext_module in ext_modules: + ext_module.name = attr['name'] + + attr['ext_modules'] = ext_modules + + # Add rename .so hook in easy_install + assert 'easy_install' not in cmdclass + cmdclass['easy_install'] = EasyInstallCommand + + # Always set zip_safe=False to make compatible in PY2 and PY3 + # See http://peak.telecommunity.com/DevCenter/setuptools#setting-the-zip-safe-flag + attr['zip_safe'] = False + + # switch `write_stub` to inject paddle api in .egg + with bootstrap_context(): + setuptools.setup(**attr) + + +def CppExtension(sources, *args, **kwargs): + """ + Returns setuptools.CppExtension instance for setup.py to make it easy + to specify compile flags while building C++ custommed op kernel. + + Args: + sources(list[str]): The C++/CUDA source file names + args(list[options]): list of config options used to compile shared library + kwargs(dict[option]): dict of config options used to compile shared library + + Returns: + Extension: An instance of setuptools.Extension + """ + kwargs = normalize_extension_kwargs(kwargs, use_cuda=False) + # Note(Aurelius84): While using `setup` and `jit`, the Extension `name` will + # be replaced as `setup.name` to keep consistant with package. Because we allow + # users can not specific name in Extension. + # See `paddle.utils.cpp_extension.setup` for details. + name = kwargs.get('name', None) + if name is None: + name = _generate_extension_name(sources) + + return setuptools.Extension(name, sources, *args, **kwargs) + + +def CUDAExtension(sources, *args, **kwargs): + """ + Returns setuptools.CppExtension instance for setup.py to make it easy + to specify compile flags while build CUDA custommed op kernel. + + Args: + sources(list[str]): The C++/CUDA source file names + args(list[options]): list of config options used to compile shared library + kwargs(dict[option]): dict of config options used to compile shared library + + Returns: + Extension: An instance of setuptools.Extension + """ + kwargs = normalize_extension_kwargs(kwargs, use_cuda=True) + # Note(Aurelius84): While using `setup` and `jit`, the Extension `name` will + # be replaced as `setup.name` to keep consistant with package. Because we allow + # users can not specific name in Extension. + # See `paddle.utils.cpp_extension.setup` for details. + name = kwargs.get('name', None) + if name is None: + name = _generate_extension_name(sources) + + return setuptools.Extension(name, sources, *args, **kwargs) + + +def _generate_extension_name(sources): + """ + Generate extension name by source files. + """ + assert len(sources) > 0, "source files is empty" + file_prefix = [] + for source in sources: + source = os.path.basename(source) + filename, _ = os.path.splitext(source) + # Use list to generate same order. + if filename not in file_prefix: + file_prefix.append(filename) + + return '_'.join(file_prefix) + + +class BuildExtension(build_ext, object): + """ + Inherited from setuptools.command.build_ext to customize how to apply + compilation process with share library. + """ + + @classmethod + def with_options(cls, **options): + """ + Returns a BuildExtension subclass containing use-defined options. + """ + + class cls_with_options(cls): + def __init__(self, *args, **kwargs): + kwargs.update(options) + cls.__init__(self, *args, **kwargs) + + return cls_with_options + + def __init__(self, *args, **kwargs): + """ + Attributes is initialized with following oreder: + + 1. super(self).__init__() + 2. initialize_options(self) + 3. the reset of current __init__() + 4. finalize_options(self) + + So, it is recommended to set attribute value in `finalize_options`. + """ + super(BuildExtension, self).__init__(*args, **kwargs) + self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", True) + self.output_dir = kwargs.get("output_dir", None) + # for compatible two custom op define method + use_new_custom_op_load_method( + kwargs.get("use_new_method", use_new_custom_op_load_method())) + + def initialize_options(self): + super(BuildExtension, self).initialize_options() + + def finalize_options(self): + super(BuildExtension, self).finalize_options() + # NOTE(Aurelius84): Set location of compiled shared library. + # Carefully to modify this because `setup.py build/install` + # and `load` interface rely on this attribute. + if self.output_dir is not None: + self.build_lib = self.output_dir + + def build_extensions(self): + self._check_abi() + for extension in self.extensions: + # check settings of compiler + if isinstance(extension.extra_compile_args, dict): + for compiler in ['cxx', 'nvcc']: + if compiler not in extension.extra_compile_args: + extension.extra_compile_args[compiler] = [] + # add determine compile flags + add_compile_flag(extension, '-std=c++11') + + # Consider .cu, .cu.cc as valid source extensions. + self.compiler.src_extensions += ['.cu', '.cu.cc'] + # Save the original _compile method for later. + if self.compiler.compiler_type == 'msvc' or IS_WINDOWS: + raise NotImplementedError("Not support on MSVC currently.") + else: + original_compile = self.compiler._compile + + def unix_custom_single_compiler(obj, src, ext, cc_args, extra_postargs, + pp_opts): + """ + Monkey patch machanism to replace inner compiler to custom complie process on Unix platform. + """ + # use abspath to ensure no warning and don't remove deecopy because modify params + # with dict type is dangerous. + src = os.path.abspath(src) + cflags = copy.deepcopy(extra_postargs) + try: + original_compiler = self.compiler.compiler_so + # ncvv compile CUDA source + if is_cuda_file(src): + assert CUDA_HOME is not None + nvcc_cmd = os.path.join(CUDA_HOME, 'bin', 'nvcc') + self.compiler.set_executable('compiler_so', nvcc_cmd) + # {'nvcc': {}, 'cxx: {}} + if isinstance(cflags, dict): + cflags = cflags['nvcc'] + else: + cflags = prepare_unix_cflags(cflags) + # cxx compile Cpp source + elif isinstance(cflags, dict): + cflags = cflags['cxx'] + + add_std_without_repeat( + cflags, self.compiler.compiler_type, use_std14=False) + original_compile(obj, src, ext, cc_args, cflags, pp_opts) + finally: + # restore original_compiler + self.compiler.compiler_so = original_compiler + + def object_filenames_with_cuda(origina_func, build_directory): + """ + Decorated the function to add customized naming machanism. + Originally, both .cc/.cu will have .o object output that will + bring file override problem. Use .cu.o as CUDA object suffix. + """ + + def wrapper(source_filenames, strip_dir=0, output_dir=''): + try: + objects = origina_func(source_filenames, strip_dir, + output_dir) + for i, source in enumerate(source_filenames): + # modify xx.o -> xx.cu.o + if is_cuda_file(source): + old_obj = objects[i] + objects[i] = old_obj[:-1] + 'cu.o' + # if user set build_directory, output objects there. + if build_directory is not None: + objects = [ + os.path.join(build_directory, os.path.basename(obj)) + for obj in objects + ] + # ensure to use abspath + objects = [os.path.abspath(obj) for obj in objects] + finally: + self.compiler.object_filenames = origina_func + + return objects + + return wrapper + + # customized compile process + self.compiler._compile = unix_custom_single_compiler + self.compiler.object_filenames = object_filenames_with_cuda( + self.compiler.object_filenames, self.build_lib) + + self._record_op_info() + + print("Compiling user custom op, it will cost a few seconds.....") + build_ext.build_extensions(self) + + def get_ext_filename(self, fullname): + # for example: custommed_extension.cpython-37m-x86_64-linux-gnu.so + ext_name = super(BuildExtension, self).get_ext_filename(fullname) + if self.no_python_abi_suffix and six.PY3: + split_str = '.' + name_items = ext_name.split(split_str) + assert len( + name_items + ) > 2, "Expected len(name_items) > 2, but received {}".format( + len(name_items)) + name_items.pop(-2) + # custommed_extension.so + ext_name = split_str.join(name_items) + + return ext_name + + def _check_abi(self): + """ + Check ABI Compatibility. + """ + if hasattr(self.compiler, 'compiler_cxx'): + compiler = self.compiler.compiler_cxx[0] + elif IS_WINDOWS: + compiler = os.environ.get('CXX', 'cl') + raise NotImplementedError("We don't support Windows Currently.") + else: + compiler = os.environ.get('CXX', 'c++') + + check_abi_compatibility(compiler) + + def _record_op_info(self): + """ + Record custum op inforomation. + """ + # parse shared library abs path + outputs = self.get_outputs() + assert len(outputs) == 1 + # multi operators built into same one .so file + so_path = os.path.abspath(outputs[0]) + so_name = os.path.basename(so_path) + + for i, extension in enumerate(self.extensions): + sources = [os.path.abspath(s) for s in extension.sources] + op_names = parse_op_name_from(sources) + + for op_name in op_names: + CustomOpInfo.instance().add(op_name, + so_name=so_name, + so_path=so_path) + + +class EasyInstallCommand(easy_install, object): + """ + Extend easy_intall Command to control the behavior of naming shared library + file. + + NOTE(Aurelius84): This is a hook subclass inherited Command used to rename shared + library file after extracting egg-info into site-packages. + """ + + def __init__(self, *args, **kwargs): + super(EasyInstallCommand, self).__init__(*args, **kwargs) + + # NOTE(Aurelius84): Add args and kwargs to make compatible with PY2/PY3 + def run(self, *args, **kwargs): + super(EasyInstallCommand, self).run(*args, **kwargs) + # NOTE: To avoid failing import .so file instead of + # python file because they have same name, we rename + # .so shared library to another name. + for egg_file in self.outputs: + filename, ext = os.path.splitext(egg_file) + if ext == '.so': + new_so_path = filename + "_pd_" + ext + if not os.path.exists(new_so_path): + os.rename(r'%s' % egg_file, r'%s' % new_so_path) + assert os.path.exists(new_so_path) + + +def load(name, + sources, + extra_cflags=None, + extra_cuda_cflags=None, + extra_ldflags=None, + extra_include_paths=None, + build_directory=None, + interpreter=None, + verbose=False): + """ + An Interface to automatically compile C++/CUDA source files Just-In-Time + and return callable python function as other Paddle layers API. It will + append user defined custom op in background. + + This module will perform compiling, linking, api generation and module loading + processes for users. It does not require CMake or Ninja environment and only + g++/nvcc on Linux and clang++ on MacOS. Moreover, ABI compatibility will be + checked to ensure that compiler version on local machine is compatible with + pre-installed Paddle whl in python site-packages. For example if Paddle is built + with GCC5.4, the version of user's local machine should satisfy GCC >= 5.4. + Otherwise, a fatal error will occur because ABI compatibility. + + Args: + name(str): generated shared library file name. + sources(list[str]): custom op source files name with .cc/.cu suffix. + extra_cflag(list[str]): additional flags used to compile CPP files. By default + all basic and framework related flags have been included. + If your pre-insall Paddle supported MKLDNN, please add + '-DPADDLE_WITH_MKLDNN'. Default None. + extra_cuda_cflags(list[str]): additonal flags used to compile CUDA files. See + https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html + for details. Default None. + extra_ldflags(list[str]): additonal flags used to link shared library. See + https://gcc.gnu.org/onlinedocs/gcc/Link-Options.html for details. + Default None. + extra_include_paths(list[str]): additional include path used to search header files. + Default None. + build_directory(str): specific directory path to put shared library file. If set None, + it will use `PADDLE_EXTENSION_DIR` from os.environ. Use + `paddle.utils.cpp_extension.get_build_directory()` to see the location. + interpreter(str): alias or full interpreter path to specific which one to use if have installed multiple. + If set None, will use `python` as default interpreter. + verbose(bool): whether to verbose compiled log information + + Returns: + custom api: A callable python function with same signature as CustomOp Kernel defination. + + Example: + + >> from paddle.utils.cpp_extension import load + >> relu2 = load(name='relu2', + sources=['relu_op.cc', 'relu_op.cu']) + >> x = paddle.rand([4, 10]], dtype='float32') + >> out = relu2(x) + """ + + if build_directory is None: + build_directory = get_build_directory(verbose) + + # ensure to use abs path + build_directory = os.path.abspath(build_directory) + log_v("build_directory: {}".format(build_directory), verbose) + + file_path = os.path.join(build_directory, "setup.py") + sources = [os.path.abspath(source) for source in sources] + + # TODO(Aurelius84): split cflags and cuda_flags + if extra_cflags is None: extra_cflags = [] + if extra_cuda_cflags is None: extra_cuda_cflags = [] + compile_flags = extra_cflags + extra_cuda_cflags + log_v("additonal compile_flags: [{}]".format(' '.join(compile_flags)), + verbose) + + # write setup.py file and compile it + _write_setup_file(name, sources, file_path, extra_include_paths, + compile_flags, extra_ldflags, verbose) + _jit_compile(file_path, interpreter, verbose) + + # import as callable python api + custom_op_api = _import_module_from_library(name, build_directory, verbose) + + return custom_op_api diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py new file mode 100644 index 00000000000..52c17d77bd4 --- /dev/null +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -0,0 +1,722 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import six +import sys +import copy +import glob +import logging +import collections +import textwrap +import warnings +import subprocess + +from contextlib import contextmanager +from setuptools.command import bdist_egg + +from .. import load_op_library +from ...fluid import core +from ...fluid.framework import OpProtoHolder +from ...sysconfig import get_include, get_lib + +logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) +logger = logging.getLogger("utils.cpp_extension") + +OS_NAME = sys.platform +IS_WINDOWS = OS_NAME.startswith('win') +NVCC_COMPILE_FLAGS = [ + '-ccbin', 'cc', '-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU', '-DPADDLE_USE_DSO', + '-Xcompiler', '-fPIC', '-w', '--expt-relaxed-constexpr', '-O3', '-DNVCC' +] + +GCC_MINI_VERSION = (5, 4, 0) +# Give warning if using wrong compiler +WRONG_COMPILER_WARNING = ''' + ************************************* + * Compiler Compatibility WARNING * + ************************************* + +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +Found that your compiler ({user_compiler}) is not compatible with the compiler +built Paddle for this platform, which is {paddle_compiler} on {platform}. Please +use {paddle_compiler} to compile your custom op. Or you may compile Paddle from +source using {user_compiler}, and then also use it compile your custom op. + +See https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/2.0/install/compile/linux-compile.html +for help with compiling Paddle from source. + +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +''' +# Give warning if used compiler version is incompatible +ABI_INCOMPATIBILITY_WARNING = ''' + ********************************** + * ABI Compatibility WARNING * + ********************************** + +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +Found that your compiler ({user_compiler} == {version}) may be ABI-incompatible with pre-installed Paddle! +Please use compiler that is ABI-compatible with GCC >= 5.4 (Recommended 8.2). + +See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html for ABI Compatibility +information + +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +''' +USING_NEW_CUSTOM_OP_LOAD_METHOD = True + + +# NOTE(chenweihang): In order to be compatible with +# the two custom op define method, after removing +# old method, we can remove them together +def use_new_custom_op_load_method(*args): + global USING_NEW_CUSTOM_OP_LOAD_METHOD + if len(args) == 0: + return USING_NEW_CUSTOM_OP_LOAD_METHOD + else: + assert len(args) == 1 and isinstance(args[0], bool) + USING_NEW_CUSTOM_OP_LOAD_METHOD = args[0] + + +@contextmanager +def bootstrap_context(): + """ + Context to manage how to write `__bootstrap__` code in .egg + """ + origin_write_stub = bdist_egg.write_stub + bdist_egg.write_stub = custom_write_stub + yield + + bdist_egg.write_stub = origin_write_stub + + +def load_op_meta_info_and_register_op(lib_filename): + if USING_NEW_CUSTOM_OP_LOAD_METHOD: + core.load_op_meta_info_and_register_op(lib_filename) + else: + core.load_op_library(lib_filename) + return OpProtoHolder.instance().update_op_proto() + + +def custom_write_stub(resource, pyfile): + """ + Customized write_stub function to allow us to inject generated python + api codes into egg python file. + """ + _stub_template = textwrap.dedent(""" + import os + import sys + import types + import paddle + + def inject_ext_module(module_name, api_names): + if module_name in sys.modules: + return sys.modules[module_name] + + new_module = types.ModuleType(module_name) + for api_name in api_names: + setattr(new_module, api_name, eval(api_name)) + + return new_module + + def __bootstrap__(): + cur_dir = os.path.dirname(os.path.abspath(__file__)) + so_path = os.path.join(cur_dir, "{resource}") + + assert os.path.exists(so_path) + + # load custom op shared library with abs path + new_custom_ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path) + m = inject_ext_module(__name__, new_custom_ops) + + __bootstrap__() + + {custom_api} + """).lstrip() + + # Parse registerring op information + _, op_info = CustomOpInfo.instance().last() + so_path = op_info.so_path + + new_custom_ops = load_op_meta_info_and_register_op(so_path) + assert len( + new_custom_ops + ) > 0, "Required at least one custom operators, but received len(custom_op) = %d" % len( + new_custom_ops) + + # NOTE: To avoid importing .so file instead of python file because they have same name, + # we rename .so shared library to another name, see EasyInstallCommand. + filename, ext = os.path.splitext(resource) + resource = filename + "_pd_" + ext + + api_content = [] + for op_name in new_custom_ops: + api_content.append(_custom_api_content(op_name)) + + with open(pyfile, 'w') as f: + f.write( + _stub_template.format( + resource=resource, custom_api='\n\n'.join(api_content))) + + +OpInfo = collections.namedtuple('OpInfo', ['so_name', 'so_path']) + + +class CustomOpInfo: + """ + A global Singleton map to record all compiled custom ops information. + """ + + @classmethod + def instance(cls): + if not hasattr(cls, '_instance'): + cls._instance = cls() + return cls._instance + + def __init__(self): + assert not hasattr( + self.__class__, + '_instance'), 'Please use `instance()` to get CustomOpInfo object!' + # NOTE(Aurelius84): Use OrderedDict to save more order information + self.op_info_map = collections.OrderedDict() + + def add(self, op_name, so_name, so_path=None): + self.op_info_map[op_name] = OpInfo(so_name, so_path) + + def last(self): + """ + Return the lastest insert custom op info. + """ + assert len(self.op_info_map) > 0 + return next(reversed(self.op_info_map.items())) + + +def prepare_unix_cflags(cflags): + """ + Prepare all necessary compiled flags for nvcc compiling CUDA files. + """ + cflags = NVCC_COMPILE_FLAGS + cflags + get_cuda_arch_flags(cflags) + + return cflags + + +def add_std_without_repeat(cflags, compiler_type, use_std14=False): + """ + Append -std=c++11/14 in cflags if without specific it before. + """ + cpp_flag_prefix = '/std:' if compiler_type == 'msvc' else '-std=' + if not any(cpp_flag_prefix in flag for flag in cflags): + suffix = 'c++14' if use_std14 else 'c++11' + cpp_flag = cpp_flag_prefix + suffix + cflags.append(cpp_flag) + + +def get_cuda_arch_flags(cflags): + """ + For an arch, say "6.1", the added compile flag will be + ``-gencode=arch=compute_61,code=sm_61``. + For an added "+PTX", an additional + ``-gencode=arch=compute_xx,code=compute_xx`` is added. + """ + # TODO(Aurelius84): + return [] + + +def normalize_extension_kwargs(kwargs, use_cuda=False): + """ + Normalize include_dirs, library_dir and other attributes in kwargs. + """ + assert isinstance(kwargs, dict) + # append necessary include dir path of paddle + include_dirs = kwargs.get('include_dirs', []) + include_dirs.extend(find_paddle_includes(use_cuda)) + kwargs['include_dirs'] = include_dirs + + # append necessary lib path of paddle + library_dirs = kwargs.get('library_dirs', []) + library_dirs.extend(find_paddle_libraries(use_cuda)) + kwargs['library_dirs'] = library_dirs + + # add runtime library dirs + runtime_library_dirs = kwargs.get('runtime_library_dirs', []) + runtime_library_dirs.extend(find_paddle_libraries(use_cuda)) + kwargs['runtime_library_dirs'] = runtime_library_dirs + + # append compile flags + extra_compile_args = kwargs.get('extra_compile_args', []) + extra_compile_args.extend(['-g', '-w']) # diable warnings + kwargs['extra_compile_args'] = extra_compile_args + + # append link flags + extra_link_args = kwargs.get('extra_link_args', []) + extra_link_args.append('-lpaddle_framework') + if use_cuda: + extra_link_args.append('-lcudart') + + kwargs['extra_link_args'] = extra_link_args + + kwargs['language'] = 'c++' + return kwargs + + +def find_paddle_includes(use_cuda=False): + """ + Return Paddle necessary include dir path. + """ + # pythonXX/site-packages/paddle/include + paddle_include_dir = get_include() + third_party_dir = os.path.join(paddle_include_dir, 'third_party') + + include_dirs = [paddle_include_dir, third_party_dir] + + return include_dirs + + +def find_cuda_includes(): + + cuda_home = find_cuda_home() + if cuda_home is None: + raise ValueError( + "Not found CUDA runtime, please use `export CUDA_HOME=XXX` to specific it." + ) + + return [os.path.join(cuda_home, 'lib64')] + + +def find_cuda_home(): + """ + Use heuristic method to find cuda path + """ + # step 1. find in $CUDA_HOME or $CUDA_PATH + cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + + # step 2. find path by `which nvcc` + if cuda_home is None: + which_cmd = 'where' if IS_WINDOWS else 'which' + try: + with open(os.devnull, 'w') as devnull: + nvcc_path = subprocess.check_output( + [which_cmd, 'nvcc'], stderr=devnull) + if six.PY3: + nvcc_path = nvcc_path.decode() + nvcc_path = nvcc_path.rstrip('\r\n') + # for example: /usr/local/cuda/bin/nvcc + cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) + except: + if IS_WINDOWS: + # search from default NVIDIA GPU path + candidate_paths = glob.glob( + 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') + if len(candidate_paths) > 0: + cuda_home = candidate_paths[0] + else: + cuda_home = "/usr/local/cuda" + # step 3. check whether path is valid + if not os.path.exists(cuda_home) and core.is_compiled_with_cuda(): + cuda_home = None + warnings.warn( + "Not found CUDA runtime, please use `export CUDA_HOME= XXX` to specific it." + ) + + return cuda_home + + +def find_paddle_libraries(use_cuda=False): + """ + Return Paddle necessary library dir path. + """ + # pythonXX/site-packages/paddle/libs + paddle_lib_dirs = [get_lib()] + if use_cuda: + cuda_dirs = find_cuda_includes() + paddle_lib_dirs.extend(cuda_dirs) + return paddle_lib_dirs + + +def add_compile_flag(extension, flag): + extra_compile_args = copy.deepcopy(extension.extra_compile_args) + if isinstance(extra_compile_args, dict): + for args in extra_compile_args.values(): + args.append(flag) + else: + extra_compile_args.append(flag) + + extension.extra_compile_args = extra_compile_args + + +def is_cuda_file(path): + + cuda_suffix = set(['.cu']) + items = os.path.splitext(path) + assert len(items) > 1 + return items[-1] in cuda_suffix + + +def get_build_directory(verbose=False): + """ + Return paddle extension root directory, default specific by `PADDLE_EXTENSION_DIR` + """ + root_extensions_directory = os.environ.get('PADDLE_EXTENSION_DIR') + if root_extensions_directory is None: + dir_name = "paddle_extensions" + if OS_NAME.startswith('linux'): + root_extensions_directory = os.path.join( + os.path.expanduser('~/.cache'), dir_name) + else: + # TODO(Aurelius84): consider wind32/macOs + raise NotImplementedError("Only support Linux now.") + + log_v("$PADDLE_EXTENSION_DIR is not set, using path: {} by default.". + format(root_extensions_directory), verbose) + + if not os.path.exists(root_extensions_directory): + os.makedirs(root_extensions_directory) + + return root_extensions_directory + + +def parse_op_info(op_name): + """ + Parse input names and outpus detail information from registered custom op + from OpInfoMap. + """ + from paddle.fluid.framework import OpProtoHolder + if op_name not in OpProtoHolder.instance().op_proto_map: + raise ValueError( + "Please load {} shared library file firstly by `paddle.utils.cpp_extension.load_op_meta_info_and_register_op(...)`". + format(op_name)) + op_proto = OpProtoHolder.instance().get_op_proto(op_name) + + in_names = [x.name for x in op_proto.inputs] + out_names = [x.name for x in op_proto.outputs] + + return in_names, out_names + + +def _import_module_from_library(module_name, build_directory, verbose=False): + """ + Load .so shared library and import it as callable python module. + """ + # TODO(Aurelius84): Consider file suffix is .dll on Windows Platform. + ext_path = os.path.join(build_directory, module_name + '.so') + if not os.path.exists(ext_path): + raise FileNotFoundError("Extension path: {} does not exist.".format( + ext_path)) + + # load custom op_info and kernels from .so shared library + log_v('loading shared library from: {}'.format(ext_path), verbose) + op_names = load_op_meta_info_and_register_op(ext_path) + + # generate Python api in ext_path + return _generate_python_module(module_name, op_names, build_directory, + verbose) + + +def _generate_python_module(module_name, + op_names, + build_directory, + verbose=False): + """ + Automatically generate python file to allow import or load into as module + """ + api_file = os.path.join(build_directory, module_name + '.py') + log_v("generate api file: {}".format(api_file), verbose) + + # write into .py file + api_content = [_custom_api_content(op_name) for op_name in op_names] + with open(api_file, 'w') as f: + f.write('\n\n'.join(api_content)) + + # load module + custom_module = _load_module_from_file(api_file, verbose) + return custom_module + + +def _custom_api_content(op_name): + params_str, ins_str, outs_str = _get_api_inputs_str(op_name) + + API_TEMPLATE = textwrap.dedent(""" + from paddle.fluid.layer_helper import LayerHelper + + def {op_name}({inputs}): + helper = LayerHelper("{op_name}", **locals()) + + # prepare inputs and output + ins = {ins} + outs = {{}} + out_names = {out_names} + for out_name in out_names: + # Set 'float32' temporarily, and the actual dtype of output variable will be inferred + # in runtime. + outs[out_name] = helper.create_variable(dtype='float32') + + helper.append_op(type="{op_name}", inputs=ins, outputs=outs) + + res = [outs[out_name] for out_name in out_names] + + return res[0] if len(res)==1 else res + """).lstrip() + + # generate python api file + api_content = API_TEMPLATE.format( + op_name=op_name, inputs=params_str, ins=ins_str, out_names=outs_str) + + return api_content + + +def _load_module_from_file(api_file_path, verbose=False): + """ + Load module from python file. + """ + if not os.path.exists(api_file_path): + raise FileNotFoundError("File : {} does not exist.".format( + api_file_path)) + + # Unique readable module name to place custom api. + log_v('import module from file: {}'.format(api_file_path), verbose) + ext_name = "_paddle_cpp_extension_" + if six.PY2: + import imp + module = imp.load_source(ext_name, api_file_path) + else: + from importlib import machinery + loader = machinery.SourceFileLoader(ext_name, api_file_path) + module = loader.load_module() + + return module + + +def _get_api_inputs_str(op_name): + """ + Returns string of api parameters and inputs dict. + """ + in_names, out_names = parse_op_info(op_name) + # e.g: x, y, z + params_str = ','.join([p.lower() for p in in_names]) + # e.g: {'X': x, 'Y': y, 'Z': z} + ins_str = "{%s}" % ','.join( + ["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names]) + # e.g: ['Out', 'Index'] + outs_str = "[%s]" % ','.join(["'{}'".format(name) for name in out_names]) + return params_str, ins_str, outs_str + + +def _write_setup_file(name, + sources, + file_path, + include_dirs, + compile_flags, + link_args, + verbose=False): + """ + Automatically generate setup.py and write it into build directory. + """ + template = textwrap.dedent(""" + import os + from paddle.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension, setup + from paddle.utils.cpp_extension import get_build_directory + setup( + name='{name}', + ext_modules=[ + {prefix}Extension( + sources={sources}, + include_dirs={include_dirs}, + extra_compile_args={extra_compile_args}, + extra_link_args={extra_link_args})], + cmdclass={{"build_ext" : BuildExtension.with_options( + output_dir=get_build_directory(), + no_python_abi_suffix=True, + use_new_method={use_new_method}) + }})""").lstrip() + + with_cuda = False + if any([is_cuda_file(source) for source in sources]): + with_cuda = True + log_v("with_cuda: {}".format(with_cuda), verbose) + + content = template.format( + name=name, + prefix='CUDA' if with_cuda else 'Cpp', + sources=list2str(sources), + include_dirs=list2str(include_dirs), + extra_compile_args=list2str(compile_flags), + extra_link_args=list2str(link_args), + use_new_method=use_new_custom_op_load_method()) + + log_v('write setup.py into {}'.format(file_path), verbose) + with open(file_path, 'w') as f: + f.write(content) + + +def list2str(args): + """ + Convert list[str] into string. For example: [x, y] -> "['x', 'y']" + """ + if args is None: return '[]' + assert isinstance(args, (list, tuple)) + args = ["'{}'".format(arg) for arg in args] + return '[' + ','.join(args) + ']' + + +def _jit_compile(file_path, interpreter=None, verbose=False): + """ + Build shared library in subprocess + """ + ext_dir = os.path.dirname(file_path) + setup_file = os.path.basename(file_path) + + if interpreter is None: + interpreter = 'python' + try: + py_path = subprocess.check_output(['which', interpreter]) + py_version = subprocess.check_output([interpreter, '-V']) + if six.PY3: + py_path = py_path.decode() + py_version = py_version.decode() + log_v("Using Python interpreter: {}, version: {}".format( + py_path.strip(), py_version.strip()), verbose) + except Exception: + _, error, _ = sys.exc_info() + raise RuntimeError( + 'Failed to check Python interpreter with `{}`, errors: {}'.format( + interpreter, error)) + + compile_cmd = 'cd {} && {} {} build'.format(ext_dir, interpreter, + setup_file) + print("Compiling user custom op, it will cost a few seconds.....") + run_cmd(compile_cmd, verbose) + + +def parse_op_name_from(sources): + """ + Parse registerring custom op name from sources. + """ + + def regex(content): + if USING_NEW_CUSTOM_OP_LOAD_METHOD: + pattern = re.compile(r'PD_BUILD_OP\(([^,\)]+)\)') + else: + pattern = re.compile(r'REGISTER_OPERATOR\(([^,]+),') + + content = re.sub(r'\s|\t|\n', '', content) + op_name = pattern.findall(content) + op_name = set([re.sub('_grad', '', name) for name in op_name]) + + return op_name + + op_names = set() + for source in sources: + with open(source, 'r') as f: + content = f.read() + op_names |= regex(content) + + return list(op_names) + + +def run_cmd(command, verbose=False): + """ + Execute command with subprocess. + """ + # logging + log_v("execute command: {}".format(command), verbose) + try: + from subprocess import DEVNULL # py3 + except ImportError: + DEVNULL = open(os.devnull, 'wb') + + # execute command + try: + if verbose: + return subprocess.check_call( + command, shell=True, stderr=subprocess.STDOUT) + else: + return subprocess.check_call(command, shell=True, stdout=DEVNULL) + except Exception: + _, error, _ = sys.exc_info() + raise RuntimeError("Failed to run command: {}, errors: {}".format( + compile, error)) + + +def check_abi_compatibility(compiler, verbose=False): + """ + Check whether GCC version on user local machine is compatible with Paddle in + site-packages. + """ + # TODO(Aurelius84): After we support windows, remove IS_WINDOWS in following code. + if os.environ.get('PADDLE_SKIP_CHECK_ABI') in ['True', 'true', '1' + ] or IS_WINDOWS: + return True + + cmd_out = subprocess.check_output( + ['which', compiler], stderr=subprocess.STDOUT) + compiler_path = os.path.realpath(cmd_out.decode() + if six.PY3 else cmd_out).strip() + # step 1. if not found any suitable compiler, raise error + if not any(name in compiler_path + for name in _expected_compiler_current_platform()): + warnings.warn( + WRONG_COMPILER_WARNING.format( + user_compiler=compiler, + paddle_compiler=_expected_compiler_current_platform()[0], + platform=OS_NAME)) + return False + + # clang++ have no ABI compatibility problem + if OS_NAME.startswith('darwin'): + return True + try: + if OS_NAME.startswith('linux'): + version_info = subprocess.check_output( + [compiler, '-dumpfullversion']) + if six.PY3: + version_info = version_info.decode() + version = version_info.strip().split('.') + assert len(version) == 3 + # check version compatibility + if tuple(map(int, version)) >= GCC_MINI_VERSION: + return True + else: + warnings.warn( + ABI_INCOMPATIBILITY_WARNING.format( + user_compiler=compiler, version=version_info.strip())) + # TODO(Aurelius84): check version compatibility on windows + elif IS_WINDOWS: + warnings.warn("We don't support Windows now.") + except Exception: + _, error, _ = sys.exc_info() + warnings.warn('Failed to check compiler version for {}: {}'.format( + compiler, error)) + + return False + + +def _expected_compiler_current_platform(): + """ + Returns supported compiler string on current platform + """ + expect_compilers = ['clang', 'clang++'] if OS_NAME.startswith( + 'darwin') else ['gcc', 'g++', 'gnu-c++', 'gnu-cc'] + return expect_compilers + + +def log_v(info, verbose): + """ + Print log information on stdout. + """ + if verbose: + logging.info(info) diff --git a/python/setup.py.in b/python/setup.py.in index 652b81b25c8..f662e21a7be 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -139,6 +139,7 @@ write_distributed_training_mode_py(filename='@PADDLE_BINARY_DIR@/python/paddle/f packages=['paddle', 'paddle.libs', 'paddle.utils', + 'paddle.utils.cpp_extension', 'paddle.dataset', 'paddle.reader', 'paddle.distributed', @@ -378,6 +379,8 @@ def find_files(pattern, root): yield os.path.join(dirpath, filename) headers = ( + list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle')) + + list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/extension')) + list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/framework')) + list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/imperative')) + list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/memory')) + -- GitLab