未验证 提交 1a3eef02 编写于 作者: Q qingqing01 提交者: GitHub

Enable users to create custom cpp op outside framework. (#19256)

* How to write custom op needs to follow framework OP spec.
* Package fluid_framework.so and headers into whl.
* Add paddle.sysconfig.get_include() and paddle.sysconfig.get_lib() to get include dir and lib dir.
* Export some C-APIs to merge OpInfo between core.so and custom_op.so.
* Add unit testing.
* Update API.spec.
上级 f1eebf75
......@@ -26,6 +26,7 @@ paddle.fluid.Variable.gradient (ArgSpec(args=['self'], varargs=None, keywords=No
paddle.fluid.Variable.numpy (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '7536e8feb56d827875943e7f01d406fc'))
paddle.fluid.Variable.set_value (ArgSpec(args=['self', 'value'], varargs=None, keywords=None, defaults=None), ('document', 'c424b9e763ff51c38a6917f98026fe7d'))
paddle.fluid.Variable.to_string (ArgSpec(args=['self', 'throw_on_error', 'with_details'], varargs=None, keywords=None, defaults=(False,)), ('document', '31f359a2c074f26dc0ffff296fc3983f'))
paddle.fluid.load_op_library (ArgSpec(args=['lib_filename'], varargs=None, keywords=None, defaults=None), ('document', 'c009b2ea5fb6520f2d2f53aafec788e0'))
paddle.fluid.Executor ('paddle.fluid.executor.Executor', ('document', '34e8c1769313fbeff7817212dda6259e'))
paddle.fluid.Executor.__init__ (ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.Executor.close (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '3a584496aa1343f36eebf3c46b323a74'))
......
......@@ -246,3 +246,35 @@ message(STATUS "commit: ${PADDLE_COMMIT}")
message(STATUS "branch: ${PADDLE_BRANCH}")
configure_file(commit.h.in commit.h)
set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto)
if(WIN32)
sep_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES})
else(WIN32)
cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES})
endif(WIN32)
cc_library(paddle_framework_shared
SHARED SRCS executor.cc operator.cc ${CMAKE_CURRENT_SOURCE_DIR}/c/c_api.cc
DEPS ${FLUID_FRAMEWORK_MODULES})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
set_target_properties(paddle_framework_shared PROPERTIES OUTPUT_NAME paddle_framework)
target_link_libraries(paddle_framework_shared ${os_dependency_modules})
if (LINUX)
set(FLUID_FRAMEWORK_SHARED_LIB
${PADDLE_BINARY_DIR}/paddle/fluid/framework/libpaddle_framework.so
CACHE INTERNAL "Fluid framework lib")
endif()
if (WIN32)
set(FLUID_FRAMEWORK_SHARED_LIB
${PADDLE_BINARY_DIR}/paddle/fluid/framework/libpaddle_framework.dll
CACHE INTERNAL "Fluid framework lib")
endif()
if(APPLE)
set(FLUID_FRAMEWORK_SHARED_LIB
${PADDLE_BINARY_DIR}/paddle/fluid/framework/libpaddle_framework.dylib
CACHE INTERNAL "Fluid framework lib")
endif()
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/c/c_api.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/init.h"
extern "C" {
paddle::framework::OpInfoMap &PD_GetOpInfoMap() {
return paddle::framework::OpInfoMap::Instance();
}
void PD_InitDevicesPool(paddle::platform::DeviceContextPool *pool) {
paddle::platform::DeviceContextPool::SetPool(pool);
}
std::vector<std::string> PD_GetGradOpDescStrs(
const paddle::framework::OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> *grad_to_var,
const std::vector<paddle::framework::BlockDesc *> &grad_block) {
auto &op_info = PD_GetOpInfoMap().Get(op_desc.Type());
std::vector<std::string> ret;
if (op_info.grad_op_maker_) {
auto grad_op_descs =
op_info.grad_op_maker_(op_desc, no_grad_set, grad_to_var, grad_block);
size_t op_num = grad_op_descs.size();
ret.resize(op_num);
for (size_t i = 0; i < op_num; ++i) {
PADDLE_ENFORCE_EQ(
grad_op_descs[i]->Proto()->SerializePartialToString(&ret[i]), true,
"Cannot serialize message.");
}
}
return ret;
}
} // end extern "C"
/* copyright (c) 2019 paddlepaddle authors. all rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef __cplusplus
extern "C" {
#endif
// C-API to get global OpInfo map.
paddle::framework::OpInfoMap &PD_GetOpInfoMap();
// C-API to init global DeviceContextPool from outside.
void PD_InitDevicesPool(paddle::platform::DeviceContextPool *pool);
// C-API to serialize the grad op protocol message to a binary string.
std::vector<std::string> PD_GetGradOpDescStrs(
const paddle::framework::OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> *grad_to_var,
const std::vector<paddle::framework::BlockDesc *> &grad_block);
#ifdef __cplusplus
}
#endif
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace framework {
template <typename T>
T *DynLoad(void *handle, std::string name) {
T *func = reinterpret_cast<T *>(dlsym(handle, name.c_str()));
#if !defined(_WIN32)
auto errorno = dlerror();
#else
auto errorno = GetLastError();
#endif // !_WIN32
PADDLE_ENFORCE_NOT_NULL(func, errorno);
return func;
}
void LoadOpLib(const std::string &dso_name) {
void *handle = paddle::platform::dynload::GetOpDsoHandle(dso_name);
typedef OpInfoMap &get_op_info_t();
get_op_info_t *get_op_info =
DynLoad<get_op_info_t>(handle, "PD_GetOpInfoMap");
auto &op_info = get_op_info();
auto *dyn_info_map = op_info.mutable_map();
typedef std::vector<std::string> grad_op_desc_maker_t(
const OpDesc &, const std::unordered_set<std::string> &,
std::unordered_map<std::string, std::string> *,
const std::vector<BlockDesc *> &);
grad_op_desc_maker_t *grad_op_desc_maker =
DynLoad<grad_op_desc_maker_t>(handle, "PD_GetGradOpDescStrs");
auto &info_map = OpInfoMap::Instance();
for (const auto &n : *(dyn_info_map)) {
auto type = n.first;
if (type == "recurrent" || type == "recurrent_grad" ||
type == "conditional_block" || type == "conditional_block_grad") {
continue;
}
if (info_map.Has(n.first)) {
PADDLE_THROW("Op %s has been registered.");
}
OpInfo info;
info.creator_ = n.second.creator_;
// If get the protocol buffer from dynamic library directly, there
// will be deconstruction error
// ** Error in `python`: free(): invalid pointer:
// ... paddle::framework::proto::OpDesc::SharedDtor()
// It seems a bug in protobuf, see
// https://github.com/protocolbuffers/protobuf/issues/435
// So, get the serialized binary string from dynamic library,
// then deserialize to protocol buffer.
info.grad_op_maker_ = [grad_op_desc_maker](
const OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> *grad_to_var,
const std::vector<BlockDesc *> &grad_block) {
std::vector<std::string> strs =
grad_op_desc_maker(op_desc, no_grad_set, grad_to_var, grad_block);
std::vector<std::unique_ptr<OpDesc>> ret;
for (auto &str : strs) {
proto::OpDesc proto_desc;
PADDLE_ENFORCE_EQ(proto_desc.ParseFromString(str), true,
"Failed to parse OpDesc from string");
ret.emplace_back(new OpDesc(proto_desc, nullptr));
}
return ret;
};
info.proto_ = n.second.proto_;
info.checker_ = n.second.checker_;
info.infer_var_type_ = n.second.infer_var_type_;
info.infer_shape_ = n.second.infer_shape_;
info.infer_inplace_ = n.second.infer_inplace_;
info.infer_no_need_buffer_vars_ = n.second.infer_no_need_buffer_vars_;
info.use_default_grad_op_desc_maker_ =
n.second.use_default_grad_op_desc_maker_;
info_map.Insert(type, info);
}
typedef void init_device_t(platform::DeviceContextPool *);
init_device_t *init_dev =
DynLoad<init_device_t>(handle, "PD_InitDevicesPool");
init_dev(&(platform::DeviceContextPool::Instance()));
}
} // namespace framework
} // namespace paddle
......@@ -331,6 +331,8 @@ class DeviceContextPool {
return *pool;
}
static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; }
/*! \brief Return handle of single device context. */
platform::DeviceContext* Get(const platform::Place& place);
......
......@@ -46,6 +46,8 @@ DEFINE_string(
DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so.");
DEFINE_string(op_dir, "", "Specify path for loading user-defined op library.");
namespace paddle {
namespace platform {
namespace dynload {
......@@ -280,6 +282,16 @@ void* GetMKLMLDsoHandle() {
#endif
}
void* GetOpDsoHandle(const std::string& dso_name) {
#if defined(__APPLE__) || defined(__OSX__)
PADDLE_THROW("Do not support Apple.");
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
PADDLE_THROW("Do not support Windows.");
#else
return GetDsoHandleFromSearchPath(FLAGS_op_dir, dso_name);
#endif
}
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -35,6 +35,7 @@ void* GetWarpCTCDsoHandle();
void* GetNCCLDsoHandle();
void* GetTensorRtDsoHandle();
void* GetMKLMLDsoHandle();
void* GetOpDsoHandle(const std::string& dso_name);
void SetPaddleLibPath(const std::string&);
} // namespace dynload
......
......@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/load_op_lib.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
......@@ -1053,6 +1054,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_gflags", framework::InitGflags);
m.def("init_glog", framework::InitGLOG);
m.def("init_dgc", framework::InitDGC);
m.def("load_op_library", framework::LoadOpLib);
m.def("init_devices",
[](bool init_p2p) { framework::InitDevices(init_p2p); });
......
......@@ -32,3 +32,4 @@ import paddle.batch
import paddle.compat
import paddle.distributed
batch = batch.batch
import paddle.sysconfig
......@@ -47,6 +47,7 @@ __all__ = [
'in_dygraph_mode',
'is_compiled_with_cuda',
'Variable',
'load_op_library',
]
EMPTY_VAR_NAME = core.kEmptyVarName()
......@@ -1300,6 +1301,12 @@ class OpProtoHolder(object):
raise ValueError("Operator \"%s\" has not been registered." % type)
return self.op_proto_map[type]
def update_op_proto(self):
op_protos = get_all_op_protos()
for proto in op_protos:
if proto.type not in self.op_proto_map:
self.op_proto_map[proto.type] = proto
@staticmethod
def generated_op_attr_names():
return {
......@@ -4327,3 +4334,25 @@ def _dygraph_place_guard(place):
yield
_dygraph_current_expected_place_ = tmp_place
def load_op_library(lib_filename):
"""
Load a dynamic library, including custom operators and kernels.
When library is loaded, ops and kernels registered in the library
will be available in PaddlePaddle main process.
Please note, the type of custom operators cann't have the same type
with the existing operators in the framework.
Args:
lib_filename (str): name of dynamic library.
Examples:
.. code-block:: python
import paddle.fluid as fluid
#fluid.load_op_library('custom_op.so')
"""
core.load_op_library(lib_filename)
OpProtoHolder.instance().update_op_proto()
......@@ -17,7 +17,7 @@ from __future__ import print_function
import copy
import six
from .framework import Parameter, dtype_is_floating, in_dygraph_mode
from .framework import Parameter, dtype_is_floating, in_dygraph_mode, OpProtoHolder
from . import unique_name
from paddle.fluid.initializer import Constant, Xavier
from .param_attr import ParamAttr
......
......@@ -11,3 +11,7 @@ endforeach()
add_subdirectory(unittests)
add_subdirectory(book)
if(NOT APPLE AND NOT WIN32)
add_subdirectory(custom_op)
endif()
if (WITH_GPU)
nv_library(relu_op_shared SHARED SRCS relu_op.cc relu_op.cu DEPS paddle_framework_shared)
else()
cc_library(relu_op_shared SHARED SRCS relu_op.cc DEPS paddle_framework_shared)
endif()
set_target_properties(relu_op_shared PROPERTIES OUTPUT_NAME relu2_op)
target_link_libraries(relu_op_shared ${FLUID_FRAMEWORK_SHARED_LIB})
# remove the linked glog and gflags when compling relu_op_shared
# otherwise, there is running error:
# ERROR: something wrong with flag 'logtostderr' in file
# 'third_party/glog/src/extern_glog/src/logging.cc'.
# One possibility: file 'third_party/glog/src/extern_glog/src/logging.cc'
# is being linked both statically and dynamically into this executable.
get_target_property(TARGET_LIBRARIES relu_op_shared LINK_LIBRARIES)
LIST(REMOVE_ITEM TARGET_LIBRARIES glog)
LIST(REMOVE_ITEM TARGET_LIBRARIES gflags)
set_property(TARGET relu_op_shared PROPERTY LINK_LIBRARIES ${TARGET_LIBRARIES} )
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class Relu2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Y", in_dims);
}
};
class Relu2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddOutput("Y", "Output of relu_op");
AddComment(R"DOC(
Relu2 Operator.
)DOC");
}
};
class Relu2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Y"));
ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
}
};
class Relu2GradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("relu2_grad");
op->SetInput("Y", Output("Y"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class Relu2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("X");
auto* out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < in_t->numel(); ++i) {
y[i] = std::max(static_cast<T>(0.), x[i]);
}
}
};
template <typename DeviceContext, typename T>
class Relu2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* y_t = ctx.Input<Tensor>("Y");
auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = dy_t->data<T>();
auto y = y_t->data<T>();
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < y_t->numel(); ++i) {
dx[i] = dy[i] * (y[i] > static_cast<T>(0) ? 1. : 0.);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(relu2, ops::Relu2Op, ops::Relu2OpMaker, ops::Relu2GradMaker);
REGISTER_OPERATOR(relu2_grad, ops::Relu2GradOp);
REGISTER_OP_CPU_KERNEL(relu2,
ops::Relu2Kernel<CPU, float>,
ops::Relu2Kernel<CPU, double>);
REGISTER_OP_CPU_KERNEL(relu2_grad,
ops::Relu2GradKernel<CPU, float>,
ops::Relu2GradKernel<CPU, double>);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void KeRelu2(const T* x, const int num, T* y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
y[i] = max(x[i], static_cast<T>(0.));
}
}
template <typename DeviceContext, typename T>
class Relu2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("X");
auto* out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int num = in_t->numel();
int block = 512;
int grid = (num + block - 1) / block;
KeRelu2<T><<<grid, block, 0, dev_ctx.stream()>>>(x, num, y);
}
};
template <typename T>
__global__ void KeRelu2Grad(const T* y, const T* dy, const int num, T* dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.);
}
}
template <typename DeviceContext, typename T>
class Relu2GradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* y_t = ctx.Input<Tensor>("Y");
auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = dy_t->data<T>();
auto y = y_t->data<T>();
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int num = dy_t->numel();
int block = 512;
int grid = (num + block - 1) / block;
KeRelu2Grad<T><<<grid, block, 0, dev_ctx.stream()>>>(y, dy, num, dx);
}
};
} // namespace operators
} // namespace paddle
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(relu2,
paddle::operators::Relu2CUDAKernel<CUDA, float>,
paddle::operators::Relu2CUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(relu2_grad,
paddle::operators::Relu2GradCUDAKernel<CUDA, float>,
paddle::operators::Relu2GradCUDAKernel<CUDA, double>);
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import unittest
import contextlib
import paddle
import paddle.fluid as fluid
file_dir = os.path.dirname(os.path.abspath(__file__))
fluid.load_op_library(os.path.join(file_dir, 'librelu2_op.so'))
from paddle.fluid.layer_helper import LayerHelper
def relu2(x, name=None):
helper = LayerHelper("relu2", **locals())
out = helper.create_variable(
type=x.type, name=name, dtype=x.dtype, persistable=False)
helper.append_op(type="relu2", inputs={"X": x}, outputs={"Y": out})
return out
@contextlib.contextmanager
def scope_prog_guard():
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
def linear_fc(data, label, use_custom_relu):
hidden = fluid.layers.fc(data, size=128)
hidden = relu2(hidden) if use_custom_relu else fluid.layers.relu(hidden)
hidden = fluid.layers.fc(hidden, size=128)
hidden = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=hidden, label=label)
loss = fluid.layers.mean(loss)
return loss
def custom_op_test(use_gpu=True, use_custom_relu=True):
with scope_prog_guard():
np.random.seed(0)
fluid.default_startup_program().random_seed = 10
fluid.default_main_program().random_seed = 10
data = fluid.layers.data(
name='data', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
loss = linear_fc(data, label, use_custom_relu)
optimizer = fluid.optimizer.Momentum(learning_rate=0.1, momentum=0.9)
optimizer.minimize(loss)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
compile_program = fluid.compiler.CompiledProgram(
fluid.default_main_program()).with_data_parallel(
loss_name=loss.name)
reader = paddle.batch(paddle.dataset.mnist.train(), batch_size=32)
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
num = 4
for i, data in enumerate(reader()):
outs, = exe.run(compile_program,
feed=feeder.feed(data),
fetch_list=[loss])
if i == num:
break
return outs
class CustomOpTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(2)
def test_cpu(self):
actual = custom_op_test(False, True)
expect = custom_op_test(False, False)
self.assertEqual(actual.all(), expect.all())
def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
actual = custom_op_test(True, True)
expect = custom_op_test(True, False)
self.assertEqual(actual.all(), expect.all())
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import paddle
class SysConfigTest(unittest.TestCase):
def test_include(self):
inc_dir = paddle.sysconfig.get_include()
inc_dirs = inc_dir.split(os.sep)
self.assertEqual(inc_dirs[-1], 'include')
self.assertEqual(inc_dirs[-2], 'paddle')
def test_libs(self):
lib_dir = paddle.sysconfig.get_lib()
lib_dirs = lib_dir.split(os.sep)
self.assertEqual(lib_dirs[-1], 'libs')
self.assertEqual(lib_dirs[-2], 'paddle')
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
__all__ = ['get_include', 'get_lib']
def get_include():
"""
Get the directory containing the PaddlePaddle C++ header files.
Returns:
The directory as string.
Examples:
.. code-block:: python
import paddle
include_dir = paddle.sysconfig.get_include()
"""
import paddle
return os.path.join(os.path.dirname(paddle.__file__), 'include')
def get_lib():
"""
Get the directory containing the libpaddle_framework.
Returns:
The directory as string.
Examples:
.. code-block:: python
import paddle
include_dir = paddle.sysconfig.get_lib()
"""
import paddle
return os.path.join(os.path.dirname(paddle.__file__), 'libs')
from setuptools import setup, Distribution, Extension
import subprocess
import os
import re
import shutil
import sys
import fnmatch
from setuptools import Command
from setuptools import setup, Distribution, Extension
from setuptools.command.install import install as InstallCommandBase
class BinaryDistribution(Distribution):
def has_ext_modules(foo):
return True
......@@ -220,11 +226,19 @@ if '${WITH_NGRAPH}' == 'ON':
package_data['paddle.libs']+=['${NGRAPH_SHARED_LIB_NAME}',
'${NGRAPH_CPU_LIB_NAME}',
'${NGRAPH_TBB_LIB_NAME}']
# copy libfuild_framework.so to libs
if os.name != 'nt' and sys.platform != 'darwin':
paddle_framework_lib='${FLUID_FRAMEWORK_SHARED_LIB}'
shutil.copy(paddle_framework_lib, libs_path)
package_data['paddle.libs'] += [('libpaddle_framework' if os.name != 'nt' else 'paddle_framework') + ext_name]
# remove unused paddle/libs/__init__.py
if os.path.isfile(libs_path+'/__init__.py'):
os.remove(libs_path+'/__init__.py')
package_dir['paddle.libs']=libs_path
# change rpath of ${FLUID_CORE_NAME}.ext, add $ORIGIN/../libs/ to it.
# The reason is that libwarpctc.ext, libiomp5.ext etc are in paddle.libs, and
# ${FLUID_CORE_NAME}.ext is in paddle.fluid, thus paddle/fluid/../libs will pointer to above libraries.
......@@ -250,6 +264,93 @@ if os.name == 'nt':
elif sys.platform == 'darwin':
ext_modules = []
def find_files(pattern, root):
for dirpath, _, files in os.walk(root):
for filename in fnmatch.filter(files, pattern):
yield os.path.join(dirpath, filename)
headers = (
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/framework')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/memory')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/platform')) +
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/fluid/string')) +
list(find_files('*.pb.h', '${PADDLE_BINARY_DIR}/paddle/fluid/framework')) +
['${EIGEN_INCLUDE_DIR}/Eigen/Core'] + # eigen
list(find_files('*', '${EIGEN_INCLUDE_DIR}/Eigen/src')) + # eigen
list(find_files('*', '${EIGEN_INCLUDE_DIR}/unsupported/Eigen')) + # eigen
list(find_files('*', '${GFLAGS_INSTALL_DIR}/include')) + # gflags
list(find_files('*', '${GLOG_INSTALL_DIR}/include')) + # glog
list(find_files('*', '${BOOST_INCLUDE_DIR}/boost')) + # boost
list(find_files('*', '${XXHASH_INSTALL_DIR}/include')) + # xxhash
list(find_files('*', '${PROTOBUF_INCLUDE_DIR}')) + # protobuf
list(find_files('*.h', '${THREADPOOL_INCLUDE_DIR}'))) # threadpool
class InstallCommand(InstallCommandBase):
def finalize_options(self):
ret = InstallCommandBase.finalize_options(self)
self.install_headers = os.path.join(self.install_purelib, 'paddle',
'include')
self.install_lib = self.install_platlib
return ret
class InstallHeaders(Command):
"""Override how headers are copied.
"""
description = 'install C/C++ header files'
user_options = [('install-dir=', 'd',
'directory to install header files to'),
('force', 'f',
'force installation (overwrite existing files)'),
]
boolean_options = ['force']
def initialize_options(self):
self.install_dir = None
self.force = 0
self.outfiles = []
def finalize_options(self):
self.set_undefined_options('install',
('install_headers', 'install_dir'),
('force', 'force'))
def mkdir_and_copy_file(self, header):
if 'pb.h' in header:
install_dir = re.sub('${PADDLE_BINARY_DIR}/', '', header)
elif 'third_party' not in header:
# framework
install_dir = re.sub('@PADDLE_SOURCE_DIR@/', '', header)
else:
# third_party
install_dir = re.sub('${THIRD_PARTY_PATH}', 'third_party', header)
install_dir = re.sub('src/extern_eigen3/', '', install_dir)
install_dir = re.sub('src/extern_boost/', '', install_dir)
install_dir = os.path.join(self.install_dir, os.path.dirname(install_dir))
if not os.path.exists(install_dir):
self.mkpath(install_dir)
return self.copy_file(header, install_dir)
def run(self):
if os.name == 'nt' or sys.platform == 'darwin':
return
hdrs = self.distribution.headers
if not hdrs:
return
self.mkpath(self.install_dir)
for header in hdrs:
(out, _) = self.mkdir_and_copy_file(header)
self.outfiles.append(out)
def get_inputs(self):
return self.distribution.headers or []
def get_outputs(self):
return self.outfiles
setup(name='${PACKAGE_NAME}',
version='${PADDLE_VERSION}',
description='Parallel Distributed Deep Learning',
......@@ -259,5 +360,10 @@ setup(name='${PACKAGE_NAME}',
package_data=package_data,
package_dir=package_dir,
scripts=paddle_bins,
distclass=BinaryDistribution
distclass=BinaryDistribution,
headers=headers,
cmdclass={
'install_headers': InstallHeaders,
'install': InstallCommand,
}
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册