diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index 3c5a8a9f4a7cb6bb11d2af28466abb68373bed04..1d6304cd6409d3c35c09d0430bbd585f9049fbe3 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -27,9 +27,9 @@ foreach(src ${TEST_OPS}) endforeach() # Compiling .so will cost some time, but running process is very fast. -set_tests_properties(test_custom_op_with_setup PROPERTIES TIMEOUT 180) set_tests_properties(test_jit_load PROPERTIES TIMEOUT 180) set_tests_properties(test_setup_install PROPERTIES TIMEOUT 180) +set_tests_properties(test_setup_build PROPERTIES TIMEOUT 180) set_tests_properties(test_simple_custom_op_setup PROPERTIES TIMEOUT 250) set_tests_properties(test_simple_custom_op_jit PROPERTIES TIMEOUT 180) diff --git a/python/paddle/fluid/tests/custom_op/relu_op3.cc b/python/paddle/fluid/tests/custom_op/relu_op3.cc new file mode 100644 index 0000000000000000000000000000000000000000..ace9598c586866edd4be8cf99e4d3a783e18788b --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/relu_op3.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class Relu3Op : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto in_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Y", in_dims); + } +}; + +class Relu3OpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddOutput("Y", "Output of relu_op"); + AddComment(R"DOC( +Relu3 Operator. +)DOC"); + } +}; + +class Relu3GradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto in_dims = ctx->GetInputDim(framework::GradVarName("Y")); + ctx->SetOutputDim(framework::GradVarName("X"), in_dims); + } +}; + +template +class Relu3GradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr op) const override { + op->SetType("relu3_grad"); + op->SetInput("Y", this->Output("Y")); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + op->SetAttrMap(this->Attrs()); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +using Tensor = framework::Tensor; + +template +class Relu3Kernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_t = ctx.Input("X"); + auto* out_t = ctx.Output("Y"); + auto x = in_t->data(); + auto y = out_t->mutable_data(ctx.GetPlace()); + for (int i = 0; i < in_t->numel(); ++i) { + y[i] = std::max(static_cast(0.), x[i]); + } + } +}; + +template +class Relu3GradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dy_t = ctx.Input(framework::GradVarName("Y")); + auto* y_t = ctx.Input("Y"); + auto* dx_t = ctx.Output(framework::GradVarName("X")); + + auto dy = dy_t->data(); + auto y = y_t->data(); + auto dx = dx_t->mutable_data(ctx.GetPlace()); + + for (int i = 0; i < y_t->numel(); ++i) { + dx[i] = dy[i] * (y[i] > static_cast(0) ? 1. : 0.); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; +REGISTER_OPERATOR(relu3, + ops::Relu3Op, + ops::Relu3OpMaker, + ops::Relu3GradMaker, + ops::Relu3GradMaker); +REGISTER_OPERATOR(relu3_grad, ops::Relu3GradOp); +REGISTER_OP_CPU_KERNEL(relu3, + ops::Relu3Kernel, + ops::Relu3Kernel); +REGISTER_OP_CPU_KERNEL(relu3_grad, + ops::Relu3GradKernel, + ops::Relu3GradKernel); diff --git a/python/paddle/fluid/tests/custom_op/relu_op3.cu b/python/paddle/fluid/tests/custom_op/relu_op3.cu new file mode 100644 index 0000000000000000000000000000000000000000..8a229cafebb1d028059af9f63f28ded8117b1d12 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/relu_op3.cu @@ -0,0 +1,87 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__global__ void KeRelu3(const T* x, const int num, T* y) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + y[i] = max(x[i], static_cast(0.)); + } +} + +template +class Relu3CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_t = ctx.Input("X"); + auto* out_t = ctx.Output("Y"); + auto x = in_t->data(); + auto y = out_t->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + int num = in_t->numel(); + int block = 512; + int grid = (num + block - 1) / block; + KeRelu3<<>>(x, num, y); + } +}; + +template +__global__ void KeRelu3Grad(const T* y, const T* dy, const int num, T* dx) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.); + } +} + +template +class Relu3GradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dy_t = ctx.Input(framework::GradVarName("Y")); + auto* y_t = ctx.Input("Y"); + auto* dx_t = ctx.Output(framework::GradVarName("X")); + + auto dy = dy_t->data(); + auto y = y_t->data(); + auto dx = dx_t->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + int num = dy_t->numel(); + int block = 512; + int grid = (num + block - 1) / block; + KeRelu3Grad<<>>(y, dy, num, dx); + } +}; + +} // namespace operators +} // namespace paddle + +using CUDA = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(relu3, + paddle::operators::Relu3CUDAKernel, + paddle::operators::Relu3CUDAKernel); + +REGISTER_OP_CUDA_KERNEL(relu3_grad, + paddle::operators::Relu3GradCUDAKernel, + paddle::operators::Relu3GradCUDAKernel); diff --git a/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc b/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc new file mode 100644 index 0000000000000000000000000000000000000000..9a72db10069a00aec76063de8e4399587ca146af --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +std::vector relu_cuda_forward(const paddle::Tensor& x); +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out); + +std::vector ReluForward(const paddle::Tensor& x); + +std::vector ReluBackward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out); + +std::vector> ReluInferShape(std::vector x_shape); + +std::vector ReluInferDType(paddle::DataType x_dtype); + +// Reuse codes in `relu_op_simple.cc/cu` to register another custom operator +// to test jointly compile multi operators at same time. +PD_BUILD_OPERATOR("relu3") + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForward)) + .SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) + .SetBackwardOp("relu3_grad") + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackward)); diff --git a/python/paddle/fluid/tests/custom_op/setup_build.py b/python/paddle/fluid/tests/custom_op/setup_build.py index 5993ef1a124b7708c345b37b081a97b95d0c2495..408738170c0e265a3acc9968d1eb5347d49804ec 100644 --- a/python/paddle/fluid/tests/custom_op/setup_build.py +++ b/python/paddle/fluid/tests/custom_op/setup_build.py @@ -27,7 +27,8 @@ setup( ext_modules=[ CUDAExtension( name='librelu2_op_from_setup', - sources=['relu_op.cc', 'relu_op.cu'], + sources=['relu_op3.cc', 'relu_op3.cu', 'relu_op.cc', + 'relu_op.cu'], # test for multi ops include_dirs=paddle_includes, extra_compile_args=extra_compile_args) ], diff --git a/python/paddle/fluid/tests/custom_op/setup_install.py b/python/paddle/fluid/tests/custom_op/setup_install.py index 80477bfbea8bc53c5925732aef5f32f5d5a7985a..f8fadbeee54a22ba986b549473249f98065d5b4f 100644 --- a/python/paddle/fluid/tests/custom_op/setup_install.py +++ b/python/paddle/fluid/tests/custom_op/setup_install.py @@ -25,7 +25,8 @@ setup( ext_modules=[ CUDAExtension( name='custom_relu2', - sources=['relu_op.cc', 'relu_op.cu'], + sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc', + 'relu_op3.cu'], # test for multi ops include_dirs=paddle_includes, extra_compile_args=extra_compile_args) ]) diff --git a/python/paddle/fluid/tests/custom_op/setup_install_simple.py b/python/paddle/fluid/tests/custom_op/setup_install_simple.py index f8eba6b3ad634e9861fbb2a351d071a1e6008670..2aebbc299a606ba2a7a318a5887e3a85f89c426a 100644 --- a/python/paddle/fluid/tests/custom_op/setup_install_simple.py +++ b/python/paddle/fluid/tests/custom_op/setup_install_simple.py @@ -22,7 +22,9 @@ setup( ext_modules=[ CUDAExtension( name='simple_setup_relu2', - sources=['relu_op_simple.cc', 'relu_op_simple.cu'], + sources=[ + 'relu_op_simple.cc', 'relu_op_simple.cu', 'relu_op3_simple.cc' + ], # test for multi ops include_dirs=paddle_includes, extra_compile_args=extra_compile_args) ]) diff --git a/python/paddle/fluid/tests/custom_op/test_jit_load.py b/python/paddle/fluid/tests/custom_op/test_jit_load.py index 084c91673890a8ac8d38db96f60348b349914f85..222c69f5edcc56747bf3819ac4bfd5b1915e3ded 100644 --- a/python/paddle/fluid/tests/custom_op/test_jit_load.py +++ b/python/paddle/fluid/tests/custom_op/test_jit_load.py @@ -24,9 +24,9 @@ from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_me use_new_custom_op_load_method(False) # Compile and load custom op Just-In-Time. -relu2 = load( - name='relu2', - sources=['relu_op.cc', 'relu_op.cu'], +custom_module = load( + name='custom_relu2', + sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc', 'relu_op3.cu'], interpreter='python', # add for unittest extra_include_paths=paddle_includes, # add for Coverage CI extra_cflags=extra_compile_args, # add for Coverage CI @@ -37,12 +37,14 @@ relu2 = load( class TestJITLoad(unittest.TestCase): def test_api(self): raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32') + gt_data = np.array([[0, 1, 0], [1, 0, 0]]).astype('float32') x = paddle.to_tensor(raw_data, dtype='float32') # use custom api - out = relu2(x) - self.assertTrue( - np.array_equal(out.numpy(), - np.array([[0, 1, 0], [1, 0, 0]]).astype('float32'))) + out = custom_module.relu2(x) + out3 = custom_module.relu3(x) + + self.assertTrue(np.array_equal(out.numpy(), gt_data)) + self.assertTrue(np.array_equal(out3.numpy(), gt_data)) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/custom_op/test_custom_op_with_setup.py b/python/paddle/fluid/tests/custom_op/test_setup_build.py similarity index 56% rename from python/paddle/fluid/tests/custom_op/test_custom_op_with_setup.py rename to python/paddle/fluid/tests/custom_op/test_setup_build.py index d7bf687b2f1e2a49023ab3dc5a03f0b6d0281bd4..1ef14c2e3aaa3c9412660b7afcc14e22aa6402ea 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_op_with_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_setup_build.py @@ -14,8 +14,11 @@ import os import unittest +import numpy as np from test_custom_op import CustomOpTest, load_so +import paddle from paddle.utils.cpp_extension.extension_utils import run_cmd +from paddle.fluid.layer_helper import LayerHelper from paddle.utils.cpp_extension.extension_utils import use_new_custom_op_load_method # switch to old custom op method @@ -32,6 +35,34 @@ def compile_so(): run_cmd(cmd) +# `setup.py build` only produce .so file containing multi operators. +# Python Interface should be added manually. `relu2` api is in `test_custom_op.py` +def relu3(x, name=None): + helper = LayerHelper("relu3", **locals()) + out = helper.create_variable( + type=x.type, name=name, dtype=x.dtype, persistable=False) + helper.append_op(type="relu3", inputs={"X": x}, outputs={"Y": out}) + return out + + +class TestCompileMultiOp(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_relu3(self): + raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32') + x = paddle.to_tensor(raw_data, dtype='float32') + # use custom api + out = relu3(x) + + self.assertTrue( + np.array_equal(out.numpy(), + np.array([[0, 1, 0], [1, 0, 0]]).astype('float32'))) + + def tearDown(self): + paddle.enable_static() + + if __name__ == '__main__': compile_so() load_so(so_name='librelu2_op_from_setup.so') diff --git a/python/paddle/fluid/tests/custom_op/test_setup_install.py b/python/paddle/fluid/tests/custom_op/test_setup_install.py index bc49b26c45caec5c1b5c59fd26a705dd61c35523..1fd7b8a06f952336274a78b253d096644fbdf08f 100644 --- a/python/paddle/fluid/tests/custom_op/test_setup_install.py +++ b/python/paddle/fluid/tests/custom_op/test_setup_install.py @@ -51,13 +51,14 @@ class TestSetUpInstall(unittest.TestCase): import custom_relu2 raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32') + gt_data = np.array([[0, 1, 0], [1, 0, 0]]).astype('float32') x = paddle.to_tensor(raw_data, dtype='float32') # use custom api out = custom_relu2.relu2(x) + out3 = custom_relu2.relu3(x) - self.assertTrue( - np.array_equal(out.numpy(), - np.array([[0, 1, 0], [1, 0, 0]]).astype('float32'))) + self.assertTrue(np.array_equal(out.numpy(), gt_data)) + self.assertTrue(np.array_equal(out3.numpy(), gt_data)) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py b/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py index 43f2abd93f5a0893995c3063cde806c061f461e2..926ab4064a42ca48441d3992d247f4bf0639bc34 100644 --- a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py @@ -21,16 +21,16 @@ from utils import paddle_includes, extra_compile_args from test_simple_custom_op_setup import relu2_dynamic, relu2_static # Compile and load custom op Just-In-Time. -simple_relu2 = load( +custom_module = load( name='simple_jit_relu2', - sources=['relu_op_simple.cc', 'relu_op_simple.cu'], + sources=['relu_op_simple.cc', 'relu_op_simple.cu', 'relu_op3_simple.cc'], extra_include_paths=paddle_includes, # add for Coverage CI extra_cflags=extra_compile_args) # add for Coverage CI class TestJITLoad(unittest.TestCase): def setUp(self): - self.custom_op = simple_relu2 + self.custom_ops = [custom_module.relu2, custom_module.relu3] self.dtypes = ['float32', 'float64'] self.devices = ['cpu', 'gpu'] @@ -38,28 +38,30 @@ class TestJITLoad(unittest.TestCase): for device in self.devices: for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - out = relu2_static(self.custom_op, device, dtype, x) - pd_out = relu2_static(self.custom_op, device, dtype, x, False) - self.assertTrue( - np.array_equal(out, pd_out), - "custom op out: {},\n paddle api out: {}".format(out, - pd_out)) + for custom_op in self.custom_ops: + out = relu2_static(custom_op, device, dtype, x) + pd_out = relu2_static(custom_op, device, dtype, x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) def test_dynamic(self): for device in self.devices: for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - out, x_grad = relu2_dynamic(self.custom_op, device, dtype, x) - pd_out, pd_x_grad = relu2_dynamic(self.custom_op, device, dtype, - x, False) - self.assertTrue( - np.array_equal(out, pd_out), - "custom op out: {},\n paddle api out: {}".format(out, - pd_out)) - self.assertTrue( - np.array_equal(x_grad, pd_x_grad), - "custom op x grad: {},\n paddle api x grad: {}".format( - x_grad, pd_x_grad)) + for custom_op in self.custom_ops: + out, x_grad = relu2_dynamic(custom_op, device, dtype, x) + pd_out, pd_x_grad = relu2_dynamic(custom_op, device, dtype, + x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + self.assertTrue( + np.array_equal(x_grad, pd_x_grad), + "custom op x grad: {},\n paddle api x grad: {}".format( + x_grad, pd_x_grad)) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py b/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py index 7d9fb678c46230b363ca8f280315ab3e0121391e..dd69aef86ab99f80b74337878bc48501300e8c3c 100644 --- a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py @@ -107,7 +107,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): # usage: import the package directly import simple_setup_relu2 - self.custom_op = simple_setup_relu2.relu2 + self.custom_ops = [simple_setup_relu2.relu2, simple_setup_relu2.relu3] self.dtypes = ['float32', 'float64'] self.devices = ['cpu', 'gpu'] @@ -116,40 +116,42 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): for device in self.devices: for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - out = relu2_static(self.custom_op, device, dtype, x) - pd_out = relu2_static(self.custom_op, device, dtype, x, False) - self.assertTrue( - np.array_equal(out, pd_out), - "custom op out: {},\n paddle api out: {}".format(out, - pd_out)) + for custom_op in self.custom_ops: + out = relu2_static(custom_op, device, dtype, x) + pd_out = relu2_static(custom_op, device, dtype, x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) def test_static_pe(self): for device in self.devices: for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - out = relu2_static_pe(self.custom_op, device, dtype, x) - pd_out = relu2_static_pe(self.custom_op, device, dtype, x, - False) - self.assertTrue( - np.array_equal(out, pd_out), - "custom op out: {},\n paddle api out: {}".format(out, - pd_out)) + for custom_op in self.custom_ops: + out = relu2_static_pe(custom_op, device, dtype, x) + pd_out = relu2_static_pe(custom_op, device, dtype, x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) def test_dynamic(self): for device in self.devices: for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - out, x_grad = relu2_dynamic(self.custom_op, device, dtype, x) - pd_out, pd_x_grad = relu2_dynamic(self.custom_op, device, dtype, - x, False) - self.assertTrue( - np.array_equal(out, pd_out), - "custom op out: {},\n paddle api out: {}".format(out, - pd_out)) - self.assertTrue( - np.array_equal(x_grad, pd_x_grad), - "custom op x grad: {},\n paddle api x grad: {}".format( - x_grad, pd_x_grad)) + for custom_op in self.custom_ops: + out, x_grad = relu2_dynamic(custom_op, device, dtype, x) + pd_out, pd_x_grad = relu2_dynamic(custom_op, device, dtype, + x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + self.assertTrue( + np.array_equal(x_grad, pd_x_grad), + "custom op x grad: {},\n paddle api x grad: {}".format( + x_grad, pd_x_grad)) if __name__ == '__main__': diff --git a/python/paddle/utils/cpp_extension/cpp_extension.py b/python/paddle/utils/cpp_extension/cpp_extension.py index 6975b884e9c523a279844ad8e28426d7a6cdd90a..93be1ec8dbe0b690c92ad4efae06c3b15c798ab6 100644 --- a/python/paddle/utils/cpp_extension/cpp_extension.py +++ b/python/paddle/utils/cpp_extension/cpp_extension.py @@ -271,23 +271,21 @@ class BuildExtension(build_ext, object): """ Record custum op inforomation. """ - # parse op name - sources = [] - for extension in self.extensions: - sources.extend(extension.sources) - - sources = [os.path.abspath(s) for s in sources] - op_name = parse_op_name_from(sources) - # parse shared library abs path outputs = self.get_outputs() assert len(outputs) == 1 - - build_directory = os.path.abspath(outputs[0]) - so_name = os.path.basename(build_directory) - CustomOpInfo.instance().add(op_name, - so_name=so_name, - build_directory=build_directory) + # multi operators built into same one .so file + so_path = os.path.abspath(outputs[0]) + so_name = os.path.basename(so_path) + + for i, extension in enumerate(self.extensions): + sources = [os.path.abspath(s) for s in extension.sources] + op_names = parse_op_name_from(sources) + + for op_name in op_names: + CustomOpInfo.instance().add(op_name, + so_name=so_name, + build_directory=so_path) class EasyInstallCommand(easy_install, object): diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index a9558850680d4701491d240cb4a93840aab0f250..f4c83998626e69430196bf166e00a597d35a2e49 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -47,7 +47,7 @@ GCC_MINI_VERSION = (5, 4, 0) # Give warning if using wrong compiler WRONG_COMPILER_WARNING = ''' ************************************* - * Compiler Compatibility WARNING * + * Compiler Compatibility WARNING * ************************************* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -70,7 +70,7 @@ ABI_INCOMPATIBILITY_WARNING = ''' !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -Found that your compiler ({user_compiler} == {version}) may be ABI-incompatible with pre-insalled Paddle! +Found that your compiler ({user_compiler} == {version}) may be ABI-incompatible with pre-installed Paddle! Please use compiler that is ABI-compatible with GCC >= 5.4 (Recommended 8.2). See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html for ABI Compatibility @@ -125,12 +125,13 @@ def custom_write_stub(resource, pyfile): import types import paddle - def inject_ext_module(module_name, api_name): + def inject_ext_module(module_name, api_names): if module_name in sys.modules: return sys.modules[module_name] new_module = types.ModuleType(module_name) - setattr(new_module, api_name, eval(api_name)) + for api_name in api_names: + setattr(new_module, api_name, eval(api_name)) return new_module @@ -141,9 +142,8 @@ def custom_write_stub(resource, pyfile): assert os.path.exists(so_path) # load custom op shared library with abs path - new_custom_op = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path) - assert len(new_custom_op) == 1 - m = inject_ext_module(__name__, new_custom_op[0]) + new_custom_ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path) + m = inject_ext_module(__name__, new_custom_ops) __bootstrap__() @@ -154,21 +154,25 @@ def custom_write_stub(resource, pyfile): _, op_info = CustomOpInfo.instance().last() so_path = op_info.build_directory - new_custom_op = load_op_meta_info_and_register_op(so_path) - assert len(new_custom_op - ) == 1, "The number of loaded costom operators is %d" % len( - new_custom_op) + new_custom_ops = load_op_meta_info_and_register_op(so_path) + assert len( + new_custom_ops + ) > 0, "Required at least one custom operators, but received len(custom_op) = %d" % len( + new_custom_ops) # NOTE: To avoid importing .so file instead of python file because they have same name, # we rename .so shared library to another name, see EasyInstallCommand. filename, ext = os.path.splitext(resource) resource = filename + "_pd_" + ext + api_content = [] + for op_name in new_custom_ops: + api_content.append(_custom_api_content(op_name)) + with open(pyfile, 'w') as f: f.write( _stub_template.format( - resource=resource, - custom_api=_custom_api_content(new_custom_op[0]))) + resource=resource, custom_api='\n\n'.join(api_content))) OpInfo = collections.namedtuple('OpInfo', @@ -406,11 +410,12 @@ def parse_op_info(op_name): return in_names, out_infos -def _import_module_from_library(name, build_directory, verbose=False): +def _import_module_from_library(module_name, build_directory, verbose=False): """ Load .so shared library and import it as callable python module. """ - ext_path = os.path.join(build_directory, name + '.so') + # TODO(Aurelius84): Consider file suffix is .dll on Windows Platform. + ext_path = os.path.join(build_directory, module_name + '.so') if not os.path.exists(ext_path): raise FileNotFoundError("Extension path: {} does not exist.".format( ext_path)) @@ -418,27 +423,30 @@ def _import_module_from_library(name, build_directory, verbose=False): # load custom op_info and kernels from .so shared library log_v('loading shared library from: {}'.format(ext_path), verbose) op_names = load_op_meta_info_and_register_op(ext_path) - assert len(op_names) == 1 # generate Python api in ext_path - return _generate_python_module(op_names[0], build_directory, verbose) + return _generate_python_module(module_name, op_names, build_directory, + verbose) -def _generate_python_module(op_name, build_directory, verbose=False): +def _generate_python_module(module_name, + op_names, + build_directory, + verbose=False): """ Automatically generate python file to allow import or load into as module """ - api_file = os.path.join(build_directory, op_name + '.py') + api_file = os.path.join(build_directory, module_name + '.py') log_v("generate api file: {}".format(api_file), verbose) # write into .py file - api_content = _custom_api_content(op_name) + api_content = [_custom_api_content(op_name) for op_name in op_names] with open(api_file, 'w') as f: - f.write(api_content) + f.write('\n\n'.join(api_content)) # load module - custom_api = _load_module_from_file(op_name, api_file, verbose) - return custom_api + custom_module = _load_module_from_file(api_file, verbose) + return custom_module def _custom_api_content(op_name): @@ -475,7 +483,7 @@ def _custom_api_content(op_name): return api_content -def _load_module_from_file(op_name, api_file_path, verbose=False): +def _load_module_from_file(api_file_path, verbose=False): """ Load module from python file. """ @@ -494,8 +502,7 @@ def _load_module_from_file(op_name, api_file_path, verbose=False): loader = machinery.SourceFileLoader(ext_name, api_file_path) module = loader.load_module() - assert hasattr(module, op_name) - return getattr(module, op_name) + return module def _get_api_inputs_str(op_name): @@ -621,11 +628,7 @@ def parse_op_name_from(sources): content = f.read() op_names |= regex(content) - # TODO(Aurelius84): Support register more customs op at once - assert len( - op_names) == 1, "The number of registered costom operators is %d" % len( - op_names) - return list(op_names)[0] + return list(op_names) def run_cmd(command, verbose=False):