未验证 提交 04a49b09 编写于 作者: Z Zhou Wei 提交者: GitHub

[Custom OP]Remove old custom OP and reduce whl package volume (#31813)

* Remove old custom OP to reduce whl package volume

* [Custom OP]Remove old custom OP to reduce whl package volume
上级 fe284868
......@@ -360,46 +360,11 @@ set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_prot
cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES})
# Old custom op extension mechanism related, will be removed in 2.1.0
cc_library(paddle_framework_shared
SHARED SRCS executor.cc operator.cc
${CMAKE_CURRENT_SOURCE_DIR}/c/c_api.cc
${CMAKE_SOURCE_DIR}/paddle/fluid/imperative/layer.cc
DEPS ${FLUID_FRAMEWORK_MODULES})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
set_target_properties(paddle_framework_shared PROPERTIES OUTPUT_NAME paddle_framework)
target_link_libraries(paddle_framework_shared ${os_dependency_modules})
if (LINUX)
set(FLUID_FRAMEWORK_SHARED_LIB
${PADDLE_BINARY_DIR}/paddle/fluid/framework/libpaddle_framework.so
CACHE INTERNAL "Fluid framework lib")
endif()
if (WIN32)
if("${CMAKE_GENERATOR}" STREQUAL "Ninja")
set(paddle_framework_lib_path ${CMAKE_CURRENT_BINARY_DIR})
else()
set(paddle_framework_lib_path ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE})
endif()
set(FLUID_FRAMEWORK_IMPORT_LIB
${paddle_framework_lib_path}/paddle_framework.lib
CACHE INTERNAL "Fluid framework lib")
set(FLUID_FRAMEWORK_SHARED_LIB
${paddle_framework_lib_path}/paddle_framework.dll
CACHE INTERNAL "Fluid framework dll")
endif()
if(APPLE)
set(FLUID_FRAMEWORK_SHARED_LIB
${PADDLE_BINARY_DIR}/paddle/fluid/framework/libpaddle_framework.dylib
CACHE INTERNAL "Fluid framework lib")
endif()
if(WITH_TESTING AND TEST selected_rows_test)
set_tests_properties(selected_rows_test PROPERTIES TIMEOUT 120)
endif()
# New custom op extension mechanism related
##### 2.0 New custom op extension mechanism related #####
# if not deps `layer`, will cause: undefined symbol: _ZN6paddle10imperative7VarBase9name_set_
set(PADDLE_CUSTOM_OP_MODULES custom_tensor op_meta_info custom_operator layer)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/c/c_api.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
extern "C" {
paddle::framework::OpInfoMap &PD_GetOpInfoMap() {
return paddle::framework::OpInfoMap::Instance();
}
void PD_InitDevicesPool(paddle::platform::DeviceContextPool *pool) {
paddle::platform::DeviceContextPool::SetPool(pool);
}
std::vector<std::string> PD_GetGradOpDescStrs(
const paddle::framework::OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> *grad_to_var,
const std::vector<paddle::framework::BlockDesc *> &grad_block) {
auto &op_info = PD_GetOpInfoMap().Get(op_desc.Type());
std::vector<std::string> ret;
if (op_info.grad_op_maker_) {
auto grad_op_descs =
op_info.grad_op_maker_(op_desc, no_grad_set, grad_to_var, grad_block);
size_t op_num = grad_op_descs.size();
ret.resize(op_num);
for (size_t i = 0; i < op_num; ++i) {
PADDLE_ENFORCE_EQ(
grad_op_descs[i]->Proto()->SerializePartialToString(&ret[i]), true,
paddle::platform::errors::Unavailable(
"Cannot serialize operator desc message."));
}
}
return ret;
}
} // end extern "C"
/* copyright (c) 2019 paddlepaddle authors. all rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
class OpInfoMap;
} // namespace framework
namespace platform {
class DeviceContextPool;
} // namespace platform
} // namespace paddle
#ifdef __cplusplus
extern "C" {
#endif
// C-API to get global OpInfo map.
paddle::framework::OpInfoMap &PD_GetOpInfoMap();
// C-API to init global DeviceContextPool from outside.
void PD_InitDevicesPool(paddle::platform::DeviceContextPool *pool);
// C-API to serialize the grad op protocol message to a binary string.
std::vector<std::string> PD_GetGradOpDescStrs(
const paddle::framework::OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> *grad_to_var,
const std::vector<paddle::framework::BlockDesc *> &grad_block);
#ifdef __cplusplus
}
#endif
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace framework {
template <typename T>
T *DynLoad(void *handle, std::string name) {
T *func = reinterpret_cast<T *>(dlsym(handle, name.c_str()));
#if !defined(_WIN32)
auto errorno = dlerror();
#else
auto errorno = GetLastError();
#endif // !_WIN32
PADDLE_ENFORCE_NOT_NULL(
func,
platform::errors::NotFound(
"Failed to load dynamic operator library, error code(%s).", errorno));
return func;
}
void LoadOpLib(const std::string &dso_name) {
void *handle = paddle::platform::dynload::GetOpDsoHandle(dso_name);
typedef OpInfoMap &get_op_info_t();
get_op_info_t *get_op_info =
DynLoad<get_op_info_t>(handle, "PD_GetOpInfoMap");
auto &op_info = get_op_info();
auto *dyn_info_map = op_info.mutable_map();
typedef std::vector<std::string> grad_op_desc_maker_t(
const OpDesc &, const std::unordered_set<std::string> &,
std::unordered_map<std::string, std::string> *,
const std::vector<BlockDesc *> &);
grad_op_desc_maker_t *grad_op_desc_maker =
DynLoad<grad_op_desc_maker_t>(handle, "PD_GetGradOpDescStrs");
auto &info_map = OpInfoMap::Instance();
for (const auto &n : *(dyn_info_map)) {
auto type = n.first;
if (type == "recurrent" || type == "recurrent_grad" ||
type == "conditional_block" || type == "conditional_block_grad") {
continue;
}
PADDLE_ENFORCE_NE(info_map.Has(n.first), true,
platform::errors::AlreadyExists(
"Operator (%s) has been registered.", type));
OpInfo info;
info.creator_ = n.second.creator_;
// If get the protocol buffer from dynamic library directly, there
// will be deconstruction error
// ** Error in `python`: free(): invalid pointer:
// ... paddle::framework::proto::OpDesc::SharedDtor()
// It seems a bug in protobuf, see
// https://github.com/protocolbuffers/protobuf/issues/435
// So, get the serialized binary string from dynamic library,
// then deserialize to protocol buffer.
info.grad_op_maker_ = [grad_op_desc_maker](
const OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> *grad_to_var,
const std::vector<BlockDesc *> &grad_block) {
std::vector<std::string> strs =
grad_op_desc_maker(op_desc, no_grad_set, grad_to_var, grad_block);
std::vector<std::unique_ptr<OpDesc>> ret;
for (auto &str : strs) {
proto::OpDesc proto_desc;
PADDLE_ENFORCE_EQ(proto_desc.ParseFromString(str), true,
platform::errors::InvalidArgument(
"Failed to parse OpDesc from string."));
ret.emplace_back(new OpDesc(proto_desc, nullptr));
}
return ret;
};
info.proto_ = n.second.proto_;
info.checker_ = n.second.checker_;
info.infer_var_type_ = n.second.infer_var_type_;
info.infer_shape_ = n.second.infer_shape_;
info.infer_inplace_ = n.second.infer_inplace_;
info.infer_no_need_buffer_vars_ = n.second.infer_no_need_buffer_vars_;
info.use_default_grad_op_desc_maker_ =
n.second.use_default_grad_op_desc_maker_;
info.use_empty_grad_op_desc_maker_ = n.second.use_empty_grad_op_desc_maker_;
info_map.Insert(type, info);
}
typedef void init_device_t(platform::DeviceContextPool *);
init_device_t *init_dev =
DynLoad<init_device_t>(handle, "PD_InitDevicesPool");
init_dev(&(platform::DeviceContextPool::Instance()));
}
} // namespace framework
} // namespace paddle
......@@ -33,7 +33,6 @@ limitations under the License. */
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/load_op_lib.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
......@@ -1752,7 +1751,6 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_gflags", framework::InitGflags);
m.def("init_glog", framework::InitGLOG);
m.def("load_op_library", framework::LoadOpLib);
m.def("load_op_meta_info_and_register_op",
framework::LoadOpMetaInfoAndRegisterOp);
m.def("init_devices", []() { framework::InitDevices(); });
......
......@@ -53,7 +53,6 @@ __all__ = [
'is_compiled_with_cuda',
'is_compiled_with_xpu',
'Variable',
'load_op_library',
'require_version',
'device_guard',
'set_flags',
......@@ -5771,33 +5770,6 @@ def _dygraph_place_guard(place):
_set_dygraph_tracer_expected_place(tmp_place)
def load_op_library(lib_filename):
"""
:api_attr: Static Graph
Load a dynamic library, including custom operators and kernels.
When library is loaded, ops and kernels registered in the library
will be available in PaddlePaddle main process.
Please note, the type of custom operators can't have the same type
with the existing operators in the framework.
Args:
lib_filename (str): name of dynamic library.
Returns:
list[str]: new registered custom op names.
Examples:
.. code-block:: python
import paddle.fluid as fluid
#fluid.load_op_library('custom_op.so')
"""
core.load_op_library(lib_filename)
return OpProtoHolder.instance().update_op_proto()
def switch_device(device):
global _current_device
pre_device = _current_device
......
......@@ -9,7 +9,8 @@ endforeach()
add_subdirectory(unittests)
add_subdirectory(book)
# TODO: support New Custom OP on Mac
# 2.0 New custom OP can support Windows/Linux now
# TODO: support 2.0 New Custom OP on Mac
if(NOT APPLE)
add_subdirectory(custom_op)
endif()
......
# New custom OP can support Windows/Linux now
if(WITH_GPU)
# 'test_custom_relu_op_setup/jit' compile .cc and .cu file
# GPU custom op tests: compile both .cc and .cu file
py_test(test_custom_relu_op_setup SRCS test_custom_relu_op_setup.py)
py_test(test_custom_relu_op_jit SRCS test_custom_relu_op_jit.py)
py_test(test_custom_relu_model SRCS test_custom_relu_model.py)
......@@ -11,8 +11,6 @@ if(WITH_GPU)
set_tests_properties(test_custom_relu_model PROPERTIES TIMEOUT 180)
endif()
py_test(test_sysconfig SRCS test_sysconfig.py)
# CPU custom op tests: only compile .cc file
py_test(test_dispatch_jit SRCS test_dispatch_jit.py)
py_test(test_multi_out_jit SRCS test_multi_out_jit.py)
......@@ -21,41 +19,6 @@ py_test(test_custom_concat SRCS test_custom_concat.py)
py_test(test_custom_conj SRCS test_custom_conj.py)
# other tests
py_test(test_sysconfig SRCS test_sysconfig.py)
py_test(test_check_abi SRCS test_check_abi.py)
cc_test(test_check_error SRCS test_check_error.cc DEPS gtest)
if(NOT LINUX)
return()
endif()
# Old custom OP only support Linux, only run on Linux
py_test(test_custom_op SRCS test_custom_op.py)
py_test(test_jit_load SRCS test_jit_load.py)
py_test(test_setup_install SRCS test_setup_install.py)
py_test(test_setup_build SRCS test_setup_build.py)
set_tests_properties(test_jit_load PROPERTIES TIMEOUT 180)
set_tests_properties(test_setup_install PROPERTIES TIMEOUT 250)
set_tests_properties(test_setup_build PROPERTIES TIMEOUT 180)
if(WITH_ROCM)
hip_library(relu_op_shared SHARED SRCS relu_op.cc relu_op.cu DEPS paddle_framework_shared)
elseif(WITH_GPU)
nv_library(relu_op_shared SHARED SRCS relu_op.cc relu_op.cu DEPS paddle_framework_shared)
else()
cc_library(relu_op_shared SHARED SRCS relu_op.cc DEPS paddle_framework_shared)
endif()
set_target_properties(relu_op_shared PROPERTIES OUTPUT_NAME relu2_op)
target_link_libraries(relu_op_shared ${FLUID_FRAMEWORK_SHARED_LIB})
# remove the linked glog and gflags when compling relu_op_shared
# otherwise, there is running error:
# ERROR: something wrong with flag 'logtostderr' in file
# 'third_party/glog/src/extern_glog/src/logging.cc'.
# One possibility: file 'third_party/glog/src/extern_glog/src/logging.cc'
# is being linked both statically and dynamically into this executable.
get_target_property(TARGET_LIBRARIES relu_op_shared LINK_LIBRARIES)
LIST(REMOVE_ITEM TARGET_LIBRARIES glog)
LIST(REMOVE_ITEM TARGET_LIBRARIES gflags)
set_property(TARGET relu_op_shared PROPERTY LINK_LIBRARIES ${TARGET_LIBRARIES} )
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class Relu2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Y", in_dims);
}
};
class Relu2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddOutput("Y", "Output of relu_op");
AddComment(R"DOC(
Relu2 Operator.
)DOC");
}
};
class Relu2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Y"));
ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
}
};
template <typename T>
class Relu2GradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType("relu2_grad");
op->SetInput("Y", this->Output("Y"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class Relu2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("X");
auto* out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < in_t->numel(); ++i) {
y[i] = std::max(static_cast<T>(0.), x[i]);
}
}
};
template <typename DeviceContext, typename T>
class Relu2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* y_t = ctx.Input<Tensor>("Y");
auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = dy_t->data<T>();
auto y = y_t->data<T>();
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < y_t->numel(); ++i) {
dx[i] = dy[i] * (y[i] > static_cast<T>(0) ? 1. : 0.);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(relu2,
ops::Relu2Op,
ops::Relu2OpMaker,
ops::Relu2GradMaker<paddle::framework::OpDesc>,
ops::Relu2GradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(relu2_grad, ops::Relu2GradOp);
REGISTER_OP_CPU_KERNEL(relu2,
ops::Relu2Kernel<CPU, float>,
ops::Relu2Kernel<CPU, double>);
REGISTER_OP_CPU_KERNEL(relu2_grad,
ops::Relu2GradKernel<CPU, float>,
ops::Relu2GradKernel<CPU, double>);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void KeRelu2(const T* x, const int num, T* y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
y[i] = max(x[i], static_cast<T>(0.));
}
}
template <typename DeviceContext, typename T>
class Relu2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("X");
auto* out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int num = in_t->numel();
int block = 512;
int grid = (num + block - 1) / block;
KeRelu2<T><<<grid, block, 0, dev_ctx.stream()>>>(x, num, y);
}
};
template <typename T>
__global__ void KeRelu2Grad(const T* y, const T* dy, const int num, T* dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.);
}
}
template <typename DeviceContext, typename T>
class Relu2GradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* y_t = ctx.Input<Tensor>("Y");
auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = dy_t->data<T>();
auto y = y_t->data<T>();
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int num = dy_t->numel();
int block = 512;
int grid = (num + block - 1) / block;
KeRelu2Grad<T><<<grid, block, 0, dev_ctx.stream()>>>(y, dy, num, dx);
}
};
} // namespace operators
} // namespace paddle
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(relu2,
paddle::operators::Relu2CUDAKernel<CUDA, float>,
paddle::operators::Relu2CUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(relu2_grad,
paddle::operators::Relu2GradCUDAKernel<CUDA, float>,
paddle::operators::Relu2GradCUDAKernel<CUDA, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class Relu3Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Y", in_dims);
}
};
class Relu3OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddOutput("Y", "Output of relu_op");
AddComment(R"DOC(
Relu3 Operator.
)DOC");
}
};
class Relu3GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Y"));
ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
}
};
template <typename T>
class Relu3GradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType("relu3_grad");
op->SetInput("Y", this->Output("Y"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class Relu3Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("X");
auto* out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < in_t->numel(); ++i) {
y[i] = std::max(static_cast<T>(0.), x[i]);
}
}
};
template <typename DeviceContext, typename T>
class Relu3GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* y_t = ctx.Input<Tensor>("Y");
auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = dy_t->data<T>();
auto y = y_t->data<T>();
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < y_t->numel(); ++i) {
dx[i] = dy[i] * (y[i] > static_cast<T>(0) ? 1. : 0.);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(relu3,
ops::Relu3Op,
ops::Relu3OpMaker,
ops::Relu3GradMaker<paddle::framework::OpDesc>,
ops::Relu3GradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(relu3_grad, ops::Relu3GradOp);
REGISTER_OP_CPU_KERNEL(relu3,
ops::Relu3Kernel<CPU, float>,
ops::Relu3Kernel<CPU, double>);
REGISTER_OP_CPU_KERNEL(relu3_grad,
ops::Relu3GradKernel<CPU, float>,
ops::Relu3GradKernel<CPU, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void KeRelu3(const T* x, const int num, T* y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
y[i] = max(x[i], static_cast<T>(0.));
}
}
template <typename DeviceContext, typename T>
class Relu3CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("X");
auto* out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int num = in_t->numel();
int block = 512;
int grid = (num + block - 1) / block;
KeRelu3<T><<<grid, block, 0, dev_ctx.stream()>>>(x, num, y);
}
};
template <typename T>
__global__ void KeRelu3Grad(const T* y, const T* dy, const int num, T* dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.);
}
}
template <typename DeviceContext, typename T>
class Relu3GradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* y_t = ctx.Input<Tensor>("Y");
auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = dy_t->data<T>();
auto y = y_t->data<T>();
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int num = dy_t->numel();
int block = 512;
int grid = (num + block - 1) / block;
KeRelu3Grad<T><<<grid, block, 0, dev_ctx.stream()>>>(y, dy, num, dx);
}
};
} // namespace operators
} // namespace paddle
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(relu3,
paddle::operators::Relu3CUDAKernel<CUDA, float>,
paddle::operators::Relu3CUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(relu3_grad,
paddle::operators::Relu3GradCUDAKernel<CUDA, float>,
paddle::operators::Relu3GradCUDAKernel<CUDA, double>);
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from utils import paddle_includes, extra_compile_args
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension, setup
from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method
# switch to old custom op method
use_new_custom_op_load_method(False)
file_dir = os.path.dirname(os.path.abspath(__file__))
setup(
name='librelu2_op_from_setup',
ext_modules=[
CUDAExtension(
sources=['relu_op3.cc', 'relu_op3.cu', 'relu_op.cc',
'relu_op.cu'], # test for multi ops
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args)
],
cmdclass={
'build_ext': BuildExtension.with_options(
no_python_abi_suffix=True, output_dir=file_dir) # for unittest
})
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from utils import paddle_includes, extra_compile_args
from paddle.utils.cpp_extension import CUDAExtension, setup
from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method
# switch to old custom op method
use_new_custom_op_load_method(False)
setup(
name='custom_relu2',
ext_modules=CUDAExtension( # test for not specific name here.
sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc',
'relu_op3.cu'], # test for multi ops
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import unittest
import contextlib
import paddle
import paddle.fluid as fluid
paddle.enable_static()
def load_so(so_name):
"""
Load .so file and parse custom op into OpInfoMap.
"""
file_dir = os.path.dirname(os.path.abspath(__file__))
fluid.load_op_library(os.path.join(file_dir, so_name))
from paddle.fluid.layer_helper import LayerHelper
def relu2(x, name=None):
helper = LayerHelper("relu2", **locals())
out = helper.create_variable(
type=x.type, name=name, dtype=x.dtype, persistable=False)
helper.append_op(type="relu2", inputs={"X": x}, outputs={"Y": out})
return out
@contextlib.contextmanager
def scope_prog_guard():
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
def linear_fc(data, label, use_custom_relu):
hidden = fluid.layers.fc(data, size=128)
hidden = relu2(hidden) if use_custom_relu else fluid.layers.relu(hidden)
hidden = fluid.layers.fc(hidden, size=128)
hidden = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=hidden, label=label)
loss = fluid.layers.mean(loss)
return loss
def custom_op_test(use_gpu=True, use_custom_relu=True):
with scope_prog_guard():
np.random.seed(0)
fluid.default_startup_program().random_seed = 10
fluid.default_main_program().random_seed = 10
data = fluid.layers.data(
name='data', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
loss = linear_fc(data, label, use_custom_relu)
optimizer = fluid.optimizer.Momentum(learning_rate=0.1, momentum=0.9)
optimizer.minimize(loss)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
compile_program = fluid.compiler.CompiledProgram(
fluid.default_main_program()).with_data_parallel(
loss_name=loss.name)
reader = paddle.batch(paddle.dataset.mnist.train(), batch_size=32)
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
num = 4
for i, data in enumerate(reader()):
outs, = exe.run(compile_program,
feed=feeder.feed(data),
fetch_list=[loss])
if i == num:
break
return outs
class CustomOpTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(2)
def test_cpu(self):
actual = custom_op_test(False, True)
expect = custom_op_test(False, False)
self.assertEqual(actual.all(), expect.all())
def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
actual = custom_op_test(True, True)
expect = custom_op_test(True, False)
self.assertEqual(actual.all(), expect.all())
if __name__ == '__main__':
load_so(so_name='librelu2_op.so')
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import paddle
import numpy as np
from paddle.utils.cpp_extension import load
from utils import paddle_includes, extra_cc_args, extra_nvcc_args
from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method
# switch to old custom op method
use_new_custom_op_load_method(False)
# Compile and load custom op Just-In-Time.
custom_module = load(
name='custom_relu2',
sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc', 'relu_op3.cu'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=extra_cc_args, # test for cc flags
extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags
verbose=True # add for unittest
)
class TestJITLoad(unittest.TestCase):
def test_api(self):
raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32')
gt_data = np.array([[0, 1, 0], [1, 0, 0]]).astype('float32')
x = paddle.to_tensor(raw_data, dtype='float32')
# use custom api
out = custom_module.relu2(x)
out3 = custom_module.relu3(x)
self.assertTrue(np.array_equal(out.numpy(), gt_data))
self.assertTrue(np.array_equal(out3.numpy(), gt_data))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
from test_custom_op import CustomOpTest, load_so
import paddle
from paddle.utils.cpp_extension.extension_utils import run_cmd
from paddle.fluid.layer_helper import LayerHelper
from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method
# switch to old custom op method
use_new_custom_op_load_method(False)
def compile_so():
"""
Compile .so file by running setup.py config.
"""
# build .so with setup.py
file_dir = os.path.dirname(os.path.abspath(__file__))
cmd = 'cd {} && python setup_build.py build'.format(file_dir)
run_cmd(cmd)
# `setup.py build` only produce .so file containing multi operators.
# Python Interface should be added manually. `relu2` api is in `test_custom_op.py`
def relu3(x, name=None):
helper = LayerHelper("relu3", **locals())
out = helper.create_variable(
type=x.type, name=name, dtype=x.dtype, persistable=False)
helper.append_op(type="relu3", inputs={"X": x}, outputs={"Y": out})
return out
class TestCompileMultiOp(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_relu3(self):
raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32')
x = paddle.to_tensor(raw_data, dtype='float32')
# use custom api
out = relu3(x)
self.assertTrue(
np.array_equal(out.numpy(),
np.array([[0, 1, 0], [1, 0, 0]]).astype('float32')))
def tearDown(self):
paddle.enable_static()
if __name__ == '__main__':
compile_so()
load_so(so_name='librelu2_op_from_setup.so')
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import site
import unittest
import paddle
import subprocess
import numpy as np
from paddle.utils.cpp_extension.extension_utils import run_cmd
from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method
# switch to old custom op method
use_new_custom_op_load_method(False)
class TestSetUpInstall(unittest.TestCase):
def setUp(self):
cur_dir = os.path.dirname(os.path.abspath(__file__))
# compile, install the custom op egg into site-packages under background
cmd = 'cd {} && python setup_install.py install'.format(cur_dir)
run_cmd(cmd)
# NOTE(Aurelius84): Normally, it's no need to add following codes for users.
# But we simulate to pip install in current process, so interpreter don't snap
# sys.path has been updated. So we update it manually.
# See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3
site_dir = site.getsitepackages()[0]
custom_egg_path = [
x for x in os.listdir(site_dir) if 'custom_relu2' in x
]
assert len(custom_egg_path) == 1, "Matched egg number is %d." % len(
custom_egg_path)
sys.path.append(os.path.join(site_dir, custom_egg_path[0]))
def test_api(self):
# usage: import the package directly
import custom_relu2
raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32')
gt_data = np.array([[0, 1, 0], [1, 0, 0]]).astype('float32')
x = paddle.to_tensor(raw_data, dtype='float32')
# use custom api
out = custom_relu2.relu2(x)
out3 = custom_relu2.relu3(x)
self.assertTrue(np.array_equal(out.numpy(), gt_data))
self.assertTrue(np.array_equal(out3.numpy(), gt_data))
if __name__ == '__main__':
unittest.main()
......@@ -14,7 +14,6 @@
from . import optimizer
from ..fluid.contrib import reader
from ..fluid import load_op_library
from ..fluid.layer_helper import LayerHelper
__all__ = []
......
......@@ -20,7 +20,6 @@ from .lazy_import import try_import
from .op_version import OpLastCheckpointChecker
from .install_check import run_check
from ..fluid.framework import unique_name
from ..fluid.framework import load_op_library
from ..fluid.framework import require_version
from . import download
......@@ -30,4 +29,4 @@ from . import cpp_extension
__all__ = ['dump_config', 'deprecated', 'download', 'run_check']
#TODO: define new api under this directory
__all__ += ['unique_name', 'load_op_library', 'require_version']
__all__ += ['unique_name', 'require_version']
......@@ -26,7 +26,7 @@ from .extension_utils import find_cuda_home, find_rocm_home, normalize_extension
from .extension_utils import is_cuda_file, prepare_unix_cudaflags, prepare_win_cudaflags
from .extension_utils import _import_module_from_library, _write_setup_file, _jit_compile
from .extension_utils import check_abi_compatibility, log_v, CustomOpInfo, parse_op_name_from
from .extension_utils import use_new_custom_op_load_method, clean_object_if_change_cflags
from .extension_utils import clean_object_if_change_cflags
from .extension_utils import bootstrap_context, get_build_directory, add_std_without_repeat
from .extension_utils import IS_WINDOWS, OS_NAME, MSVC_COMPILE_FLAGS, MSVC_COMPILE_FLAGS
......
......@@ -28,7 +28,6 @@ import subprocess
from contextlib import contextmanager
from setuptools.command import bdist_egg
from .. import load_op_library
from ...fluid import core
from ...fluid.framework import OpProtoHolder
from ...sysconfig import get_include, get_lib
......@@ -86,7 +85,6 @@ information
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
'''
USING_NEW_CUSTOM_OP_LOAD_METHOD = True
DEFAULT_OP_ATTR_NAMES = [
core.op_proto_and_checker_maker.kOpRoleAttrName(),
......@@ -97,18 +95,6 @@ DEFAULT_OP_ATTR_NAMES = [
]
# NOTE(chenweihang): In order to be compatible with
# the two custom op define method, after removing
# old method, we can remove them together
def use_new_custom_op_load_method(*args):
global USING_NEW_CUSTOM_OP_LOAD_METHOD
if len(args) == 0:
return USING_NEW_CUSTOM_OP_LOAD_METHOD
else:
assert len(args) == 1 and isinstance(args[0], bool)
USING_NEW_CUSTOM_OP_LOAD_METHOD = args[0]
@contextmanager
def bootstrap_context():
"""
......@@ -122,10 +108,7 @@ def bootstrap_context():
def load_op_meta_info_and_register_op(lib_filename):
if USING_NEW_CUSTOM_OP_LOAD_METHOD:
core.load_op_meta_info_and_register_op(lib_filename)
else:
core.load_op_library(lib_filename)
core.load_op_meta_info_and_register_op(lib_filename)
return OpProtoHolder.instance().update_op_proto()
......@@ -406,10 +389,7 @@ def normalize_extension_kwargs(kwargs, use_cuda=False):
# append link flags
extra_link_args = kwargs.get('extra_link_args', [])
if use_new_custom_op_load_method():
extra_link_args.append('-lpaddle_custom_op')
else:
extra_link_args.append('-lpaddle_framework')
extra_link_args.append('-lpaddle_custom_op')
if use_cuda:
extra_link_args.append('-lcudart')
......@@ -811,9 +791,7 @@ def _write_setup_file(name,
import os
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension, setup
from paddle.utils.cpp_extension import get_build_directory
from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method
use_new_custom_op_load_method({use_new_method})
setup(
name='{name}',
......@@ -841,8 +819,7 @@ def _write_setup_file(name,
extra_cxx_cflags=list2str(extra_cxx_cflags),
extra_cuda_cflags=list2str(extra_cuda_cflags),
extra_link_args=list2str(link_args),
build_dir=build_dir,
use_new_method=use_new_custom_op_load_method())
build_dir=build_dir)
log_v('write setup.py into {}'.format(file_path), verbose)
with open(file_path, 'w') as f:
......@@ -898,11 +875,7 @@ def parse_op_name_from(sources):
"""
def regex(content):
if USING_NEW_CUSTOM_OP_LOAD_METHOD:
pattern = re.compile(r'PD_BUILD_OP\(([^,\)]+)\)')
else:
pattern = re.compile(r'REGISTER_OPERATOR\(([^,]+),')
pattern = re.compile(r'PD_BUILD_OP\(([^,\)]+)\)')
content = re.sub(r'\s|\t|\n', '', content)
op_name = pattern.findall(content)
op_name = set([re.sub('_grad', '', name) for name in op_name])
......
......@@ -347,11 +347,6 @@ if '${WITH_XPU}' == 'OFF' and '${XPU_SDK_ROOT}' != '':
shutil.copy(xpu_rt_lib, libs_path)
package_data['paddle.libs']+=['libxpurt.so']
### Old custom op extension mechanism related, will be removed in 2.1.0 ###
# copy libpaddle_framework.so to libs on linux
if sys.platform.startswith('linux'):
shutil.copy('${FLUID_FRAMEWORK_SHARED_LIB}', libs_path)
package_data['paddle.libs'] += ['libpaddle_framework.so']
### New custom op extension mechanism related ###
# copy libpaddle_custom_op.so to libs on linux
......@@ -405,25 +400,8 @@ def find_files(pattern, root):
headers = (
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/extension')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/framework')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/imperative')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/memory')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/platform')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/string')) +
list(find_files('*.pb.h', '${PADDLE_BINARY_DIR}/paddle/fluid/platform')) +
list(find_files('*.pb.h', '${PADDLE_BINARY_DIR}/paddle/fluid/framework')) +
list(find_files('*.pb', '${cudaerror_INCLUDE_DIR}')) + # errorMessage.pb for errormessage
['${EIGEN_INCLUDE_DIR}/Eigen/Core'] + # eigen
list(find_files('*', '${EIGEN_INCLUDE_DIR}/Eigen/src')) + # eigen
list(find_files('*', '${EIGEN_INCLUDE_DIR}/unsupported/Eigen')) + # eigen
list(find_files('*', '${GFLAGS_INSTALL_DIR}/include')) + # gflags
list(find_files('*', '${GLOG_INSTALL_DIR}/include')) + # glog
list(find_files('*', '${BOOST_INCLUDE_DIR}/boost')) + # boost
list(find_files('*', '${XXHASH_INSTALL_DIR}/include')) + # xxhash
list(find_files('*', '${PROTOBUF_INCLUDE_DIR}')) + # protobuf
list(find_files('*', '${DLPACK_INCLUDE_DIR}')) + # dlpack
list(find_files('*.h', '${THREADPOOL_INCLUDE_DIR}'))) # threadpool
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/extension')) + # extension
list(find_files('*', '${BOOST_INCLUDE_DIR}/boost'))) # boost
if '${WITH_MKLDNN}' == 'ON':
headers += list(find_files('*', '${MKLDNN_INSTALL_DIR}/include')) # mkldnn
......@@ -463,17 +441,18 @@ class InstallHeaders(Command):
('install_headers', 'install_dir'),
('force', 'force'))
def copy_data_type_headers(self, header):
if os.name == 'nt':
data_type_headers = ['platform\\complex64.h', 'platform\\complex128.h', 'platform\\float16.h']
else:
data_type_headers = ['platform/complex64.h', 'platform/complex128.h', 'platform/float16.h']
for dtype_header in data_type_headers:
if dtype_header in header:
install_dir = os.path.join(self.install_dir, "paddle/fluid/extension/include")
if not os.path.exists(install_dir):
self.mkpath(install_dir)
return self.copy_file(header, install_dir)
def copy_data_type_headers(self):
# For paddle uew custom op, only copy data type headers from `paddle/fluid/platform`
# to `extension/incude`,
data_type_headers = (['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex64.h'] +
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/complex128.h'] +
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/float16.h'])
install_dir = os.path.join(self.install_dir, "paddle/fluid/extension/include")
if not os.path.exists(install_dir):
self.mkpath(install_dir)
for header in data_type_headers:
self.copy_file(header, install_dir)
def mkdir_and_copy_file(self, header):
if 'pb.h' in header:
......@@ -481,9 +460,6 @@ class InstallHeaders(Command):
elif 'third_party' not in header:
# paddle headers
install_dir = re.sub('@PADDLE_SOURCE_DIR@/', '', header)
# For paddle data type headers, we also need to copy to `extension/incude`,
# used for new custom operator
self.copy_data_type_headers(header)
else:
# third_party
install_dir = re.sub('${THIRD_PARTY_PATH}', 'third_party', header)
......@@ -509,6 +485,7 @@ class InstallHeaders(Command):
for header in hdrs:
(out, _) = self.mkdir_and_copy_file(header)
self.outfiles.append(out)
self.copy_data_type_headers()
def get_inputs(self):
return self.distribution.headers or []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册