diff --git a/CMakeLists.txt b/CMakeLists.txt index 740a9cef1dda768cc225bebb9d52a15465177f39..3683bfd31a1252e6b956fdf5fba9d5b51f96699c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -293,6 +293,8 @@ set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") set(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") +add_definitions(-DPADDLE_DLL_EXPORT) + if(ON_INFER) # you can trun off the paddle fluid and inference lib by set ON_INFER=OFF message(STATUS "On inference mode, will take place some specific optimization.") diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 8e1e2f460f4960ec8d4265a06cd57b5aa43017d8..e2807ae392ec5915e2f9d7faa18cec30a8d09e74 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -792,17 +792,15 @@ function(py_test TARGET_NAME) if(WITH_COVERAGE) add_test(NAME ${TARGET_NAME} - COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true - FLAGS_cpu_deterministic=true - PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS} - COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data - ${PYTHON_EXECUTABLE} -m coverage run --branch -p ${py_test_SRCS} ${py_test_ARGS} - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true + FLAGS_cpu_deterministic=true ${py_test_ENVS} + COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data + ${PYTHON_EXECUTABLE} -m coverage run --branch -p ${py_test_SRCS} ${py_test_ARGS} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) else() add_test(NAME ${TARGET_NAME} COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true - FLAGS_cpu_deterministic=true - PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS} + FLAGS_cpu_deterministic=true ${py_test_ENVS} ${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) endif() diff --git a/paddle/fluid/extension/include/all.h b/paddle/fluid/extension/include/all.h index 5aa61f8203e75320cfdf11ed34fe9a7462548c60..e2a3bc38c5f4ab3ee1d126159b7961d979a33c06 100644 --- a/paddle/fluid/extension/include/all.h +++ b/paddle/fluid/extension/include/all.h @@ -18,6 +18,12 @@ limitations under the License. */ #error C++11 or later compatible compiler is required to use Paddle. #endif +#ifdef _WIN32 +#ifndef NOMINMAX +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#endif +#endif + #include "paddle/fluid/extension/include/dispatch.h" #include "paddle/fluid/extension/include/dtype.h" #include "paddle/fluid/extension/include/op_meta_info.h" diff --git a/paddle/fluid/extension/include/dll_decl.h b/paddle/fluid/extension/include/dll_decl.h new file mode 100644 index 0000000000000000000000000000000000000000..3dbea5e6dffc271cd2edc4e399d96e18e259d936 --- /dev/null +++ b/paddle/fluid/extension/include/dll_decl.h @@ -0,0 +1,27 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#if defined(_WIN32) +#ifndef PD_DLL_DECL +#ifdef PADDLE_DLL_EXPORT +#define PD_DLL_DECL __declspec(dllexport) +#else +#define PD_DLL_DECL __declspec(dllimport) +#endif // PADDLE_DLL_EXPORT +#endif // PD_DLL_DECL +#else +#define PD_DLL_DECL +#endif // _WIN32 diff --git a/paddle/fluid/extension/include/op_meta_info.h b/paddle/fluid/extension/include/op_meta_info.h index 920049e2390ed38b12f3466fb35bf37c77dfbbe2..1bc044f647fbae0c4666ecda9e2a2fc3dc8ef214 100644 --- a/paddle/fluid/extension/include/op_meta_info.h +++ b/paddle/fluid/extension/include/op_meta_info.h @@ -14,12 +14,14 @@ limitations under the License. */ #pragma once +#include #include #include #include #include +#include "paddle/fluid/extension/include/dll_decl.h" #include "paddle/fluid/extension/include/tensor.h" /** @@ -31,7 +33,7 @@ limitations under the License. */ namespace paddle { namespace framework { -class OpMetaInfoHelper; +class PD_DLL_DECL OpMetaInfoHelper; } // namespace framework using Tensor = paddle::Tensor; @@ -43,6 +45,26 @@ using Tensor = paddle::Tensor; classname& operator=(const classname&) = delete; \ classname& operator=(classname&&) = delete +#if defined _WIN32 +#define HANDLE_THE_ERROR try { +#define END_HANDLE_THE_ERROR \ + } \ + catch (const std::exception& e) { \ + std::cerr << e.what() << std::endl; \ + throw e; \ + } +#else +#define HANDLE_THE_ERROR +#define END_HANDLE_THE_ERROR +#endif + +#define PD_THROW(err_msg) \ + do { \ + HANDLE_THE_ERROR \ + throw std::runtime_error(err_msg); \ + END_HANDLE_THE_ERROR \ + } while (0) + ///////////////// Util Define and Function //////////////// inline std::string Grad(const std::string& var_name) { @@ -59,6 +81,26 @@ inline std::string Grad(const std::string& var_name) { using KernelFunc = std::vector (*)(std::vector inputs, std::vector attrs); +#define PD_SPECIALIZE_ComputeCallHelper(attr_type) \ + template \ + struct ComputeCallHelper { \ + template \ + static Return Compute(std::vector inputs, \ + std::vector attrs, \ + const PreviousArgs&... pargs) { \ + try { \ + attr_type arg = boost::any_cast(attrs[attr_idx]); \ + return ComputeCallHelper::template Compute( \ + inputs, attrs, pargs..., arg); \ + } catch (boost::bad_any_cast&) { \ + PD_THROW( \ + "Attribute cast error in custom operator. Expected " #attr_type \ + " value."); \ + } \ + } \ + } + template struct TypeTag {}; @@ -92,26 +134,20 @@ struct KernelFuncImpl { } }; - // TODO(chenweihang): add support for attribute input - // int attribute input (not used now) - template - struct ComputeCallHelper { - template - static Return Compute(std::vector inputs, - std::vector attrs, - const PreviousArgs&... pargs) { - try { - int arg = boost::any_cast(attrs[attr_idx]); - return ComputeCallHelper::template Compute( - inputs, attrs, pargs..., arg); - } catch (boost::bad_any_cast&) { - throw std::runtime_error( - "Attribute cast error in custom operator. Expected int value."); - } - } - }; - + PD_SPECIALIZE_ComputeCallHelper(bool); + PD_SPECIALIZE_ComputeCallHelper(int); + PD_SPECIALIZE_ComputeCallHelper(float); + PD_SPECIALIZE_ComputeCallHelper(int64_t); + PD_SPECIALIZE_ComputeCallHelper(std::string); + PD_SPECIALIZE_ComputeCallHelper(std::vector); + PD_SPECIALIZE_ComputeCallHelper(std::vector); + PD_SPECIALIZE_ComputeCallHelper(std::vector); + PD_SPECIALIZE_ComputeCallHelper(std::vector); + // TODO(chenweihang): support other attribute type if needed. + // Why not support other attribute type here? + // - boost::blank, std::vector and std::vector + // are not used in op + // - BlockDesc* and std::vector are used in framework // end: base template template struct ComputeCallHelper> { @@ -220,13 +256,26 @@ struct InferDtypeFuncImpl { ////////////////////// Op Meta Info ////////////////////// -class OpMetaInfo { +class PD_DLL_DECL OpMetaInfo { public: explicit OpMetaInfo(const std::string& op_name) : name_(op_name) {} + + // format: {"", "", ...} OpMetaInfo& Inputs(std::vector&& inputs); + + // format: {"", "", ...} OpMetaInfo& Outputs(std::vector&& outputs); + + // format: {":", ":", ...} + OpMetaInfo& Attrs(std::vector&& attrs); + + // format: PD_KERNEL(...) OpMetaInfo& SetKernelFn(KernelFunc&& func); + + // format: PD_INFER_SHAPE(...) OpMetaInfo& SetInferShapeFn(InferShapeFunc&& func); + + // format: PD_INFER_DTYPE(...) OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func); private: @@ -246,7 +295,7 @@ class OpMetaInfo { //////////////// Op Meta Info Map ///////////////// -class OpMetaInfoMap { +class PD_DLL_DECL OpMetaInfoMap { public: // this function's impl should keep in header file. // if move to cc file, meta info can not be added @@ -270,14 +319,15 @@ class OpMetaInfoMap { //////////////// Op Meta Info Builder ///////////////// -class OpMetaInfoBuilder { +class PD_DLL_DECL OpMetaInfoBuilder { public: explicit OpMetaInfoBuilder(std::string&& name); OpMetaInfoBuilder& Inputs(std::vector&& inputs); OpMetaInfoBuilder& Outputs(std::vector&& outputs); - OpMetaInfoBuilder& SetKernelFn(KernelFunc&& func); - OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc&& func); - OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc&& func); + OpMetaInfoBuilder& Attrs(std::vector&& attrs); + OpMetaInfoBuilder& SetKernelFn(KernelFunc func); + OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func); + OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func); OpMetaInfoBuilder& SetBackwardOp(const std::string& bwd_op_name); private: @@ -317,8 +367,12 @@ void LoadCustomOperatorLib(const std::string& dso_name); extern "C" { #endif +#if defined(_WIN32) // C-API to get global OpMetaInfoMap. -paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap(); +__declspec(dllexport) inline paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap() { + return paddle::OpMetaInfoMap::Instance(); +} +#endif // _WIN32 #ifdef __cplusplus } diff --git a/paddle/fluid/extension/include/tensor.h b/paddle/fluid/extension/include/tensor.h index a5ce0d1a5858b0422e6187bf2ca0e7198b87ed57..47af4dc70a15ffde980daa65ce769f5e2371058c 100644 --- a/paddle/fluid/extension/include/tensor.h +++ b/paddle/fluid/extension/include/tensor.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/extension/include/dll_decl.h" #include "paddle/fluid/extension/include/dtype.h" #include "paddle/fluid/extension/include/place.h" @@ -23,7 +24,7 @@ namespace paddle { namespace framework { class CustomTensorUtils; } // namespace framework -class Tensor { +class PD_DLL_DECL Tensor { public: /// \brief Construct a Tensor on target Place for CustomOp. /// Generally it's only used for user to create Tensor. diff --git a/paddle/fluid/extension/src/op_meta_info.cc b/paddle/fluid/extension/src/op_meta_info.cc index f31723e5ac83675884f950c1c4e8917c220bc474..d362282b8d9d24c287e51643d3aca72d9fd36c50 100644 --- a/paddle/fluid/extension/src/op_meta_info.cc +++ b/paddle/fluid/extension/src/op_meta_info.cc @@ -32,6 +32,10 @@ OpMetaInfo& OpMetaInfo::Outputs(std::vector&& outputs) { outputs_ = std::forward>(outputs); return *this; } +OpMetaInfo& OpMetaInfo::Attrs(std::vector&& attrs) { + attrs_ = std::forward>(attrs); + return *this; +} OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) { kernel_fn_ = std::forward(func); return *this; @@ -78,17 +82,22 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs( return *this; } -OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc&& func) { +OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector&& attrs) { + info_ptr_->Attrs(std::forward>(attrs)); + return *this; +} + +OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) { info_ptr_->SetKernelFn(std::forward(func)); return *this; } -OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc&& func) { +OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) { info_ptr_->SetInferShapeFn(std::forward(func)); return *this; } -OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc&& func) { +OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) { info_ptr_->SetInferDtypeFn(std::forward(func)); return *this; } @@ -114,10 +123,17 @@ void LoadCustomOperatorLib(const std::string& dso_name) { } } // namespace paddle +#ifdef __cplusplus extern "C" { +#endif +#ifndef _WIN32 +// C-API to get global OpMetaInfoMap. paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap() { return paddle::OpMetaInfoMap::Instance(); } +#endif +#ifdef __cplusplus } // end extern "C" +#endif diff --git a/paddle/fluid/extension/src/tensor.cc b/paddle/fluid/extension/src/tensor.cc index 11d505a5aab4f4d33926162445cffd3f5ca4db32..39ed27486411080c167c19f02e9adebb2c2c1d90 100644 --- a/paddle/fluid/extension/src/tensor.cc +++ b/paddle/fluid/extension/src/tensor.cc @@ -207,73 +207,87 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const { return target; } -template Tensor Tensor::copy_to( +template PD_DLL_DECL Tensor +Tensor::copy_to(const PlaceType &target_place) const; +template PD_DLL_DECL Tensor Tensor::copy_to( const PlaceType &target_place) const; -template Tensor Tensor::copy_to( +template PD_DLL_DECL Tensor Tensor::copy_to( const PlaceType &target_place) const; -template Tensor Tensor::copy_to( +template PD_DLL_DECL Tensor Tensor::copy_to( const PlaceType &target_place) const; -template Tensor Tensor::copy_to( - const PlaceType &target_place) const; -template Tensor Tensor::copy_to(const PlaceType &target_place) const; -template Tensor Tensor::copy_to(const PlaceType &target_place) const; -template Tensor Tensor::copy_to(const PlaceType &target_place) const; -template Tensor Tensor::copy_to(const PlaceType &target_place) const; -template Tensor Tensor::copy_to(const PlaceType &target_place) const; -template Tensor Tensor::copy_to(const PlaceType &target_place) const; -template Tensor Tensor::copy_to(const PlaceType &target_place) const; -template Tensor Tensor::copy_to(const PlaceType &target_place) const; +template PD_DLL_DECL Tensor +Tensor::copy_to(const PlaceType &target_place) const; +template PD_DLL_DECL Tensor +Tensor::copy_to(const PlaceType &target_place) const; +template PD_DLL_DECL Tensor +Tensor::copy_to(const PlaceType &target_place) const; +template PD_DLL_DECL Tensor +Tensor::copy_to(const PlaceType &target_place) const; +template PD_DLL_DECL Tensor +Tensor::copy_to(const PlaceType &target_place) const; +template PD_DLL_DECL Tensor +Tensor::copy_to(const PlaceType &target_place) const; +template PD_DLL_DECL Tensor +Tensor::copy_to(const PlaceType &target_place) const; +template PD_DLL_DECL Tensor +Tensor::copy_to(const PlaceType &target_place) const; -template float *Tensor::data() const; -template double *Tensor::data() const; -template int64_t *Tensor::data() const; -template int32_t *Tensor::data() const; -template uint8_t *Tensor::data() const; -template int8_t *Tensor::data() const; -template paddle::platform::float16 *Tensor::data() - const; -template paddle::platform::bfloat16 *Tensor::data() - const; -template paddle::platform::complex128 * +template PD_DLL_DECL float *Tensor::data() const; +template PD_DLL_DECL double *Tensor::data() const; +template PD_DLL_DECL int64_t *Tensor::data() const; +template PD_DLL_DECL int32_t *Tensor::data() const; +template PD_DLL_DECL uint8_t *Tensor::data() const; +template PD_DLL_DECL int8_t *Tensor::data() const; +template PD_DLL_DECL paddle::platform::float16 * +Tensor::data() const; +template PD_DLL_DECL paddle::platform::bfloat16 * +Tensor::data() const; +template PD_DLL_DECL paddle::platform::complex128 * Tensor::data() const; -template paddle::platform::complex64 * +template PD_DLL_DECL paddle::platform::complex64 * Tensor::data() const; -template int16_t *Tensor::data() const; -template bool *Tensor::data() const; +template PD_DLL_DECL int16_t *Tensor::data() const; +template PD_DLL_DECL bool *Tensor::data() const; -template float *Tensor::mutable_data(); -template double *Tensor::mutable_data(); -template int64_t *Tensor::mutable_data(); -template int32_t *Tensor::mutable_data(); -template uint8_t *Tensor::mutable_data(); -template int8_t *Tensor::mutable_data(); -template paddle::platform::float16 * +template PD_DLL_DECL float *Tensor::mutable_data(); +template PD_DLL_DECL double *Tensor::mutable_data(); +template PD_DLL_DECL int64_t *Tensor::mutable_data(); +template PD_DLL_DECL int32_t *Tensor::mutable_data(); +template PD_DLL_DECL uint8_t *Tensor::mutable_data(); +template PD_DLL_DECL int8_t *Tensor::mutable_data(); +template PD_DLL_DECL paddle::platform::float16 * Tensor::mutable_data(); -template paddle::platform::bfloat16 * +template PD_DLL_DECL paddle::platform::bfloat16 * Tensor::mutable_data(); -template paddle::platform::complex128 * +template PD_DLL_DECL paddle::platform::complex128 * Tensor::mutable_data(); -template paddle::platform::complex64 * +template PD_DLL_DECL paddle::platform::complex64 * Tensor::mutable_data(); -template int16_t *Tensor::mutable_data(); -template bool *Tensor::mutable_data(); +template PD_DLL_DECL int16_t *Tensor::mutable_data(); +template PD_DLL_DECL bool *Tensor::mutable_data(); -template float *Tensor::mutable_data(const PlaceType &place); -template double *Tensor::mutable_data(const PlaceType &place); -template int64_t *Tensor::mutable_data(const PlaceType &place); -template int32_t *Tensor::mutable_data(const PlaceType &place); -template uint8_t *Tensor::mutable_data(const PlaceType &place); -template int8_t *Tensor::mutable_data(const PlaceType &place); -template paddle::platform::float16 * +template PD_DLL_DECL float *Tensor::mutable_data(const PlaceType &place); +template PD_DLL_DECL double *Tensor::mutable_data( + const PlaceType &place); +template PD_DLL_DECL int64_t *Tensor::mutable_data( + const PlaceType &place); +template PD_DLL_DECL int32_t *Tensor::mutable_data( + const PlaceType &place); +template PD_DLL_DECL uint8_t *Tensor::mutable_data( + const PlaceType &place); +template PD_DLL_DECL int8_t *Tensor::mutable_data( + const PlaceType &place); +template PD_DLL_DECL paddle::platform::float16 * Tensor::mutable_data(const PlaceType &place); -template paddle::platform::bfloat16 * +template PD_DLL_DECL paddle::platform::bfloat16 * Tensor::mutable_data(const PlaceType &place); -template paddle::platform::complex128 * +template PD_DLL_DECL paddle::platform::complex128 * Tensor::mutable_data(const PlaceType &place); -template paddle::platform::complex64 * +template PD_DLL_DECL paddle::platform::complex64 * Tensor::mutable_data(const PlaceType &place); -template int16_t *Tensor::mutable_data(const PlaceType &place); -template bool *Tensor::mutable_data(const PlaceType &place); +template PD_DLL_DECL int16_t *Tensor::mutable_data( + const PlaceType &place); +template PD_DLL_DECL bool *Tensor::mutable_data(const PlaceType &place); std::vector Tensor::shape() const { GET_CASTED_TENSOR diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 2f4dcf465de780ac963680140b9d02a0324531b1..4074218c7ae6f89de5661a5bc4380f8f15790c55 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -321,9 +321,9 @@ message(STATUS "branch: ${PADDLE_BRANCH}") configure_file(commit.h.in commit.h) -cc_library(custom_tensor SRCS ../extension/src/tensor.cc DEPS lod_tensor) +cc_library(custom_tensor SRCS ../extension/src/tensor.cc DEPS lod_tensor memory enforce) cc_library(op_meta_info SRCS ../extension/src/op_meta_info.cc DEPS custom_tensor) -cc_library(custom_operator SRCS custom_operator.cc DEPS operator op_registry device_context dynamic_loader custom_tensor op_meta_info) +cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper custom_tensor op_meta_info) cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog) set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator) @@ -346,9 +346,12 @@ if (LINUX) endif() if (WIN32) + set(FLUID_FRAMEWORK_IMPORT_LIB + ${PADDLE_BINARY_DIR}/paddle/fluid/framework/${CMAKE_BUILD_TYPE}/paddle_framework.lib + CACHE INTERNAL "Fluid framework lib") set(FLUID_FRAMEWORK_SHARED_LIB - ${PADDLE_BINARY_DIR}/paddle/fluid/framework/libpaddle_framework.dll - CACHE INTERNAL "Fluid framework lib") + ${PADDLE_BINARY_DIR}/paddle/fluid/framework/${CMAKE_BUILD_TYPE}/paddle_framework.dll + CACHE INTERNAL "Fluid framework dll") endif() if(APPLE) @@ -359,3 +362,37 @@ endif() if(WITH_TESTING) set_tests_properties(selected_rows_test PROPERTIES TIMEOUT 120) endif() + +# New custom op extension mechanism related + +# if not deps `layer`, will cause: undefined symbol: _ZN6paddle10imperative7VarBase9name_set_ +set(PADDLE_CUSTOM_OP_MODULES custom_tensor op_meta_info custom_operator layer) + +cc_library(paddle_custom_op_shared + SHARED SRCS custom_operator.cc ../extension/src/tensor.cc ../extension/src/op_meta_info.cc + ${CMAKE_SOURCE_DIR}/paddle/fluid/imperative/layer.cc + DEPS ${PADDLE_CUSTOM_OP_MODULES}) +get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) +set_target_properties(paddle_custom_op_shared PROPERTIES OUTPUT_NAME paddle_custom_op) +target_link_libraries(paddle_custom_op_shared ${os_dependency_modules}) + +if (LINUX) + set(PADDLE_CUSTOM_OP_SHARED_LIB + ${PADDLE_BINARY_DIR}/paddle/fluid/framework/libpaddle_custom_op.so + CACHE INTERNAL "Paddle custom op lib") +endif() + +if (WIN32) + set(PADDLE_CUSTOM_OP_SHARED_LIB + ${PADDLE_BINARY_DIR}/paddle/fluid/framework/${CMAKE_BUILD_TYPE}/paddle_custom_op.lib + CACHE INTERNAL "Paddle custom op lib") + set(PADDLE_CUSTOM_OP_SHARED_LIB + ${PADDLE_BINARY_DIR}/paddle/fluid/framework/${CMAKE_BUILD_TYPE}/paddle_custom_op.dll + CACHE INTERNAL "Paddle custom op dll") +endif() + +if(APPLE) + set(PADDLE_CUSTOM_OP_SHARED_LIB + ${PADDLE_BINARY_DIR}/paddle/fluid/framework/paddle_custom_op.dylib + CACHE INTERNAL "Paddle custom op lib") +endif() diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 1e2a77e915dea4e19046c68e176ba49637ece9ac..03a8cc366e7f2e8bb3baa2dd65ee609533cb8137 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -73,6 +73,24 @@ inline bool IsMemberOf(const std::vector& vec, return std::find(vec.cbegin(), vec.cend(), name) != vec.cend(); } +std::vector ParseAttrStr(const std::string& attr) { + auto split_pos = attr.find_first_of(":"); + PADDLE_ENFORCE_NE(split_pos, std::string::npos, + platform::errors::InvalidArgument( + "Invalid attribute string format. Attribute string " + "format is `:`.")); + + std::vector rlt; + // 1. name + rlt.emplace_back(string::trim_spaces(attr.substr(0, split_pos))); + // 2. type + rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1))); + + VLOG(1) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1]; + + return rlt; +} + } // namespace detail ////////////////// Kernel Define //////////////////// @@ -81,7 +99,8 @@ inline bool IsMemberOf(const std::vector& vec, static void RunKernelFunc(const framework::ExecutionContext& ctx, const paddle::KernelFunc& func, const std::vector& inputs, - const std::vector& outputs) { + const std::vector& outputs, + const std::vector& attrs) { VLOG(1) << "Custom Operator: Start run KernelFunc."; std::vector custom_ins; for (auto& in_name : inputs) { @@ -98,10 +117,43 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, custom_ins.emplace_back(custom_in); } - std::vector attrs; + std::vector custom_attrs; + for (auto& attr_str : attrs) { + auto attr_name_and_type = detail::ParseAttrStr(attr_str); + auto attr_name = attr_name_and_type[0]; + auto attr_type_str = attr_name_and_type[1]; + if (attr_type_str == "bool") { + custom_attrs.emplace_back(ctx.Attr(attr_name)); + } else if (attr_type_str == "int") { + custom_attrs.emplace_back(ctx.Attr(attr_name)); + } else if (attr_type_str == "float") { + custom_attrs.emplace_back(ctx.Attr(attr_name)); + } else if (attr_type_str == "int64_t") { + custom_attrs.emplace_back(ctx.Attr(attr_name)); + } else if (attr_type_str == "std::string") { + custom_attrs.emplace_back(ctx.Attr(attr_name)); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back(ctx.Attr>(attr_name)); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back(ctx.Attr>(attr_name)); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back(ctx.Attr>(attr_name)); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back(ctx.Attr>(attr_name)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported `%s` type value as custom attribute now. " + "Supported data types include `bool`, `int`, `float`, " + "`int64_t`, `std::string`, `std::vector`, " + "`std::vector`, `std::vector, " + "`std::vector`, Please check whether " + "the attribute data type and data type string are matched.", + attr_type_str)); + } + } VLOG(1) << "Run ComputeFunc."; - auto outs = func(custom_ins, attrs); + auto outs = func(custom_ins, custom_attrs); VLOG(1) << "Custom Operator: Share outputs into ExecutionContext."; for (size_t i = 0; i < outputs.size(); ++i) { @@ -164,7 +216,51 @@ class CustomOpMaker : public OpProtoAndCheckerMaker { for (auto& out_name : outputs_) { AddOutput(out_name, "The output " + out_name + "of Custom Operator."); } - // TODO(chenweihang): support attrs in later PR + for (auto& attr : attrs_) { + auto attr_name_and_type = detail::ParseAttrStr(attr); + auto attr_name = attr_name_and_type[0]; + auto attr_type_str = attr_name_and_type[1]; + if (attr_type_str == "bool") { + AddAttr(attr_name, "custom operator bool attribute.") + .SetDefault(false); + } else if (attr_type_str == "int") { + AddAttr(attr_name, "custom operator int attribute.").SetDefault(1); + } else if (attr_type_str == "float") { + AddAttr(attr_name, "custom operator float attribute.") + .SetDefault(1.0f); + } else if (attr_type_str == "int64_t") { + AddAttr(attr_name, "custom operator int64_t attribute.") + .SetDefault(1); + } else if (attr_type_str == "std::string") { + AddAttr(attr_name, "custom operator int attribute.") + .SetDefault(""); + } else if (attr_type_str == "std::vector") { + AddAttr>(attr_name, + "custom operator std::vector attribute.") + .SetDefault({}); + } else if (attr_type_str == "std::vector") { + AddAttr>( + attr_name, "custom operator std::vector attribute.") + .SetDefault({}); + } else if (attr_type_str == "std::vector") { + AddAttr>( + attr_name, "custom operator std::vector attribute.") + .SetDefault({}); + } else if (attr_type_str == "std::vector") { + AddAttr>( + attr_name, "custom operator std::vector attribute.") + .SetDefault({}); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported `%s` type value as custom attribute now. " + "Supported data types include `bool`, `int`, `float`, " + "`int64_t`, `std::string`, `std::vector`, " + "`std::vector`, `std::vector, " + "`std::vector`, Please check whether " + "the attribute data type and data type string are matched.", + attr_type_str)); + } + } AddComment(R"DOC( Custom Operator. @@ -227,7 +323,7 @@ class CustomGradOpMaker : public SingleGradOpMaker { VLOG(1) << "Custom Operator: GradOpDescMaker - output: " << out_name; grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name))); } - // TODO(chenweihang): support attrs in later PR + grad_op->SetAttrMap(this->Attrs()); } private: @@ -287,7 +383,7 @@ class CustomGradOpMaker VLOG(1) << "Custom Operator: GradOpBaseMaker - output: " << out_name; grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name))); } - // TODO(chenweihang): support attrs in later PR + grad_op->SetAttrMap(this->Attrs()); } private: @@ -303,21 +399,24 @@ void RegisterOperatorKernelWithPlace(const std::string& name, const proto::VarType::Type type, const PlaceType& place, const std::vector& inputs, - const std::vector& outputs) { + const std::vector& outputs, + const std::vector& attrs) { OpKernelType key(type, CustomTensorUtils::ConvertEnumPlaceToInnerPlace(place)); VLOG(1) << "Custom Operator: op kernel key: " << key; OperatorWithKernel::AllOpKernels()[name][key] = - [kernel_func, inputs, outputs](const framework::ExecutionContext& ctx) { + [kernel_func, inputs, outputs, + attrs](const framework::ExecutionContext& ctx) { VLOG(1) << "Custom Operator: run custom kernel func in lambda."; - RunKernelFunc(ctx, kernel_func, inputs, outputs); + RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs); }; } void RegisterOperatorKernel(const std::string& name, const paddle::KernelFunc& kernel_func, const std::vector& inputs, - const std::vector& outputs) { + const std::vector& outputs, + const std::vector& attrs) { VLOG(1) << "Custom Operator: op name in kernel: " << name; // NOTE [ Dummy Op Kernel Key ] // TODO(chenweihang): Because execute engine need get device context based @@ -325,9 +424,11 @@ void RegisterOperatorKernel(const std::string& name, // device. But this is not entirely correct, if user only give a cpu kernel, // but call api in gpu device, it will cause error. RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, - PlaceType::kCPU, inputs, outputs); + PlaceType::kCPU, inputs, outputs, attrs); +#ifdef PADDLE_WITH_CUDA RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, - PlaceType::kGPU, inputs, outputs); + PlaceType::kGPU, inputs, outputs, attrs); +#endif } void RegisterOperatorWithMetaInfo( @@ -350,6 +451,8 @@ void RegisterOperatorWithMetaInfo( << string::join_strings(op_inputs, ','); VLOG(1) << "Custom Operator: forward, op outputs: " << string::join_strings(op_outputs, ','); + VLOG(1) << "Custom Operator: forward, op attrs: " + << string::join_strings(op_attrs, ','); // Op info.creator_ = [](const std::string& op_name, const VariableNameMap& inputs, @@ -426,7 +529,7 @@ void RegisterOperatorWithMetaInfo( }; // Kernel func - RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs); + RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs); // If grad op or double grad op exists std::string cur_op_name = op_name; @@ -436,6 +539,7 @@ void RegisterOperatorWithMetaInfo( auto& grad_op_name = OpMetaInfoHelper::GetOpName(cur_grad_op); auto& grad_op_inputs = OpMetaInfoHelper::GetInputs(cur_grad_op); auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op); + auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op); auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op); VLOG(1) << "Custom Operator: backward, op name: " << grad_op_name; @@ -489,7 +593,7 @@ void RegisterOperatorWithMetaInfo( // Kernel func RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs, - grad_op_outputs); + grad_op_outputs, grad_op_attrs); // update current info OpInfoMap::Instance().Insert(cur_op_name, info); diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index d95b812e930c7cf2d2135fd4f9d098fa0e76a7c8..8a017385efb0419da18ecb5d7355a2cf20e153f6 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -378,9 +378,6 @@ void* GetOpDsoHandle(const std::string& dso_name) { #if defined(__APPLE__) || defined(__OSX__) PADDLE_THROW(platform::errors::Unimplemented( "Create custom cpp op outside framework do not support Apple.")); -#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - PADDLE_THROW(platform::errors::Unimplemented( - "Create custom cpp op outside framework do not support Windows.")); #else return GetDsoHandleFromSearchPath(FLAGS_op_dir, dso_name); #endif diff --git a/paddle/scripts/paddle_build.bat b/paddle/scripts/paddle_build.bat index f3cc61e8fbf038c660fc6e643851ad4205affd05..5f12b428ec83c1b6bd4bc234d5c0e3fad8d2ac59 100644 --- a/paddle/scripts/paddle_build.bat +++ b/paddle/scripts/paddle_build.bat @@ -114,23 +114,25 @@ rem ------pre install python requirement---------- where python where pip pip install wheel --user -pip install -r %work_dir%\python\requirements.txt --user pip install -r %work_dir%\python\unittest_py\requirements.txt --user +pip install -r %work_dir%\python\requirements.txt --user + if %ERRORLEVEL% NEQ 0 ( echo pip install requirements.txt failed! exit /b 7 ) rem ------pre install clcache and init config---------- -pip install clcache --user +rem pip install clcache --user +pip uninstall -y clcache :: set USE_CLCACHE to enable clcache -set USE_CLCACHE=1 +rem set USE_CLCACHE=1 :: In some scenarios, CLCACHE_HARDLINK can save one file copy. -set CLCACHE_HARDLINK=1 +rem set CLCACHE_HARDLINK=1 :: If it takes more than 1000s to obtain the right to use the cache, an error will be reported -set CLCACHE_OBJECT_CACHE_TIMEOUT_MS=1000000 +rem set CLCACHE_OBJECT_CACHE_TIMEOUT_MS=1000000 :: set maximum cache size to 20G -clcache.exe -M 21474836480 +rem clcache.exe -M 21474836480 rem ------show summary of current environment---------- python %work_dir%\tools\summary_env.py @@ -194,11 +196,28 @@ set start=%start:~4,10% @ECHO ON if not defined CUDA_TOOLKIT_ROOT_DIR set CUDA_TOOLKIT_ROOT_DIR=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0 -set PATH=%CUDA_TOOLKIT_ROOT_DIR%\bin;%CUDA_TOOLKIT_ROOT_DIR%\libnvvp;%PATH% -set CUDA_PATH=%CUDA_TOOLKIT_ROOT_DIR% +set PATH=%TENSORRT_ROOT:/=\%\lib;%CUDA_TOOLKIT_ROOT_DIR%\bin;%CUDA_TOOLKIT_ROOT_DIR%\libnvvp;%PATH% rem ------set third_party cache dir------ +: clear third party cache every once in a while +for /F %%# in ('wmic os get localdatetime^|findstr 20') do set datetime=%%# +set day_now=%datetime:~6,2% +set day_before=-1 +set /p day_before=< %cache_dir%\day.txt +if %day_now% NEQ %day_before% ( + echo %day_now% > %cache_dir%\day.txt + type %cache_dir%\day.txt + if %day_now% EQU 25 ( + rmdir %cache_dir%\third_party_GPU/ /s/q + rmdir %cache_dir%\third_party/ /s/q + ) + if %day_now% EQU 10 ( + rmdir %cache_dir%\third_party_GPU/ /s/q + rmdir %cache_dir%\third_party/ /s/q + ) +) + if "%WITH_TPCACHE%"=="OFF" ( set THIRD_PARTY_PATH=%work_dir:\=/%/build/third_party goto :cmake_impl @@ -263,6 +282,9 @@ echo Build third_party successfully! set build_times=1 :build_paddle +:: reset clcache zero stats for collect PR's actual hit rate +rem clcache.exe -z + echo Build Paddle the %build_times% time: if "%WITH_CLCACHE%"=="OFF" ( msbuild /m:%PARALLEL_PROJECT_COUNT% /p:Configuration=Release /verbosity:minimal paddle.sln @@ -281,6 +303,11 @@ if %ERRORLEVEL% NEQ 0 ( ) echo Build Paddle successfully! +echo 0 > %cache_dir%\error_code.txt +type %cache_dir%\error_code.txt + +:: ci will collect clcache hit rate +rem goto :collect_clcache_hits goto:eof @@ -319,13 +346,14 @@ set /p PADDLE_WHL_FILE_WIN=< whl_file.txt @ECHO ON pip uninstall -y paddlepaddle pip uninstall -y paddlepaddle-gpu -pip install -U %PADDLE_WHL_FILE_WIN% --user +pip install %PADDLE_WHL_FILE_WIN% --user if %ERRORLEVEL% NEQ 0 ( call paddle_winci\Scripts\deactivate.bat 2>NUL echo pip install whl package failed! exit /b 1 ) + set CUDA_VISIBLE_DEVICES=0 python %work_dir%\paddle\scripts\installation_validate.py goto:eof @@ -383,7 +411,7 @@ if "%WITH_GPU%"=="ON" ( :parallel_test_base_gpu echo ======================================== -echo Running GPU unit tests... +echo Running GPU unit tests in parallel way ... echo ======================================== setlocal enabledelayedexpansion @@ -451,6 +479,7 @@ goto:eof echo ======================================== echo Running CPU unit tests in parallel way ... echo ======================================== + ctest.exe -E "(%disable_ut_quickly%)" -LE %nightly_label% --output-on-failure -C Release -j 8 --repeat until-pass:4 after-timeout:4 goto:eof @@ -622,6 +651,7 @@ taskkill /f /im vctip.exe 2>NUL taskkill /f /im cvtres.exe 2>NUL taskkill /f /im rc.exe 2>NUL wmic process where name="op_function_generator.exe" call terminate 2>NUL +wmic process where name="python.exe" call terminate 2>NUL taskkill /f /im python.exe 2>NUL echo 0 > %cache_dir%\error_code.txt type %cache_dir%\error_code.txt diff --git a/python/paddle/fluid/tests/CMakeLists.txt b/python/paddle/fluid/tests/CMakeLists.txt index bee49945f0074f2e8dc1af9662878ec495d25644..899d6ae7f0e314ee05ae67ef5639d9f79410a38f 100644 --- a/python/paddle/fluid/tests/CMakeLists.txt +++ b/python/paddle/fluid/tests/CMakeLists.txt @@ -9,7 +9,9 @@ endforeach() add_subdirectory(unittests) add_subdirectory(book) -if(NOT APPLE AND NOT WIN32) +# TODO: support New Custom OP on Mac +if(NOT APPLE) add_subdirectory(custom_op) endif() + set_tests_properties(test_beam_search_decoder PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index df1dc75a38c8362ce33af5854a5c2f3c14635609..3f85f4ef50a223949ef60678b61e97be29aea471 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -1,4 +1,47 @@ -if (WITH_GPU) +# New custom OP can support Windows/Linux now +if(WITH_GPU) + # 'test_custom_relu_op_setup/jit' compile .cc and .cu file + py_test(test_custom_relu_op_setup SRCS test_custom_relu_op_setup.py) + py_test(test_custom_relu_op_jit SRCS test_custom_relu_op_jit.py) + + # Compiling shared library will cost some time, but running process is very fast. + set_tests_properties(test_custom_relu_op_setup PROPERTIES TIMEOUT 250) + set_tests_properties(test_custom_relu_op_jit PROPERTIES TIMEOUT 180) +endif() + +py_test(test_sysconfig SRCS test_sysconfig.py) + +# 'test_dispatch' compile .cc file +py_test(test_dispatch_jit SRCS test_dispatch_jit.py) +set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 120) + +py_test(test_multi_out_jit SRCS test_multi_out_jit.py) +set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120) + +py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py) +set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120) + +if(NOT LINUX) + return() +endif() + +# TODO(zhouwei): support test_check_abi and abi check on Windows +py_test(test_check_abi SRCS test_check_abi.py) + +# Old custom OP only support Linux, only run on Linux +py_test(test_custom_op SRCS test_custom_op.py) +py_test(test_jit_load SRCS test_jit_load.py) +py_test(test_setup_install SRCS test_setup_install.py) +py_test(test_setup_build SRCS test_setup_build.py) + +set_tests_properties(test_jit_load PROPERTIES TIMEOUT 180) +set_tests_properties(test_setup_install PROPERTIES TIMEOUT 180) +set_tests_properties(test_setup_build PROPERTIES TIMEOUT 180) + + +if(WITH_ROCM) + hip_library(relu_op_shared SHARED SRCS relu_op.cc relu_op.cu DEPS paddle_framework_shared) +elseif(WITH_GPU) nv_library(relu_op_shared SHARED SRCS relu_op.cc relu_op.cu DEPS paddle_framework_shared) else() cc_library(relu_op_shared SHARED SRCS relu_op.cc DEPS paddle_framework_shared) @@ -16,19 +59,3 @@ get_target_property(TARGET_LIBRARIES relu_op_shared LINK_LIBRARIES) LIST(REMOVE_ITEM TARGET_LIBRARIES glog) LIST(REMOVE_ITEM TARGET_LIBRARIES gflags) set_property(TARGET relu_op_shared PROPERTY LINK_LIBRARIES ${TARGET_LIBRARIES} ) - -file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") -string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") - -foreach(src ${TEST_OPS}) - py_test(${src} SRCS ${src}.py) -endforeach() - -# Compiling .so will cost some time, but running process is very fast. -set_tests_properties(test_jit_load PROPERTIES TIMEOUT 180) -set_tests_properties(test_setup_install PROPERTIES TIMEOUT 180) -set_tests_properties(test_setup_build PROPERTIES TIMEOUT 180) -set_tests_properties(test_dispatch PROPERTIES TIMEOUT 180) - -set_tests_properties(test_simple_custom_op_setup PROPERTIES TIMEOUT 250) -set_tests_properties(test_simple_custom_op_jit PROPERTIES TIMEOUT 180) diff --git a/python/paddle/fluid/tests/custom_op/attr_test_op.cc b/python/paddle/fluid/tests/custom_op/attr_test_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..474d3d2d4e2b3b566620a11d41564fb662bd35e3 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/attr_test_op.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/extension.h" + +template +void assign_cpu_kernel(const data_t* x_data, + data_t* out_data, + int64_t x_numel) { + for (int i = 0; i < x_numel; ++i) { + out_data[i] = x_data[i]; + } +} + +std::vector AttrTestForward( + const paddle::Tensor& x, + bool bool_attr, + int int_attr, + float float_attr, + int64_t int64_attr, + std::string str_attr, + std::vector int_vec_attr, + std::vector float_vec_attr, + std::vector int64_vec_attr, + std::vector str_vec_attr) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_FLOATING_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(), x.size()); + })); + + // Check attrs value + if (bool_attr != true) { + throw std::runtime_error("bool_attr value error."); + } + if (int_attr != 10) { + throw std::runtime_error("int_attr value error."); + } + if (std::abs(float_attr - 3.14) > 1e-6) { + throw std::runtime_error("float_attr value error."); + } + if (int64_attr != 10000000000) { + throw std::runtime_error("int64_attr value error."); + } + if (str_attr != "StrAttr") { + throw std::runtime_error("str_attr value error."); + } + + if (int_vec_attr.size() != 3) { + throw std::runtime_error("int_vec_attr size error."); + } else { + for (auto& value : int_vec_attr) { + if (value != 10) { + throw std::runtime_error("int_vec_attr value error."); + } + } + } + + if (float_vec_attr.size() != 3) { + throw std::runtime_error("float_vec_attr size error."); + } else { + for (auto& value : float_vec_attr) { + if (std::abs(value - 3.14) > 1e-6) { + throw std::runtime_error("float_vec_attr value error."); + } + } + } + + if (int64_vec_attr.size() != 3) { + throw std::runtime_error("int64_vec_attr size error."); + } else { + for (auto& value : int64_vec_attr) { + if (value != 10000000000) { + throw std::runtime_error("int64_vec_attr value error."); + } + } + } + + if (str_vec_attr.size() != 3) { + throw std::runtime_error("str_vec_attr size error."); + } else { + for (auto& value : str_vec_attr) { + if (value != "StrAttr") { + throw std::runtime_error("str_vec_attr value error."); + } + } + } + + return {out}; +} + +// The attrs of backward op must be the subset of attrs of forward op +std::vector AttrTestBackward( + const paddle::Tensor& grad_out, + int int_attr, + std::vector float_vec_attr, + std::vector str_vec_attr) { + auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU); + grad_x.reshape(grad_out.shape()); + + PD_DISPATCH_FLOATING_TYPES(grad_out.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + grad_out.data(), + grad_x.mutable_data(), + grad_out.size()); + })); + + if (int_attr != 10) { + throw std::runtime_error("int_attr value error."); + } + + if (float_vec_attr.size() != 3) { + throw std::runtime_error("float_vec_attr size error."); + } else { + for (auto& value : float_vec_attr) { + if (std::abs(value - 3.14) > 1e-6) { + throw std::runtime_error("float_vec_attr value error."); + } + } + } + + if (str_vec_attr.size() != 3) { + throw std::runtime_error("str_vec_attr size error."); + } else { + for (auto& value : str_vec_attr) { + if (value != "StrAttr") { + throw std::runtime_error("str_vec_attr value error."); + } + } + } + + return {grad_x}; +} + +std::vector> InferShape(std::vector x_shape) { + return {x_shape}; +} + +std::vector InferDType(paddle::DataType x_dtype) { + return {x_dtype}; +} + +PD_BUILD_OP("attr_test") + .Inputs({"X"}) + .Outputs({"Out"}) + .Attrs({"bool_attr: bool", + "int_attr: int", + "float_attr: float", + "int64_attr: int64_t", + "str_attr: std::string", + "int_vec_attr: std::vector", + "float_vec_attr: std::vector", + "int64_vec_attr: std::vector", + "str_vec_attr: std::vector"}) + .SetKernelFn(PD_KERNEL(AttrTestForward)) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDType)) + .SetBackwardOp("attr_test_grad") + .Inputs({paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .Attrs({"int_attr: int", + "float_vec_attr: std::vector", + "str_vec_attr: std::vector"}) + .SetKernelFn(PD_KERNEL(AttrTestBackward)); diff --git a/python/paddle/fluid/tests/custom_op/relu_op_simple.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc similarity index 81% rename from python/paddle/fluid/tests/custom_op/relu_op_simple.cc rename to python/paddle/fluid/tests/custom_op/custom_relu_op.cc index b02ecba6826fa0b9dc4bcc6db07b9b8717f834aa..0e358e24ae3e814b3fd21d010c478812aa0b8340 100644 --- a/python/paddle/fluid/tests/custom_op/relu_op_simple.cc +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc @@ -17,13 +17,6 @@ #include "paddle/extension.h" -template -void fill_constant_cpu_kernel(data_t* out_data, int64_t x_numel, data_t value) { - for (int i = 0; i < x_numel; ++i) { - out_data[i] = value; - } -} - template void relu_cpu_forward_kernel(const data_t* x_data, data_t* out_data, @@ -53,21 +46,8 @@ std::vector relu_cpu_forward(const paddle::Tensor& x) { relu_cpu_forward_kernel( x.data(), out.mutable_data(x.place()), x.size()); })); - // fake multi output: Fake_float64 with float64 dtype - auto fake_float64 = paddle::Tensor(paddle::PlaceType::kCPU); - fake_float64.reshape(x.shape()); - - fill_constant_cpu_kernel( - fake_float64.mutable_data(x.place()), x.size(), 0.); - - // fake multi output: ZFake_int32 with int32 dtype - auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kCPU); - zfake_int32.reshape(x.shape()); - - fill_constant_cpu_kernel( - zfake_int32.mutable_data(x.place()), x.size(), 1); - return {out, fake_float64, zfake_int32}; + return {out}; } std::vector relu_cpu_backward(const paddle::Tensor& x, @@ -117,16 +97,16 @@ std::vector ReluBackward(const paddle::Tensor& x, } std::vector> ReluInferShape(std::vector x_shape) { - return {x_shape, x_shape, x_shape}; + return {x_shape}; } std::vector ReluInferDType(paddle::DataType x_dtype) { - return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32}; + return {x_dtype}; } -PD_BUILD_OP("relu2") +PD_BUILD_OP("custom_relu") .Inputs({"X"}) - .Outputs({"Out", "Fake_float64", "ZFake_int32"}) + .Outputs({"Out"}) .SetKernelFn(PD_KERNEL(ReluForward)) .SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) diff --git a/python/paddle/fluid/tests/custom_op/relu_op_simple.cu b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu similarity index 75% rename from python/paddle/fluid/tests/custom_op/relu_op_simple.cu rename to python/paddle/fluid/tests/custom_op/custom_relu_op.cu index 2ef6a5c1451e7409faca951cf9542c20cef59466..a9ce5176070939be24a8e6d965faa60b6f391bff 100644 --- a/python/paddle/fluid/tests/custom_op/relu_op_simple.cu +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu @@ -14,16 +14,6 @@ #include "paddle/extension.h" -template -__global__ void fill_constant_cuda_kernel(data_t* y, - const int num, - data_t value) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; - for (int i = gid; i < num; i += blockDim.x * gridDim.x) { - y[i] = value; - } -} - template __global__ void relu_cuda_forward_kernel(const data_t* x, data_t* y, @@ -57,18 +47,8 @@ std::vector relu_cuda_forward(const paddle::Tensor& x) { relu_cuda_forward_kernel<<>>( x.data(), out.mutable_data(x.place()), numel); })); - // fake multi output: Fake_1 - auto fake_float64 = paddle::Tensor(paddle::PlaceType::kGPU); - fake_float64.reshape(x.shape()); - fill_constant_cuda_kernel<<>>( - fake_float64.mutable_data(x.place()), numel, 0.); - // fake multi output: ZFake_1 - auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kGPU); - zfake_int32.reshape(x.shape()); - fill_constant_cuda_kernel<<>>( - zfake_int32.mutable_data(x.place()), numel, 1); - return {out, fake_float64, zfake_int32}; + return {out}; } std::vector relu_cuda_backward(const paddle::Tensor& x, diff --git a/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op_dup.cc similarity index 92% rename from python/paddle/fluid/tests/custom_op/relu_op3_simple.cc rename to python/paddle/fluid/tests/custom_op/custom_relu_op_dup.cc index ec64bce18736b7b3ef67b5ba58f9429913b3a92e..7319bdd76264508ef485ba80382aac6dcbbeb4b6 100644 --- a/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op_dup.cc @@ -29,11 +29,11 @@ std::vector> ReluInferShape(std::vector x_shape); std::vector ReluInferDType(paddle::DataType x_dtype); -// Reuse codes in `relu_op_simple.cc/cu` to register another custom operator +// Reuse codes in `custom_relu_op.cc/cu` to register another custom operator // to test jointly compile multi operators at same time. -PD_BUILD_OP("relu3") +PD_BUILD_OP("custom_relu_dup") .Inputs({"X"}) - .Outputs({"Out", "Fake_float64", "ZFake_int32"}) + .Outputs({"Out"}) .SetKernelFn(PD_KERNEL(ReluForward)) .SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) diff --git a/python/paddle/fluid/tests/custom_op/setup_install_simple.py b/python/paddle/fluid/tests/custom_op/custom_relu_setup.py similarity index 79% rename from python/paddle/fluid/tests/custom_op/setup_install_simple.py rename to python/paddle/fluid/tests/custom_op/custom_relu_setup.py index ed236ccbd4c59c830eaa04761b23135fb8bc16d2..598b850c876e2ff341d47e1afcf7ec6534163865 100644 --- a/python/paddle/fluid/tests/custom_op/setup_install_simple.py +++ b/python/paddle/fluid/tests/custom_op/custom_relu_setup.py @@ -17,11 +17,14 @@ import os from utils import paddle_includes, extra_compile_args from paddle.utils.cpp_extension import CUDAExtension, setup +# custom_relu_op_dup.cc is only used for multi ops test, +# not a new op, if you want to test only one op, remove this +# source file setup( - name='simple_setup_relu2', + name='custom_relu_module_setup', ext_modules=CUDAExtension( # test for not specific name here. sources=[ - 'relu_op_simple.cc', 'relu_op_simple.cu', 'relu_op3_simple.cc' + 'custom_relu_op.cc', 'custom_relu_op.cu', 'custom_relu_op_dup.cc' ], # test for multi ops include_dirs=paddle_includes, extra_compile_args=extra_compile_args)) diff --git a/python/paddle/fluid/tests/custom_op/multi_out_test_op.cc b/python/paddle/fluid/tests/custom_op/multi_out_test_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bece0f49845a5ae3fd006ccf383adb78f043bd4b --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/multi_out_test_op.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" + +template +void assign_cpu_kernel(const data_t* x_data, + data_t* out_data, + int64_t x_numel) { + for (int i = 0; i < x_numel; ++i) { + out_data[i] = x_data[i]; + } +} + +template +void fill_constant_cpu_kernel(data_t* out_data, int64_t x_numel, data_t value) { + for (int i = 0; i < x_numel; ++i) { + out_data[i] = value; + } +} + +std::vector MultiOutCPU(const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_FLOATING_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(x.place()), x.size()); + })); + + // fake multi output: Fake_float64 with float64 dtype + auto fake_float64 = paddle::Tensor(paddle::PlaceType::kCPU); + fake_float64.reshape(x.shape()); + + fill_constant_cpu_kernel( + fake_float64.mutable_data(x.place()), x.size(), 0.); + + // fake multi output: ZFake_int32 with int32 dtype + auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kCPU); + zfake_int32.reshape(x.shape()); + + fill_constant_cpu_kernel( + zfake_int32.mutable_data(x.place()), x.size(), 1); + + return {out, fake_float64, zfake_int32}; +} + +std::vector> InferShape(std::vector x_shape) { + return {x_shape, x_shape, x_shape}; +} + +std::vector InferDtype(paddle::DataType x_dtype) { + return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32}; +} + +PD_BUILD_OP("multi_out") + .Inputs({"X"}) + .Outputs({"Out", "Fake_float64", "ZFake_int32"}) + .SetKernelFn(PD_KERNEL(MultiOutCPU)) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDtype)); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_attrs_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_attrs_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..754f76cab86f083923423652055be191982e5b14 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_custom_attrs_jit.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import numpy as np + +import paddle +from paddle.utils.cpp_extension import load, get_build_directory +from utils import paddle_includes, extra_compile_args +from paddle.utils.cpp_extension.extension_utils import run_cmd + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = '{}\\custom_attrs_jit\\custom_attrs_jit.pyd'.format(get_build_directory( +)) +if os.name == 'nt' and os.path.isfile(file): + cmd = 'del {}'.format(file) + run_cmd(cmd, True) + +# Compile and load custom op Just-In-Time. +custom_attrs = load( + name='custom_attrs_jit', + sources=['attr_test_op.cc'], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_compile_args, # add for Coverage CI + verbose=True) + + +class TestJitCustomAttrs(unittest.TestCase): + def test_attr_value(self): + paddle.set_device('cpu') + # prepare test value + bool_attr = True + int_attr = 10 + float_attr = 3.14 + int64_attr = 10000000000 + str_attr = "StrAttr" + int_vec_attr = [10, 10, 10] + float_vec_attr = [3.14, 3.14, 3.14] + int64_vec_attr = [10000000000, 10000000000, 10000000000] + str_vec_attr = ["StrAttr", "StrAttr", "StrAttr"] + + x = paddle.ones([2, 2], dtype='float32') + x.stop_gradient = False + out = custom_attrs.attr_test( + x, bool_attr, int_attr, float_attr, int64_attr, str_attr, + int_vec_attr, float_vec_attr, int64_vec_attr, str_vec_attr) + out.stop_gradient = False + out.backward() + + self.assertTrue(np.array_equal(x.numpy(), out.numpy())) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..9c108a799d955f9c92be057e0720b07a2c16ed9b --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import unittest +import paddle +import numpy as np +from paddle.utils.cpp_extension import load, get_build_directory +from paddle.utils.cpp_extension.extension_utils import run_cmd +from utils import paddle_includes, extra_compile_args +from test_custom_relu_op_setup import custom_relu_dynamic, custom_relu_static + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = '{}\\custom_relu_module_jit\\custom_relu_module_jit.pyd'.format( + get_build_directory()) +if os.name == 'nt' and os.path.isfile(file): + cmd = 'del {}'.format(file) + run_cmd(cmd, True) + +# Compile and load custom op Just-In-Time. +# custom_relu_op_dup.cc is only used for multi ops test, +# not a new op, if you want to test only one op, remove this +# source file +custom_module = load( + name='custom_relu_module_jit', + sources=[ + 'custom_relu_op.cc', 'custom_relu_op.cu', 'custom_relu_op_dup.cc' + ], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_compile_args, # add for Coverage CI + extra_cuda_cflags=extra_compile_args, # add for Coverage CI + verbose=True) + + +class TestJITLoad(unittest.TestCase): + def setUp(self): + self.custom_ops = [ + custom_module.custom_relu, custom_module.custom_relu_dup + ] + self.dtypes = ['float32', 'float64'] + self.devices = ['cpu', 'gpu'] + + def test_static(self): + for device in self.devices: + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + for custom_op in self.custom_ops: + out = custom_relu_static(custom_op, device, dtype, x) + pd_out = custom_relu_static(custom_op, device, dtype, x, + False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + + def test_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + for custom_op in self.custom_ops: + out, x_grad = custom_relu_dynamic(custom_op, device, dtype, + x) + pd_out, pd_x_grad = custom_relu_dynamic(custom_op, device, + dtype, x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + self.assertTrue( + np.array_equal(x_grad, pd_x_grad), + "custom op x grad: {},\n paddle api x grad: {}".format( + x_grad, pd_x_grad)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py similarity index 50% rename from python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py rename to python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py index cfa2db0ba24a49a20b825e47d2b90077c3b6d463..6781915e021c92f4c0f6a25e9f42ab940a3035d2 100644 --- a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py @@ -23,13 +23,13 @@ import numpy as np from paddle.utils.cpp_extension.extension_utils import run_cmd -def relu2_dynamic(func, device, dtype, np_x, use_func=True): +def custom_relu_dynamic(func, device, dtype, np_x, use_func=True): paddle.set_device(device) t = paddle.to_tensor(np_x) t.stop_gradient = False - out = func(t)[0] if use_func else paddle.nn.functional.relu(t) + out = func(t) if use_func else paddle.nn.functional.relu(t) out.stop_gradient = False out.backward() @@ -37,7 +37,12 @@ def relu2_dynamic(func, device, dtype, np_x, use_func=True): return out.numpy(), t.grad -def relu2_static(func, device, dtype, np_x, use_func=True): +def custom_relu_static(func, + device, + dtype, + np_x, + use_func=True, + test_infer=False): paddle.enable_static() paddle.set_device(device) @@ -45,8 +50,7 @@ def relu2_static(func, device, dtype, np_x, use_func=True): with static.program_guard(static.Program()): x = static.data(name='X', shape=[None, 8], dtype=dtype) x.stop_gradient = False - # out, fake_float64, fake_int32 - out = func(x)[0] if use_func else paddle.nn.functional.relu(x) + out = func(x) if use_func else paddle.nn.functional.relu(x) static.append_backward(out) exe = static.Executor() @@ -60,7 +64,7 @@ def relu2_static(func, device, dtype, np_x, use_func=True): return out_v -def relu2_static_pe(func, device, dtype, np_x, use_func=True): +def custom_relu_static_pe(func, device, dtype, np_x, use_func=True): paddle.enable_static() paddle.set_device(device) @@ -69,7 +73,7 @@ def relu2_static_pe(func, device, dtype, np_x, use_func=True): with static.program_guard(static.Program()): x = static.data(name='X', shape=[None, 8], dtype=dtype) x.stop_gradient = False - out = func(x)[0] if use_func else paddle.nn.functional.relu(x) + out = func(x) if use_func else paddle.nn.functional.relu(x) static.append_backward(out) exe = static.Executor() @@ -87,11 +91,58 @@ def relu2_static_pe(func, device, dtype, np_x, use_func=True): return out_v +def custom_relu_static_inference(func, device, np_data, np_label, path_prefix): + paddle.set_device(device) + + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + # simple module + data = static.data( + name='data', shape=[None, 1, 28, 28], dtype='float32') + label = static.data(name='label', shape=[None, 1], dtype='int64') + + hidden = static.nn.fc(data, size=128) + hidden = func(hidden) + hidden = static.nn.fc(hidden, size=128) + predict = static.nn.fc(hidden, size=10, activation='softmax') + loss = paddle.nn.functional.cross_entropy(input=hidden, label=label) + avg_loss = paddle.mean(loss) + + opt = paddle.optimizer.SGD(learning_rate=0.1) + opt.minimize(avg_loss) + + # run start up model + exe = static.Executor() + exe.run(static.default_startup_program()) + + # train + for i in range(4): + avg_loss_v = exe.run(static.default_main_program(), + feed={'data': np_data, + 'label': np_label}, + fetch_list=[avg_loss]) + + # save inference model + static.save_inference_model(path_prefix, [data], [predict], exe) + + # get train predict value + predict_v = exe.run(static.default_main_program(), + feed={'data': np_data, + 'label': np_label}, + fetch_list=[predict]) + + return predict_v + + class TestNewCustomOpSetUpInstall(unittest.TestCase): def setUp(self): cur_dir = os.path.dirname(os.path.abspath(__file__)) # compile, install the custom op egg into site-packages under background - cmd = 'cd {} && python setup_install_simple.py install'.format(cur_dir) + if os.name == 'nt': + cmd = 'cd /d {} && python custom_relu_setup.py install'.format( + cur_dir) + else: + cmd = 'cd {} && python custom_relu_setup.py install'.format(cur_dir) run_cmd(cmd) # NOTE(Aurelius84): Normally, it's no need to add following codes for users. @@ -99,28 +150,42 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): # sys.path has been updated. So we update it manually. # See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3 - site_dir = site.getsitepackages()[0] + if os.name == 'nt': + # NOTE(zhouwei25): getsitepackages on windows will return a list: [python install dir, site packages dir] + site_dir = site.getsitepackages()[1] + else: + site_dir = site.getsitepackages()[0] custom_egg_path = [ - x for x in os.listdir(site_dir) if 'simple_setup_relu2' in x + x for x in os.listdir(site_dir) if 'custom_relu_module_setup' in x ] assert len(custom_egg_path) == 1, "Matched egg number is %d." % len( custom_egg_path) sys.path.append(os.path.join(site_dir, custom_egg_path[0])) # usage: import the package directly - import simple_setup_relu2 - self.custom_ops = [simple_setup_relu2.relu2, simple_setup_relu2.relu3] + import custom_relu_module_setup + # `custom_relu_dup` is same as `custom_relu_dup` + self.custom_ops = [ + custom_relu_module_setup.custom_relu, + custom_relu_module_setup.custom_relu_dup + ] self.dtypes = ['float32', 'float64'] self.devices = ['cpu', 'gpu'] + # config seed + SEED = 2021 + paddle.seed(SEED) + paddle.framework.random._manual_program_seed(SEED) + def test_static(self): for device in self.devices: for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: - out = relu2_static(custom_op, device, dtype, x) - pd_out = relu2_static(custom_op, device, dtype, x, False) + out = custom_relu_static(custom_op, device, dtype, x) + pd_out = custom_relu_static(custom_op, device, dtype, x, + False) self.assertTrue( np.array_equal(out, pd_out), "custom op out: {},\n paddle api out: {}".format( @@ -131,8 +196,9 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: - out = relu2_static_pe(custom_op, device, dtype, x) - pd_out = relu2_static_pe(custom_op, device, dtype, x, False) + out = custom_relu_static_pe(custom_op, device, dtype, x) + pd_out = custom_relu_static_pe(custom_op, device, dtype, x, + False) self.assertTrue( np.array_equal(out, pd_out), "custom op out: {},\n paddle api out: {}".format( @@ -143,9 +209,10 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: - out, x_grad = relu2_dynamic(custom_op, device, dtype, x) - pd_out, pd_x_grad = relu2_dynamic(custom_op, device, dtype, - x, False) + out, x_grad = custom_relu_dynamic(custom_op, device, dtype, + x) + pd_out, pd_x_grad = custom_relu_dynamic(custom_op, device, + dtype, x, False) self.assertTrue( np.array_equal(out, pd_out), "custom op out: {},\n paddle api out: {}".format( @@ -155,6 +222,28 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): "custom op x grad: {},\n paddle api x grad: {}".format( x_grad, pd_x_grad)) + def test_static_save_and_load_inference_model(self): + paddle.enable_static() + np_data = np.random.random((1, 1, 28, 28)).astype("float32") + np_label = np.random.random((1, 1)).astype("int64") + path_prefix = "custom_op_inference/custom_relu" + for device in self.devices: + predict = custom_relu_static_inference( + self.custom_ops[0], device, np_data, np_label, path_prefix) + # load inference model + with static.scope_guard(static.Scope()): + exe = static.Executor() + [inference_program, feed_target_names, + fetch_targets] = static.load_inference_model(path_prefix, exe) + predict_infer = exe.run(inference_program, + feed={feed_target_names[0]: np_data}, + fetch_list=fetch_targets) + self.assertTrue( + np.array_equal(predict, predict_infer), + "custom op predict: {},\n custom op infer predict: {}". + format(predict, predict_infer)) + paddle.disable_static() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_dispatch.py b/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py similarity index 84% rename from python/paddle/fluid/tests/custom_op/test_dispatch.py rename to python/paddle/fluid/tests/custom_op/test_dispatch_jit.py index 1766a6042f395f34a39fc6da8d93646ca6b50597..54d317c37faa9019cfec93f18c5e88cedf5ddc9f 100644 --- a/python/paddle/fluid/tests/custom_op/test_dispatch.py +++ b/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py @@ -16,14 +16,23 @@ import os import unittest import paddle import numpy as np -from paddle.utils.cpp_extension import load +from paddle.utils.cpp_extension import load, get_build_directory from utils import paddle_includes, extra_compile_args +from paddle.utils.cpp_extension.extension_utils import run_cmd + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = '{}\\dispatch_op\\dispatch_op.pyd'.format(get_build_directory()) +if os.name == 'nt' and os.path.isfile(file): + cmd = 'del {}'.format(file) + run_cmd(cmd, True) dispatch_op = load( name='dispatch_op', sources=['dispatch_test_op.cc'], extra_include_paths=paddle_includes, # add for Coverage CI - extra_cflags=extra_compile_args) # add for Coverage CI + extra_cxx_cflags=extra_compile_args, + verbose=True) class TestJitDispatch(unittest.TestCase): diff --git a/python/paddle/fluid/tests/custom_op/test_jit_load.py b/python/paddle/fluid/tests/custom_op/test_jit_load.py index 222c69f5edcc56747bf3819ac4bfd5b1915e3ded..ccb9544433488113bbda734ebc6e443cfbdb4be7 100644 --- a/python/paddle/fluid/tests/custom_op/test_jit_load.py +++ b/python/paddle/fluid/tests/custom_op/test_jit_load.py @@ -29,7 +29,8 @@ custom_module = load( sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc', 'relu_op3.cu'], interpreter='python', # add for unittest extra_include_paths=paddle_includes, # add for Coverage CI - extra_cflags=extra_compile_args, # add for Coverage CI + extra_cxx_cflags=extra_compile_args, # add for Coverage CI, + extra_cuda_cflags=extra_compile_args, # add for split cpp/cuda flags verbose=True # add for unittest ) diff --git a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py b/python/paddle/fluid/tests/custom_op/test_multi_out_jit.py similarity index 60% rename from python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py rename to python/paddle/fluid/tests/custom_op/test_multi_out_jit.py index 2c0dc1a4ca6a119c1dc9dd0bf8add15e677aaf43..79d366cc4af448191d073bdd141ff6e3a1a9d379 100644 --- a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_multi_out_jit.py @@ -13,81 +13,54 @@ # limitations under the License. import os +import subprocess import unittest -import paddle import numpy as np + +import paddle from paddle.utils.cpp_extension import load +from paddle.utils.cpp_extension import load, get_build_directory +from paddle.utils.cpp_extension.extension_utils import run_cmd from utils import paddle_includes, extra_compile_args -from test_simple_custom_op_setup import relu2_dynamic, relu2_static + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = '{}\\multi_out_jit\\multi_out_jit.pyd'.format(get_build_directory()) +if os.name == 'nt' and os.path.isfile(file): + cmd = 'del {}'.format(file) + run_cmd(cmd, True) # Compile and load custom op Just-In-Time. -custom_module = load( - name='simple_jit_relu2', - sources=['relu_op_simple.cc', 'relu_op_simple.cu', 'relu_op3_simple.cc'], +multi_out_module = load( + name='multi_out_jit', + sources=['multi_out_test_op.cc'], extra_include_paths=paddle_includes, # add for Coverage CI - extra_cflags=extra_compile_args) # add for Coverage CI - - -class TestJITLoad(unittest.TestCase): - def setUp(self): - self.custom_ops = [custom_module.relu2, custom_module.relu3] - self.dtypes = ['float32', 'float64'] - self.devices = ['cpu', 'gpu'] - - def test_static(self): - for device in self.devices: - for dtype in self.dtypes: - x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - for custom_op in self.custom_ops: - out = relu2_static(custom_op, device, dtype, x) - pd_out = relu2_static(custom_op, device, dtype, x, False) - self.assertTrue( - np.array_equal(out, pd_out), - "custom op out: {},\n paddle api out: {}".format( - out, pd_out)) - - def test_dynamic(self): - for device in self.devices: - for dtype in self.dtypes: - x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - for custom_op in self.custom_ops: - out, x_grad = relu2_dynamic(custom_op, device, dtype, x) - pd_out, pd_x_grad = relu2_dynamic(custom_op, device, dtype, - x, False) - self.assertTrue( - np.array_equal(out, pd_out), - "custom op out: {},\n paddle api out: {}".format( - out, pd_out)) - self.assertTrue( - np.array_equal(x_grad, pd_x_grad), - "custom op x grad: {},\n paddle api x grad: {}".format( - x_grad, pd_x_grad)) + extra_cxx_cflags=extra_compile_args, # add for Coverage CI + verbose=True) class TestMultiOutputDtypes(unittest.TestCase): def setUp(self): - self.custom_op = custom_module.relu2 + self.custom_op = multi_out_module.multi_out self.dtypes = ['float32', 'float64'] - self.devices = ['cpu', 'gpu'] + self.devices = ['cpu'] - def test_static(self): - paddle.enable_static() - for device in self.devices: - for dtype in self.dtypes: - res = self.run_static(device, dtype) - self.check_multi_outputs(res) - paddle.disable_static() + def run_static(self, device, dtype): + paddle.set_device(device) + x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - def test_dynamic(self): - for device in self.devices: - for dtype in self.dtypes: - paddle.set_device(device) - x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - x = paddle.to_tensor(x_data) + with paddle.static.scope_guard(paddle.static.Scope()): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name='X', shape=[None, 8], dtype=dtype) outs = self.custom_op(x) - self.assertTrue(len(outs) == 3) - self.check_multi_outputs(outs, True) + exe = paddle.static.Executor() + exe.run(paddle.static.default_startup_program()) + res = exe.run(paddle.static.default_main_program(), + feed={'X': x_data}, + fetch_list=outs) + + return res def check_multi_outputs(self, outs, is_dynamic=False): out, zero_float64, one_int32 = outs @@ -103,22 +76,24 @@ class TestMultiOutputDtypes(unittest.TestCase): self.assertTrue( np.array_equal(one_int32, np.ones([4, 8]).astype('int32'))) - def run_static(self, device, dtype): - paddle.set_device(device) - x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + def test_static(self): + paddle.enable_static() + for device in self.devices: + for dtype in self.dtypes: + res = self.run_static(device, dtype) + self.check_multi_outputs(res) + paddle.disable_static() - with paddle.static.scope_guard(paddle.static.Scope()): - with paddle.static.program_guard(paddle.static.Program()): - x = paddle.static.data(name='X', shape=[None, 8], dtype=dtype) + def test_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + paddle.set_device(device) + x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + x = paddle.to_tensor(x_data) outs = self.custom_op(x) - exe = paddle.static.Executor() - exe.run(paddle.static.default_startup_program()) - res = exe.run(paddle.static.default_main_program(), - feed={'X': x_data}, - fetch_list=outs) - - return res + self.assertTrue(len(outs) == 3) + self.check_multi_outputs(outs, True) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/custom_op/utils.py b/python/paddle/fluid/tests/custom_op/utils.py index f293c751942cda432400ce1786326eb14cf6a9b2..52b294dc72b4ba08ef380f954fd39cc5577918b5 100644 --- a/python/paddle/fluid/tests/custom_op/utils.py +++ b/python/paddle/fluid/tests/custom_op/utils.py @@ -23,8 +23,8 @@ site_packages_path = get_python_lib() # paddle include directory. Because the following path is generated after insalling # PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI. paddle_includes = [ - os.path.join(site_packages_path, 'paddle/include'), - os.path.join(site_packages_path, 'paddle/include/third_party') + os.path.join(site_packages_path, 'paddle', 'include'), + os.path.join(site_packages_path, 'paddle', 'include', 'third_party') ] # TODO(Aurelius84): Memory layout is different if build paddle with PADDLE_WITH_MKLDNN=ON, diff --git a/python/paddle/utils/cpp_extension/__init__.py b/python/paddle/utils/cpp_extension/__init__.py index 024fbb6bf7c4ed1a156322c83f34aed59e8da25b..130ab79b3038df026b3eeabcef45eae192aba78c 100644 --- a/python/paddle/utils/cpp_extension/__init__.py +++ b/python/paddle/utils/cpp_extension/__init__.py @@ -25,6 +25,5 @@ from . import cpp_extension from . import extension_utils __all__ = [ - 'CppExtension', 'CUDAExtension', 'BuildExtension', 'load', 'setup', - 'get_build_directory' + 'CppExtension', 'CUDAExtension', 'load', 'setup', 'get_build_directory' ] diff --git a/python/paddle/utils/cpp_extension/cpp_extension.py b/python/paddle/utils/cpp_extension/cpp_extension.py index 121c1626125af9974519e30ac87d8130c7466f25..57bcea658b53c400234afd22d4d5acc77f7f43ce 100644 --- a/python/paddle/utils/cpp_extension/cpp_extension.py +++ b/python/paddle/utils/cpp_extension/cpp_extension.py @@ -14,47 +14,124 @@ import os import six -import sys -import textwrap import copy +import re import setuptools from setuptools.command.easy_install import easy_install from setuptools.command.build_ext import build_ext +from distutils.command.build import build from .extension_utils import find_cuda_home, normalize_extension_kwargs, add_compile_flag, bootstrap_context -from .extension_utils import is_cuda_file, prepare_unix_cflags, add_std_without_repeat, get_build_directory +from .extension_utils import is_cuda_file, prepare_unix_cudaflags, prepare_win_cudaflags, add_std_without_repeat, get_build_directory from .extension_utils import _import_module_from_library, CustomOpInfo, _write_setup_file, _jit_compile, parse_op_name_from -from .extension_utils import check_abi_compatibility, log_v, IS_WINDOWS -from .extension_utils import use_new_custom_op_load_method +from .extension_utils import check_abi_compatibility, log_v, IS_WINDOWS, OS_NAME +from .extension_utils import use_new_custom_op_load_method, MSVC_COMPILE_FLAGS + +# Note(zhouwei): On windows, it will export function 'PyInit_[name]' by default, +# The solution is: 1.User add function PyInit_[name] 2. set not to export +# refer to https://stackoverflow.com/questions/34689210/error-exporting-symbol-when-building-python-c-extension-in-windows +if IS_WINDOWS and six.PY3: + from distutils.command.build_ext import build_ext as _du_build_ext + from unittest.mock import Mock + _du_build_ext.get_export_symbols = Mock(return_value=None) CUDA_HOME = find_cuda_home() def setup(**attr): """ - Wrapper setuptools.setup function to valid `build_ext` command and - implement paddle api code injection by switching `write_stub` - function in bdist_egg with `custom_write_stub`. - - Its usage is almost same as `setuptools.setup` except for `ext_modules` - arguments. For compiling multi custom operators, all necessary source files - can be include into just one Extension (CppExtension/CUDAExtension). - Moreover, only one `name` argument is required in `setup` and no need to spcific - `name` in Extension. - - Example: - - >> from paddle.utils.cpp_extension import CUDAExtension, setup - >> setup(name='custom_module', - ext_modules=CUDAExtension( - sources=['relu_op.cc', 'relu_op.cu'], - include_dirs=[], # specific user-defined include dirs - extra_compile_args=[]) # specific user-defined compil arguments. + The interface is used to config the process of compiling customized operators, + mainly includes how to complile shared library, automatically generate python API + and install it into site-package. It supports using customized operators directly with + ``import`` statement. + + It encapsulates the python built-in ``setuptools.setup`` function and keeps arguments + and usage same as the native interface. Meanwhile, it hiddens Paddle inner framework + concepts, such as necessary compiling flags, included paths of head files, and linking + flags. It also will automatically search and valid local enviromment and versions of ``cc`` and + ``nvcc`` , then compiles customized operators supporting CPU or GPU device according to + the specified Extension type. + + Moreover, `ABI compatibility `_ + will be checked to ensure that compiler version from ``cc`` + on local machine is compatible with pre-installed Paddle whl in python site-packages. + For example if Paddle with CUDA 10.1 is built with GCC 8.2, then the version of user's + local machine should satisfy GCC >= 8.2. Otherwise, a fatal error will occur because of + ABI compatibility. + + .. note:: + + 1. Compiler ABI compatibility is forward compatible. On Linux platform, + we recommend to use GCC 8.2 as soft linking condidate of ``/usr/bin/cc`` . + 2. Using ``which cc`` to ensure location of ``cc`` and using ``cc --version`` + to ensure linking GCC version on Linux. + 3. Currently we support Linux and Windows platfrom. MacOS is supporting... + + + Compared with Just-In-Time ``load`` interface, it only compiles once by executing + ``python setup.py install`` . Then customized operators API will be available everywhere + after importing it. + + A simple example of ``setup.py`` as followed: + + .. code-block:: text + + # setup.py + + # Case 1: Compiling customized operators supporting CPU and GPU devices + from paddle.utils.cpp_extension import CUDAExtension, setup + + setup( + name='custom_op', # name of package used by "import" + ext_modules=CUDAExtension( + sources=['relu_op.cc', 'relu_op.cu', 'tanh_op.cc', 'tanh_op.cu'] # Support for compilation of multiple OPs + ) + ) + + # Case 2: Compiling customized operators supporting only CPU device + from paddle.utils.cpp_extension import CppExtension, setup + + setup( + name='custom_op', # name of package used by "import" + ext_modules=CppExtension( + sources=['relu_op.cc', 'tanh_op.cc'] # Support for compilation of multiple OPs + ) + ) + + + Applying compilation and installation by executing ``python setup.py install`` under source files directory. + Then we can use the layer api as followed: + + .. code-block:: text + + import paddle + from custom_op import relu, tanh + + x = paddle.randn([4, 10], dtype='float32') + relu_out = relu(x) + tanh_out = tanh(x) + + + Args: + name(str): Specify the name of shared library file and installed python package. + ext_modules(Extension): Specify the Extension instance including customized operator source files, compiling flags et.al. + If only compile operator supporting CPU device, please use ``CppExtension`` ; If compile operator + supporting CPU and GPU devices, please use ``CUDAExtension`` . + include_dirs(list[str], optional): Specify the extra include directoies to search head files. The interface will automatically add + ``site-package/paddle/include`` . Please add the corresponding directory path if including third-party + head files. Default is None. + extra_compile_args(list[str] | dict, optional): Specify the extra compiling flags such as ``-O3`` . If set ``list[str]`` , all these flags + will be applied for ``cc`` and ``nvcc`` compiler. It support specify flags only applied ``cc`` or ``nvcc`` + compiler using dict type with ``{'cxx': [...], 'nvcc': [...]}`` . Default is None. + **attr(dict, optional): Specify other arguments same as ``setuptools.setup`` . + + Returns: None + """ cmdclass = attr.get('cmdclass', {}) assert isinstance(cmdclass, dict) - # if not specific cmdclass in setup, add it automaticaly. + # if not specific cmdclass in setup, add it automatically. if 'build_ext' not in cmdclass: cmdclass['build_ext'] = BuildExtension.with_options( no_python_abi_suffix=True) @@ -71,18 +148,22 @@ def setup(**attr): sources=['relu_op.cc', 'relu_op.cu']) # After running `python setup.py install` - from custom_module import relue + from custom_module import relu """ # name argument is required if 'name' not in attr: raise ValueError(error_msg) + assert not attr['name'].endswith('module'), \ + "Please don't use 'module' as suffix in `name` argument, " + "it will be stripped in setuptools.bdist_egg and cause import error." + ext_modules = attr.get('ext_modules', []) if not isinstance(ext_modules, list): ext_modules = [ext_modules] assert len( ext_modules - ) == 1, "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extenion.".format( + ) == 1, "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extension.".format( len(ext_modules)) # replace Extension.name with attr['name] to keep consistant with Package name. for ext_module in ext_modules: @@ -94,6 +175,13 @@ def setup(**attr): assert 'easy_install' not in cmdclass cmdclass['easy_install'] = EasyInstallCommand + # Note(Aurelius84): Add rename build_base directory hook in build command. + # To avoid using same build directory that will lead to remove the directory + # by mistake while parallelling execute setup.py, for example on CI. + assert 'build' not in cmdclass + build_base = os.path.join('build', attr['name']) + cmdclass['build'] = BuildCommand.with_options(build_base=build_base) + # Always set zip_safe=False to make compatible in PY2 and PY3 # See http://peak.telecommunity.com/DevCenter/setuptools#setting-the-zip-safe-flag attr['zip_safe'] = False @@ -105,16 +193,41 @@ def setup(**attr): def CppExtension(sources, *args, **kwargs): """ - Returns setuptools.CppExtension instance for setup.py to make it easy - to specify compile flags while building C++ custommed op kernel. + The interface is used to config source files of customized operators and complies + Op Kernel only supporting CPU device. Please use ``CUDAExtension`` if you want to + compile Op Kernel that supports both CPU and GPU devices. + + It furtherly encapsulates python built-in ``setuptools.Extension`` .The arguments and + usage are same as the native interface, except for no need to explicitly specify + ``name`` . + + **A simple example:** + + .. code-block:: text + + # setup.py + + # Compiling customized operators supporting only CPU device + from paddle.utils.cpp_extension import CppExtension, setup + + setup( + name='custom_op', + ext_modules=CppExtension(sources=['relu_op.cc']) + ) + + + .. note:: + It is mainly used in ``setup`` and the nama of built shared library keeps same + as ``name`` argument specified in ``setup`` interface. + Args: - sources(list[str]): The C++/CUDA source file names - args(list[options]): list of config options used to compile shared library - kwargs(dict[option]): dict of config options used to compile shared library - - Returns: - Extension: An instance of setuptools.Extension + sources(list[str]): Specify the C++/CUDA source files of customized operators. + *args(list[options], optional): Specify other arguments same as ``setuptools.Extension`` . + **kwargs(dict[option], optional): Specify other arguments same as ``setuptools.Extension`` . + + Returns: + setuptools.Extension: An instance of ``setuptools.Extension`` """ kwargs = normalize_extension_kwargs(kwargs, use_cuda=False) # Note(Aurelius84): While using `setup` and `jit`, the Extension `name` will @@ -130,16 +243,43 @@ def CppExtension(sources, *args, **kwargs): def CUDAExtension(sources, *args, **kwargs): """ - Returns setuptools.CppExtension instance for setup.py to make it easy - to specify compile flags while build CUDA custommed op kernel. + The interface is used to config source files of customized operators and complies + Op Kernel supporting both CPU and GPU devices. Please use ``CppExtension`` if you want to + compile Op Kernel that supports only CPU device. + + It furtherly encapsulates python built-in ``setuptools.Extension`` .The arguments and + usage are same as the native interface, except for no need to explicitly specify + ``name`` . + + **A simple example:** + + .. code-block:: text + + # setup.py + + # Compiling customized operators supporting CPU and GPU devices + from paddle.utils.cpp_extension import CUDAExtension, setup + + setup( + name='custom_op', + ext_modules=CUDAExtension( + sources=['relu_op.cc', 'relu_op.cu'] + ) + ) + + + .. note:: + It is mainly used in ``setup`` and the nama of built shared library keeps same + as ``name`` argument specified in ``setup`` interface. + Args: - sources(list[str]): The C++/CUDA source file names - args(list[options]): list of config options used to compile shared library - kwargs(dict[option]): dict of config options used to compile shared library - - Returns: - Extension: An instance of setuptools.Extension + sources(list[str]): Specify the C++/CUDA source files of customized operators. + *args(list[options], optional): Specify other arguments same as ``setuptools.Extension`` . + **kwargs(dict[option], optional): Specify other arguments same as ``setuptools.Extension`` . + + Returns: + setuptools.Extension: An instance of setuptools.Extension """ kwargs = normalize_extension_kwargs(kwargs, use_cuda=True) # Note(Aurelius84): While using `setup` and `jit`, the Extension `name` will @@ -191,20 +331,17 @@ class BuildExtension(build_ext, object): def __init__(self, *args, **kwargs): """ Attributes is initialized with following oreder: - + 1. super(self).__init__() 2. initialize_options(self) 3. the reset of current __init__() 4. finalize_options(self) - + So, it is recommended to set attribute value in `finalize_options`. """ super(BuildExtension, self).__init__(*args, **kwargs) self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", True) self.output_dir = kwargs.get("output_dir", None) - # for compatible two custom op define method - use_new_custom_op_load_method( - kwargs.get("use_new_method", use_new_custom_op_load_method())) def initialize_options(self): super(BuildExtension, self).initialize_options() @@ -219,20 +356,14 @@ class BuildExtension(build_ext, object): def build_extensions(self): self._check_abi() - for extension in self.extensions: - # check settings of compiler - if isinstance(extension.extra_compile_args, dict): - for compiler in ['cxx', 'nvcc']: - if compiler not in extension.extra_compile_args: - extension.extra_compile_args[compiler] = [] - # add determine compile flags - add_compile_flag(extension, '-std=c++11') # Consider .cu, .cu.cc as valid source extensions. self.compiler.src_extensions += ['.cu', '.cu.cc'] # Save the original _compile method for later. - if self.compiler.compiler_type == 'msvc' or IS_WINDOWS: - raise NotImplementedError("Not support on MSVC currently.") + if self.compiler.compiler_type == 'msvc': + self.compiler._cpp_extensions += ['.cu', '.cuh'] + original_compile = self.compiler.compile + original_spawn = self.compiler.spawn else: original_compile = self.compiler._compile @@ -255,8 +386,8 @@ class BuildExtension(build_ext, object): # {'nvcc': {}, 'cxx: {}} if isinstance(cflags, dict): cflags = cflags['nvcc'] - else: - cflags = prepare_unix_cflags(cflags) + + cflags = prepare_unix_cudaflags(cflags) # cxx compile Cpp source elif isinstance(cflags, dict): cflags = cflags['cxx'] @@ -268,6 +399,81 @@ class BuildExtension(build_ext, object): # restore original_compiler self.compiler.compiler_so = original_compiler + def win_custom_single_compiler(sources, + output_dir=None, + macros=None, + include_dirs=None, + debug=0, + extra_preargs=None, + extra_postargs=None, + depends=None): + + self.cflags = copy.deepcopy(extra_postargs) + extra_postargs = None + + def win_custom_spawn(cmd): + # Using regex to modify compile options + compile_options = self.compiler.compile_options + for i in range(len(cmd)): + if re.search('/MD', cmd[i]) is not None: + cmd[i] = '/MT' + if re.search('/W[1-4]', cmd[i]) is not None: + cmd[i] = '/W0' + + # Using regex to match src, obj and include files + src_regex = re.compile('/T(p|c)(.*)') + src_list = [ + m.group(2) for m in (src_regex.match(elem) for elem in cmd) + if m + ] + + obj_regex = re.compile('/Fo(.*)') + obj_list = [ + m.group(1) for m in (obj_regex.match(elem) for elem in cmd) + if m + ] + + include_regex = re.compile(r'((\-|\/)I.*)') + include_list = [ + m.group(1) + for m in (include_regex.match(elem) for elem in cmd) if m + ] + + assert len(src_list) == 1 and len(obj_list) == 1 + src = src_list[0] + obj = obj_list[0] + if is_cuda_file(src): + assert CUDA_HOME is not None + nvcc_cmd = os.path.join(CUDA_HOME, 'bin', 'nvcc') + if isinstance(self.cflags, dict): + cflags = self.cflags['nvcc'] + elif isinstance(self.cflags, list): + cflags = self.cflags + else: + cflags = [] + + cflags = prepare_win_cudaflags(cflags) + ['--use-local-env'] + for flag in MSVC_COMPILE_FLAGS: + cflags = ['-Xcompiler', flag] + cflags + cmd = [nvcc_cmd, '-c', src, '-o', obj + ] + include_list + cflags + elif isinstance(self.cflags, dict): + cflags = MSVC_COMPILE_FLAGS + self.cflags['cxx'] + cmd += cflags + elif isinstance(self.cflags, list): + cflags = MSVC_COMPILE_FLAGS + self.cflags + cmd += cflags + + return original_spawn(cmd) + + try: + self.compiler.spawn = win_custom_spawn + return original_compile(sources, output_dir, macros, + include_dirs, debug, extra_preargs, + extra_postargs, depends) + finally: + self.compiler.spawn = original_spawn + def object_filenames_with_cuda(origina_func, build_directory): """ Decorated the function to add customized naming machanism. @@ -280,10 +486,13 @@ class BuildExtension(build_ext, object): objects = origina_func(source_filenames, strip_dir, output_dir) for i, source in enumerate(source_filenames): - # modify xx.o -> xx.cu.o + # modify xx.o -> xx.cu.o/xx.cu.obj if is_cuda_file(source): old_obj = objects[i] - objects[i] = old_obj[:-1] + 'cu.o' + if self.compiler.compiler_type == 'msvc': + objects[i] = old_obj[:-3] + 'cu.obj' + else: + objects[i] = old_obj[:-1] + 'cu.o' # if user set build_directory, output objects there. if build_directory is not None: objects = [ @@ -300,10 +509,13 @@ class BuildExtension(build_ext, object): return wrapper # customized compile process - self.compiler._compile = unix_custom_single_compiler + if self.compiler.compiler_type == 'msvc': + self.compiler.compile = win_custom_single_compiler + else: + self.compiler._compile = unix_custom_single_compiler + self.compiler.object_filenames = object_filenames_with_cuda( self.compiler.object_filenames, self.build_lib) - self._record_op_info() print("Compiling user custom op, it will cost a few seconds.....") @@ -333,15 +545,21 @@ class BuildExtension(build_ext, object): compiler = self.compiler.compiler_cxx[0] elif IS_WINDOWS: compiler = os.environ.get('CXX', 'cl') - raise NotImplementedError("We don't support Windows Currently.") else: compiler = os.environ.get('CXX', 'c++') check_abi_compatibility(compiler) + # Warn user if VC env is activated but `DISTUTILS_USE_SDK` is not set. + if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os.environ and 'DISTUTILS_USE_SDK' not in os.environ: + msg = ( + 'It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.' + 'This may lead to multiple activations of the VC env.' + 'Please set `DISTUTILS_USE_SDK=1` and try again.') + raise UserWarning(msg) def _record_op_info(self): """ - Record custum op inforomation. + Record custom op information. """ # parse shared library abs path outputs = self.get_outputs() @@ -380,16 +598,59 @@ class EasyInstallCommand(easy_install, object): # .so shared library to another name. for egg_file in self.outputs: filename, ext = os.path.splitext(egg_file) - if ext == '.so': + will_rename = False + if OS_NAME.startswith('linux') and ext == '.so': + will_rename = True + elif IS_WINDOWS and ext == '.pyd': + will_rename = True + + if will_rename: new_so_path = filename + "_pd_" + ext if not os.path.exists(new_so_path): os.rename(r'%s' % egg_file, r'%s' % new_so_path) assert os.path.exists(new_so_path) +class BuildCommand(build, object): + """ + Extend build Command to control the behavior of specifying `build_base` root directory. + + NOTE(Aurelius84): This is a hook subclass inherited Command used to specify customized + build_base directory. + """ + + @classmethod + def with_options(cls, **options): + """ + Returns a BuildCommand subclass containing use-defined options. + """ + + class cls_with_options(cls): + def __init__(self, *args, **kwargs): + kwargs.update(options) + cls.__init__(self, *args, **kwargs) + + return cls_with_options + + def __init__(self, *args, **kwargs): + # Note: shall put before super() + self._specified_build_base = kwargs.get('build_base', None) + + super(BuildCommand, self).__init__(*args, **kwargs) + + def initialize_options(self): + """ + build_base is root directory for all sub-command, such as + build_lib, build_temp. See `distutils.command.build` for details. + """ + super(BuildCommand, self).initialize_options() + if self._specified_build_base is not None: + self.build_base = self._specified_build_base + + def load(name, sources, - extra_cflags=None, + extra_cxx_cflags=None, extra_cuda_cflags=None, extra_ldflags=None, extra_include_paths=None, @@ -399,48 +660,86 @@ def load(name, """ An Interface to automatically compile C++/CUDA source files Just-In-Time and return callable python function as other Paddle layers API. It will - append user defined custom op in background. + append user defined custom operators in background while building models. + + It will perform compiling, linking, Python API generation and module loading + processes under a individual subprocess. It does not require CMake or Ninja environment + and only ``g++/nvcc`` on Linux and clang++ on MacOS. For example it requires + GCC compiler with version is greater than 5.4 and linked into ``/usr/bin/cc`` . + If compiling Operators supporting GPU device, please make sure ``nvcc`` compiler + is installed in local environment. + + + Moreover, `ABI compatibility `_ + will be checked to ensure that compiler version from ``cc`` + on local machine is compatible with pre-installed Paddle whl in python site-packages. + For example if Paddle with CUDA 10.1 is built with GCC 8.2, then the version of user's + local machine should satisfy GCC >= 8.2. Otherwise, a fatal error will occur because of + ABI compatibility. + + Compared with ``setup`` interface, it doesn't need extra ``setup.py`` and excute + ``python setup.py install`` command. The interface contains all compiling and installing + process underground. + + .. note:: + + 1. Compiler ABI compatibility is forward compatible. On Linux platform, + we recommend to use GCC 8.2 as soft linking condidate of ``/usr/bin/cc`` . + 2. Using ``which cc`` to ensure location of ``cc`` and using ``cc --version`` + to ensure linking GCC version on Linux. + 3. Currenly we support Linux and Windows platfrom. MacOS is supporting... + + + **A simple example:** + + .. code-block:: text + + import paddle + from paddle.utils.cpp_extension import load + + custom_op_module = load( + name="op_shared_libary_name", # name of shared library + sources=['relu_op.cc', 'relu_op.cu'], # source files of cusomized op + extra_cxx_cflags=['-DPADDLE_WITH_MKLDNN'], # need to specify the flag if pre-installed Paddle supports MKLDNN + extra_cuda_cflags=['-DPADDLE_WITH_MKLDNN'], # need to specify the flag if pre-installed Paddle supports MKLDNN + interpreter='python3.7', # optional, specify another python interpreter + verbose=True # output log information + ) + + x = paddle.randn([4, 10], dtype='float32') + out = custom_op_module.relu(x) - This module will perform compiling, linking, api generation and module loading - processes for users. It does not require CMake or Ninja environment and only - g++/nvcc on Linux and clang++ on MacOS. Moreover, ABI compatibility will be - checked to ensure that compiler version on local machine is compatible with - pre-installed Paddle whl in python site-packages. For example if Paddle is built - with GCC5.4, the version of user's local machine should satisfy GCC >= 5.4. - Otherwise, a fatal error will occur because ABI compatibility. Args: - name(str): generated shared library file name. - sources(list[str]): custom op source files name with .cc/.cu suffix. - extra_cflag(list[str]): additional flags used to compile CPP files. By default + name(str): Specify the name of generated shared library file name, not including ``.so`` and ``.dll`` suffix. + sources(list[str]): Specify source files name of customized operators. Supporting ``.cc`` , ``.cpp`` for CPP file + and ``.cu`` for CUDA file. + extra_cxx_cflags(list[str], optional): Specify additional flags used to compile CPP files. By default all basic and framework related flags have been included. If your pre-insall Paddle supported MKLDNN, please add - '-DPADDLE_WITH_MKLDNN'. Default None. - extra_cuda_cflags(list[str]): additonal flags used to compile CUDA files. See - https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html - for details. Default None. - extra_ldflags(list[str]): additonal flags used to link shared library. See - https://gcc.gnu.org/onlinedocs/gcc/Link-Options.html for details. - Default None. - extra_include_paths(list[str]): additional include path used to search header files. - Default None. - build_directory(str): specific directory path to put shared library file. If set None, - it will use `PADDLE_EXTENSION_DIR` from os.environ. Use - `paddle.utils.cpp_extension.get_build_directory()` to see the location. - interpreter(str): alias or full interpreter path to specific which one to use if have installed multiple. - If set None, will use `python` as default interpreter. - verbose(bool): whether to verbose compiled log information + ``-DPADDLE_WITH_MKLDNN`` . Default is None. + extra_cuda_cflags(list[str], optional): Specify additional flags used to compile CUDA files. By default + all basic and framework related flags have been included. If your pre-insall Paddle supported MKLDNN, + please add ``-DPADDLE_WITH_MKLDNN`` . Default None. See `Cuda Compiler Driver NVCC `_ + for details. Default is None. + extra_ldflags(list[str], optional): Specify additional flags used to link shared library. See + `GCC Link Options `_ for details. + Default is None. + extra_include_paths(list[str], optional): Specify additional include path used to search header files. By default + all basic headers are included implicitly from ``site-package/paddle/include`` . + Default is None. + build_directory(str, optional): Specify root directory path to put shared library file. If set None, + it will use ``PADDLE_EXTENSION_DIR`` from os.environ. Use + ``paddle.utils.cpp_extension.get_build_directory()`` to see the location. Default is None. + interpreter(str, optional): Specify nterpreter path, supporting alias and full path. + If set None, it will use `python` as default interpreter. If local environment contains + more than one python interpreters and want to use new interpreter to apply compilation, + please specify this parameter, such as ``python3.7`` . Default is None. + verbose(bool, optional): whether to verbose compiled log information. Default is False Returns: - custom api: A callable python function with same signature as CustomOp Kernel defination. - - Example: + Moudle: A callable python module contains all CustomOp Layer APIs. - >> from paddle.utils.cpp_extension import load - >> relu2 = load(name='relu2', - sources=['relu_op.cc', 'relu_op.cu']) - >> x = paddle.rand([4, 10]], dtype='float32') - >> out = relu2(x) """ if build_directory is None: @@ -448,24 +747,37 @@ def load(name, # ensure to use abs path build_directory = os.path.abspath(build_directory) + # Will load shared library from 'path' on windows + if IS_WINDOWS: + os.environ['path'] = build_directory + ';' + os.environ['path'] + log_v("build_directory: {}".format(build_directory), verbose) - file_path = os.path.join(build_directory, "setup.py") + file_path = os.path.join(build_directory, "{}_setup.py".format(name)) sources = [os.path.abspath(source) for source in sources] - # TODO(Aurelius84): split cflags and cuda_flags - if extra_cflags is None: extra_cflags = [] + if extra_cxx_cflags is None: extra_cxx_cflags = [] if extra_cuda_cflags is None: extra_cuda_cflags = [] - compile_flags = extra_cflags + extra_cuda_cflags - log_v("additonal compile_flags: [{}]".format(' '.join(compile_flags)), - verbose) - - # write setup.py file and compile it - _write_setup_file(name, sources, file_path, extra_include_paths, - compile_flags, extra_ldflags, verbose) + assert isinstance( + extra_cxx_cflags, list + ), "Required type(extra_cxx_cflags) == list[str], but received {}".format( + extra_cxx_cflags) + assert isinstance( + extra_cuda_cflags, list + ), "Required type(extra_cuda_cflags) == list[str], but received {}".format( + extra_cuda_cflags) + + log_v("additional extra_cxx_cflags: [{}], extra_cuda_cflags: [{}]".format( + ' '.join(extra_cxx_cflags), ' '.join(extra_cuda_cflags)), verbose) + + # write setup.py file and compile it + build_base_dir = os.path.join(build_directory, name) + _write_setup_file(name, sources, file_path, build_base_dir, + extra_include_paths, extra_cxx_cflags, extra_cuda_cflags, + extra_ldflags, verbose) _jit_compile(file_path, interpreter, verbose) # import as callable python api - custom_op_api = _import_module_from_library(name, build_directory, verbose) + custom_op_api = _import_module_from_library(name, build_base_dir, verbose) return custom_op_api diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 52c17d77bd4771ce44f2282adfd9a25394ce97ea..db2da5574854c27267dd568d4eed1432acd5353f 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -16,7 +16,6 @@ import os import re import six import sys -import copy import glob import logging import collections @@ -38,11 +37,17 @@ logger = logging.getLogger("utils.cpp_extension") OS_NAME = sys.platform IS_WINDOWS = OS_NAME.startswith('win') -NVCC_COMPILE_FLAGS = [ - '-ccbin', 'cc', '-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU', '-DPADDLE_USE_DSO', - '-Xcompiler', '-fPIC', '-w', '--expt-relaxed-constexpr', '-O3', '-DNVCC' + +MSVC_COMPILE_FLAGS = [ + '/MT', '/wd4819', '/wd4251', '/wd4244', '/wd4267', '/wd4275', '/wd4018', + '/wd4190', '/EHsc', '/w', '/DGOOGLE_GLOG_DLL_DECL', + '/DBOOST_HAS_STATIC_ASSERT', '/DNDEBUG', '/DPADDLE_USE_DSO' ] +MSVC_LINK_FLAGS = ['/MACHINE:X64', 'paddle_framework.lib'] + +COMMON_NVCC_FLAGS = ['-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU', '-O3'] + GCC_MINI_VERSION = (5, 4, 0) # Give warning if using wrong compiler WRONG_COMPILER_WARNING = ''' @@ -80,9 +85,17 @@ information ''' USING_NEW_CUSTOM_OP_LOAD_METHOD = True +DEFAULT_OP_ATTR_NAMES = [ + core.op_proto_and_checker_maker.kOpRoleAttrName(), + core.op_proto_and_checker_maker.kOpRoleVarAttrName(), + core.op_proto_and_checker_maker.kOpNameScopeAttrName(), + core.op_proto_and_checker_maker.kOpCreationCallstackAttrName(), + core.op_proto_and_checker_maker.kOpDeviceAttrName() +] -# NOTE(chenweihang): In order to be compatible with -# the two custom op define method, after removing + +# NOTE(chenweihang): In order to be compatible with +# the two custom op define method, after removing # old method, we can remove them together def use_new_custom_op_load_method(*args): global USING_NEW_CUSTOM_OP_LOAD_METHOD @@ -206,11 +219,23 @@ class CustomOpInfo: return next(reversed(self.op_info_map.items())) -def prepare_unix_cflags(cflags): +def prepare_unix_cudaflags(cflags): """ Prepare all necessary compiled flags for nvcc compiling CUDA files. """ - cflags = NVCC_COMPILE_FLAGS + cflags + get_cuda_arch_flags(cflags) + cflags = COMMON_NVCC_FLAGS + [ + '-ccbin', 'cc', '-Xcompiler', '-fPIC', '-w', '--expt-relaxed-constexpr', + '-DNVCC' + ] + cflags + get_cuda_arch_flags(cflags) + + return cflags + + +def prepare_win_cudaflags(cflags): + """ + Prepare all necessary compiled flags for nvcc compiling CUDA files. + """ + cflags = COMMON_NVCC_FLAGS + ['-w'] + cflags + get_cuda_arch_flags(cflags) return cflags @@ -238,13 +263,14 @@ def get_cuda_arch_flags(cflags): def normalize_extension_kwargs(kwargs, use_cuda=False): - """ + """ Normalize include_dirs, library_dir and other attributes in kwargs. """ assert isinstance(kwargs, dict) # append necessary include dir path of paddle include_dirs = kwargs.get('include_dirs', []) include_dirs.extend(find_paddle_includes(use_cuda)) + kwargs['include_dirs'] = include_dirs # append necessary lib path of paddle @@ -252,50 +278,46 @@ def normalize_extension_kwargs(kwargs, use_cuda=False): library_dirs.extend(find_paddle_libraries(use_cuda)) kwargs['library_dirs'] = library_dirs - # add runtime library dirs - runtime_library_dirs = kwargs.get('runtime_library_dirs', []) - runtime_library_dirs.extend(find_paddle_libraries(use_cuda)) - kwargs['runtime_library_dirs'] = runtime_library_dirs - - # append compile flags + # append compile flags and check settings of compiler extra_compile_args = kwargs.get('extra_compile_args', []) - extra_compile_args.extend(['-g', '-w']) # diable warnings - kwargs['extra_compile_args'] = extra_compile_args - - # append link flags - extra_link_args = kwargs.get('extra_link_args', []) - extra_link_args.append('-lpaddle_framework') - if use_cuda: - extra_link_args.append('-lcudart') - - kwargs['extra_link_args'] = extra_link_args - - kwargs['language'] = 'c++' - return kwargs - - -def find_paddle_includes(use_cuda=False): - """ - Return Paddle necessary include dir path. - """ - # pythonXX/site-packages/paddle/include - paddle_include_dir = get_include() - third_party_dir = os.path.join(paddle_include_dir, 'third_party') - - include_dirs = [paddle_include_dir, third_party_dir] + if isinstance(extra_compile_args, dict): + for compiler in ['cxx', 'nvcc']: + if compiler not in extra_compile_args: + extra_compile_args[compiler] = [] + + if IS_WINDOWS: + # TODO(zhouwei): may append compile flags in future + pass + # append link flags + extra_link_args = kwargs.get('extra_link_args', []) + extra_link_args.extend(MSVC_LINK_FLAGS) + if use_cuda: + extra_link_args.extend(['cudadevrt.lib', 'cudart_static.lib']) + kwargs['extra_link_args'] = extra_link_args + else: + # append compile flags + add_compile_flag(extra_compile_args, ['-g', '-w']) # disable warnings - return include_dirs + # append link flags + extra_link_args = kwargs.get('extra_link_args', []) + if use_new_custom_op_load_method(): + extra_link_args.append('-lpaddle_custom_op') + else: + extra_link_args.append('-lpaddle_framework') + if use_cuda: + extra_link_args.append('-lcudart') + kwargs['extra_link_args'] = extra_link_args -def find_cuda_includes(): + # add runtime library dirs + runtime_library_dirs = kwargs.get('runtime_library_dirs', []) + runtime_library_dirs.extend(find_paddle_libraries(use_cuda)) + kwargs['runtime_library_dirs'] = runtime_library_dirs - cuda_home = find_cuda_home() - if cuda_home is None: - raise ValueError( - "Not found CUDA runtime, please use `export CUDA_HOME=XXX` to specific it." - ) + kwargs['extra_compile_args'] = extra_compile_args - return [os.path.join(cuda_home, 'lib64')] + kwargs['language'] = 'c++' + return kwargs def find_cuda_home(): @@ -315,19 +337,22 @@ def find_cuda_home(): if six.PY3: nvcc_path = nvcc_path.decode() nvcc_path = nvcc_path.rstrip('\r\n') + # for example: /usr/local/cuda/bin/nvcc cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) except: if IS_WINDOWS: # search from default NVIDIA GPU path candidate_paths = glob.glob( - 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') + 'C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*.*' + ) if len(candidate_paths) > 0: cuda_home = candidate_paths[0] else: cuda_home = "/usr/local/cuda" # step 3. check whether path is valid - if not os.path.exists(cuda_home) and core.is_compiled_with_cuda(): + if cuda_home and not os.path.exists( + cuda_home) and core.is_compiled_with_cuda(): cuda_home = None warnings.warn( "Not found CUDA runtime, please use `export CUDA_HOME= XXX` to specific it." @@ -336,27 +361,73 @@ def find_cuda_home(): return cuda_home +def find_cuda_includes(): + """ + Use heuristic method to find cuda include path + """ + cuda_home = find_cuda_home() + if cuda_home is None: + raise ValueError( + "Not found CUDA runtime, please use `export CUDA_HOME=XXX` to specific it." + ) + + return [os.path.join(cuda_home, 'include')] + + +def find_paddle_includes(use_cuda=False): + """ + Return Paddle necessary include dir path. + """ + # pythonXX/site-packages/paddle/include + paddle_include_dir = get_include() + third_party_dir = os.path.join(paddle_include_dir, 'third_party') + include_dirs = [paddle_include_dir, third_party_dir] + + if use_cuda: + cuda_include_dir = find_cuda_includes() + include_dirs.extend(cuda_include_dir) + + return include_dirs + + +def find_cuda_libraries(): + """ + Use heuristic method to find cuda static lib path + """ + cuda_home = find_cuda_home() + if cuda_home is None: + raise ValueError( + "Not found CUDA runtime, please use `export CUDA_HOME=XXX` to specific it." + ) + if IS_WINDOWS: + cuda_lib_dir = [os.path.join(cuda_home, 'lib', 'x64')] + else: + cuda_lib_dir = [os.path.join(cuda_home, 'lib64')] + + return cuda_lib_dir + + def find_paddle_libraries(use_cuda=False): """ Return Paddle necessary library dir path. """ # pythonXX/site-packages/paddle/libs paddle_lib_dirs = [get_lib()] + if use_cuda: - cuda_dirs = find_cuda_includes() - paddle_lib_dirs.extend(cuda_dirs) + cuda_lib_dir = find_cuda_libraries() + paddle_lib_dirs.extend(cuda_lib_dir) + return paddle_lib_dirs -def add_compile_flag(extension, flag): - extra_compile_args = copy.deepcopy(extension.extra_compile_args) +def add_compile_flag(extra_compile_args, flags): + assert isinstance(flags, list) if isinstance(extra_compile_args, dict): for args in extra_compile_args.values(): - args.append(flag) + args.extend(flags) else: - extra_compile_args.append(flag) - - extension.extra_compile_args = extra_compile_args + extra_compile_args.extend(flags) def is_cuda_file(path): @@ -369,17 +440,34 @@ def is_cuda_file(path): def get_build_directory(verbose=False): """ - Return paddle extension root directory, default specific by `PADDLE_EXTENSION_DIR` + Return paddle extension root directory to put shared library. It could be specified by + ``export PADDLE_EXTENSION_DIR=XXX`` . If not set, ``~/.cache/paddle_extension`` will be used + by default. + + Returns: + The root directory of compiling customized operators. + + Examples: + + .. code-block:: python + + from paddle.utils.cpp_extension import get_build_directory + + build_dir = get_build_directory() + print(build_dir) + """ root_extensions_directory = os.environ.get('PADDLE_EXTENSION_DIR') if root_extensions_directory is None: dir_name = "paddle_extensions" - if OS_NAME.startswith('linux'): - root_extensions_directory = os.path.join( - os.path.expanduser('~/.cache'), dir_name) - else: - # TODO(Aurelius84): consider wind32/macOs - raise NotImplementedError("Only support Linux now.") + root_extensions_directory = os.path.join( + os.path.expanduser('~/.cache'), dir_name) + if IS_WINDOWS: + root_extensions_directory = os.path.normpath( + root_extensions_directory) + elif OS_NAME.startswith('darwin'): + # TODO(Aurelius84): consider macOs + raise NotImplementedError("Not support Mac now.") log_v("$PADDLE_EXTENSION_DIR is not set, using path: {} by default.". format(root_extensions_directory), verbose) @@ -404,16 +492,22 @@ def parse_op_info(op_name): in_names = [x.name for x in op_proto.inputs] out_names = [x.name for x in op_proto.outputs] + attr_names = [ + x.name for x in op_proto.attrs if x.name not in DEFAULT_OP_ATTR_NAMES + ] - return in_names, out_names + return in_names, out_names, attr_names def _import_module_from_library(module_name, build_directory, verbose=False): """ - Load .so shared library and import it as callable python module. + Load shared library and import it as callable python module. """ - # TODO(Aurelius84): Consider file suffix is .dll on Windows Platform. - ext_path = os.path.join(build_directory, module_name + '.so') + if IS_WINDOWS: + dynamic_suffix = '.pyd' + else: + dynamic_suffix = '.so' + ext_path = os.path.join(build_directory, module_name + dynamic_suffix) if not os.path.exists(ext_path): raise FileNotFoundError("Extension path: {} does not exist.".format( ext_path)) @@ -448,7 +542,7 @@ def _generate_python_module(module_name, def _custom_api_content(op_name): - params_str, ins_str, outs_str = _get_api_inputs_str(op_name) + params_str, ins_str, attrs_str, outs_str = _get_api_inputs_str(op_name) API_TEMPLATE = textwrap.dedent(""" from paddle.fluid.layer_helper import LayerHelper @@ -456,8 +550,9 @@ def _custom_api_content(op_name): def {op_name}({inputs}): helper = LayerHelper("{op_name}", **locals()) - # prepare inputs and output + # prepare inputs and outputs ins = {ins} + attrs = {attrs} outs = {{}} out_names = {out_names} for out_name in out_names: @@ -465,7 +560,7 @@ def _custom_api_content(op_name): # in runtime. outs[out_name] = helper.create_variable(dtype='float32') - helper.append_op(type="{op_name}", inputs=ins, outputs=outs) + helper.append_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs) res = [outs[out_name] for out_name in out_names] @@ -474,7 +569,11 @@ def _custom_api_content(op_name): # generate python api file api_content = API_TEMPLATE.format( - op_name=op_name, inputs=params_str, ins=ins_str, out_names=outs_str) + op_name=op_name, + inputs=params_str, + ins=ins_str, + attrs=attrs_str, + out_names=outs_str) return api_content @@ -505,22 +604,30 @@ def _get_api_inputs_str(op_name): """ Returns string of api parameters and inputs dict. """ - in_names, out_names = parse_op_info(op_name) + in_names, out_names, attr_names = parse_op_info(op_name) # e.g: x, y, z - params_str = ','.join([p.lower() for p in in_names]) + param_names = in_names + attr_names + params_str = ','.join([p.lower() for p in param_names]) # e.g: {'X': x, 'Y': y, 'Z': z} ins_str = "{%s}" % ','.join( ["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names]) + # e.g: {'num': n} + attrs_str = "{%s}" % ",".join([ + "'{}' : {}".format(attr_name, attr_name.lower()) + for attr_name in attr_names + ]) # e.g: ['Out', 'Index'] outs_str = "[%s]" % ','.join(["'{}'".format(name) for name in out_names]) - return params_str, ins_str, outs_str + return params_str, ins_str, attrs_str, outs_str def _write_setup_file(name, sources, file_path, + build_dir, include_dirs, - compile_flags, + extra_cxx_cflags, + extra_cuda_cflags, link_args, verbose=False): """ @@ -530,18 +637,21 @@ def _write_setup_file(name, import os from paddle.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension, setup from paddle.utils.cpp_extension import get_build_directory + from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method + + use_new_custom_op_load_method({use_new_method}) + setup( name='{name}', ext_modules=[ {prefix}Extension( sources={sources}, include_dirs={include_dirs}, - extra_compile_args={extra_compile_args}, + extra_compile_args={{'cxx':{extra_cxx_cflags}, 'nvcc':{extra_cuda_cflags}}}, extra_link_args={extra_link_args})], cmdclass={{"build_ext" : BuildExtension.with_options( - output_dir=get_build_directory(), - no_python_abi_suffix=True, - use_new_method={use_new_method}) + output_dir=r'{build_dir}', + no_python_abi_suffix=True) }})""").lstrip() with_cuda = False @@ -554,8 +664,10 @@ def _write_setup_file(name, prefix='CUDA' if with_cuda else 'Cpp', sources=list2str(sources), include_dirs=list2str(include_dirs), - extra_compile_args=list2str(compile_flags), + extra_cxx_cflags=list2str(extra_cxx_cflags), + extra_cuda_cflags=list2str(extra_cuda_cflags), extra_link_args=list2str(link_args), + build_dir=build_dir, use_new_method=use_new_custom_op_load_method()) log_v('write setup.py into {}'.format(file_path), verbose) @@ -565,12 +677,12 @@ def _write_setup_file(name, def list2str(args): """ - Convert list[str] into string. For example: [x, y] -> "['x', 'y']" + Convert list[str] into string. For example: ['x', 'y'] -> "['x', 'y']" """ if args is None: return '[]' assert isinstance(args, (list, tuple)) - args = ["'{}'".format(arg) for arg in args] - return '[' + ','.join(args) + ']' + args = ["{}".format(arg) for arg in args] + return repr(args) def _jit_compile(file_path, interpreter=None, verbose=False): @@ -583,7 +695,8 @@ def _jit_compile(file_path, interpreter=None, verbose=False): if interpreter is None: interpreter = 'python' try: - py_path = subprocess.check_output(['which', interpreter]) + which = 'where' if IS_WINDOWS else 'which' + py_path = subprocess.check_output([which, interpreter]) py_version = subprocess.check_output([interpreter, '-V']) if six.PY3: py_path = py_path.decode() @@ -596,8 +709,13 @@ def _jit_compile(file_path, interpreter=None, verbose=False): 'Failed to check Python interpreter with `{}`, errors: {}'.format( interpreter, error)) - compile_cmd = 'cd {} && {} {} build'.format(ext_dir, interpreter, - setup_file) + if IS_WINDOWS: + compile_cmd = 'cd /d {} && {} {} build'.format(ext_dir, interpreter, + setup_file) + else: + compile_cmd = 'cd {} && {} {} build'.format(ext_dir, interpreter, + setup_file) + print("Compiling user custom op, it will cost a few seconds.....") run_cmd(compile_cmd, verbose) @@ -682,7 +800,7 @@ def check_abi_compatibility(compiler, verbose=False): try: if OS_NAME.startswith('linux'): version_info = subprocess.check_output( - [compiler, '-dumpfullversion']) + [compiler, '-dumpfullversion', '-dumpversion']) if six.PY3: version_info = version_info.decode() version = version_info.strip().split('.') @@ -694,8 +812,8 @@ def check_abi_compatibility(compiler, verbose=False): warnings.warn( ABI_INCOMPATIBILITY_WARNING.format( user_compiler=compiler, version=version_info.strip())) - # TODO(Aurelius84): check version compatibility on windows elif IS_WINDOWS: + # TODO(zhouwei): support check abi compatibility on windows warnings.warn("We don't support Windows now.") except Exception: _, error, _ = sys.exc_info() @@ -714,7 +832,7 @@ def _expected_compiler_current_platform(): return expect_compilers -def log_v(info, verbose): +def log_v(info, verbose=True): """ Print log information on stdout. """ diff --git a/python/requirements.txt b/python/requirements.txt index e2a3a652c7f5c9036c0160dabfc553bcbb11861b..e89b3ede94fd4a624b3ddc335f5d2ea6e7b20b8a 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -3,7 +3,8 @@ numpy>=1.13, <=1.16.4 ; python_version<"3.5" numpy>=1.13 ; python_version>="3.5" and platform_system != "Windows" numpy>=1.13, <=1.19.3 ; python_version>="3.5" and platform_system == "Windows" protobuf>=3.1.0 -gast==0.3.3 +gast>=0.3.3 ; platform_system != "Windows" +gast==0.3.3 ; platform_system == "Windows" Pillow six decorator diff --git a/python/setup.py.in b/python/setup.py.in index f662e21a7be2dc6b1f8faa6cf1f08c49c14ee095..aa18c053858321db363adadb0c98b42add8b8a8c 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -334,11 +334,21 @@ if '${WITH_XPU_BKCL}' == 'ON': shutil.copy('${XPU_BKCL_LIB}', libs_path) package_data['paddle.libs']+=['${XPU_BKCL_LIB_NAME}'] -# copy libfuild_framework.so to libs -if os.name != 'nt' and sys.platform != 'darwin': - paddle_framework_lib='${FLUID_FRAMEWORK_SHARED_LIB}' - shutil.copy(paddle_framework_lib, libs_path) - package_data['paddle.libs'] += [('libpaddle_framework' if os.name != 'nt' else 'paddle_framework') + ext_name] +# copy libpaddle_framework.so to libs on linux +if sys.platform.startswith('linux'): + shutil.copy('${FLUID_FRAMEWORK_SHARED_LIB}', libs_path) + package_data['paddle.libs'] += ['libpaddle_framework.so'] + +# copy libpaddle_custom_op.so to libs on linux +if sys.platform.startswith('linux'): + shutil.copy('${PADDLE_CUSTOM_OP_SHARED_LIB}', libs_path) + package_data['paddle.libs'] += ['libpaddle_custom_op.so'] + +# copy paddle_framework.lib/paddle_framework.dll to libs on windows +if os.name == 'nt': + shutil.copy('${FLUID_FRAMEWORK_IMPORT_LIB}', libs_path) + shutil.copy('${FLUID_FRAMEWORK_SHARED_LIB}', libs_path) + package_data['paddle.libs'] += ['paddle_framework.lib', 'paddle_framework.dll'] # remove unused paddle/libs/__init__.py if os.path.isfile(libs_path+'/__init__.py'): @@ -409,9 +419,9 @@ if '${WITH_GPU}' == 'ON': class InstallCommand(InstallCommandBase): def finalize_options(self): ret = InstallCommandBase.finalize_options(self) - self.install_headers = os.path.join(self.install_purelib, 'paddle', - 'include') self.install_lib = self.install_platlib + self.install_headers = os.path.join(self.install_platlib, 'paddle', + 'include') return ret @@ -462,11 +472,6 @@ class InstallHeaders(Command): return self.copy_file(header, install_dir) def run(self): - # only copy third_party/cudaErrorMessage.pb for cudaErrorMessage on mac or windows - if os.name == 'nt' or sys.platform == 'darwin': - if '${WITH_GPU}' == 'ON': - self.mkdir_and_copy_file('${cudaerror_INCLUDE_DIR}/cudaErrorMessage.pb') - return hdrs = self.distribution.headers if not hdrs: return