From 5ecd0ad52ab22f0f6c461c70d9bb714278e2e6f5 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Fri, 9 Dec 2022 17:56:36 +0800 Subject: [PATCH] [Custom XPU Support] Custom extension support xpu backend (#48733) * support custom_xpu * update cmake to test xpu * support custom_xpu, verify mechanism * fix test_custom_relu_op_xpu_setup.py, test=kunlun * fix FLAGS_init_allocated_mem * cancel TIMEOUT property * reset FLAGS_init_allocated_mem property --- paddle/fluid/framework/custom_operator.cc | 6 +- .../fluid/tests/custom_op/CMakeLists.txt | 6 + .../tests/custom_op/custom_relu_op_xpu.cc | 66 +++++++++ .../tests/custom_op/custom_relu_xpu_setup.py | 27 ++++ .../test_custom_relu_op_xpu_setup.py | 136 ++++++++++++++++++ .../custom_op/test_custom_simple_slice.py | 2 +- 6 files changed, 241 insertions(+), 2 deletions(-) create mode 100644 python/paddle/fluid/tests/custom_op/custom_relu_op_xpu.cc create mode 100644 python/paddle/fluid/tests/custom_op/custom_relu_xpu_setup.py create mode 100644 python/paddle/fluid/tests/custom_op/test_custom_relu_op_xpu_setup.py diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 1ca2f4e56dd..87201c93c75 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -284,7 +284,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, auto* true_out = true_out_ptrs.at(i); auto calc_out = std::dynamic_pointer_cast(calc_outs->at(i).impl()); - // assgin meta info + // assign meta info auto* true_out_meta = phi::DenseTensorUtils::GetMutableMeta(true_out); true_out_meta->dims = calc_out->dims(); true_out_meta->dtype = calc_out->dtype(); @@ -708,6 +708,10 @@ static void RegisterOperatorKernel(const std::string& name, RegisterOperatorKernelWithPlace( name, op_kernel_func, proto::VarType::RAW, platform::CUDAPlace()); #endif +#if defined(PADDLE_WITH_XPU) + RegisterOperatorKernelWithPlace( + name, op_kernel_func, proto::VarType::RAW, platform::XPUPlace()); +#endif } void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index 2addead40fc..1b2e8b6f868 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -21,6 +21,12 @@ if(WITH_GPU OR APPLE) endif() endif() +if(WITH_XPU) + set(CUSTOM_XPU_ENVS FLAGS_init_allocated_mem=0) + py_test(test_custom_relu_op_xpu_setup SRCS test_custom_relu_op_xpu_setup.py + ENVS ${CUSTOM_XPU_ENVS}) +endif() + py_test(test_custom_raw_op_kernel_op SRCS test_custom_raw_op_kernel_op.py) set_tests_properties(test_custom_raw_op_kernel_op PROPERTIES TIMEOUT 180) diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op_xpu.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op_xpu.cc new file mode 100644 index 00000000000..8d9e2e2af49 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op_xpu.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" + +#define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.") +#define CHECK_XPU_INPUT(x) PD_CHECK(x.is_xpu(), #x " must be a XPU Tensor.") + +template +void relu_cpu_forward_kernel(const data_t* x_data, + data_t* out_data, + int64_t x_numel) { + PD_CHECK(x_data != nullptr, "x_data is nullptr."); + PD_CHECK(out_data != nullptr, "out_data is nullptr."); + for (int64_t i = 0; i < x_numel; ++i) { + out_data[i] = std::max(static_cast(0.), x_data[i]); + } +} + +std::vector relu_cpu_forward(const paddle::Tensor& x) { + CHECK_CPU_INPUT(x); + auto out = paddle::empty_like(x); + + PD_DISPATCH_FLOATING_TYPES( + x.type(), "relu_cpu_forward", ([&] { + relu_cpu_forward_kernel( + x.data(), out.data(), x.numel()); + })); + + return {out}; +} + +std::vector relu_xpu_forward(const paddle::Tensor& x) { + CHECK_XPU_INPUT(x); + auto out = paddle::relu(x); + return {out}; +} + +std::vector ReluForward(const paddle::Tensor& x) { + if (x.is_cpu()) { + return relu_cpu_forward(x); + } else if (x.is_xpu()) { + return relu_xpu_forward(x); + } else { + PD_THROW("Not implemented."); + } +} + +PD_BUILD_OP(custom_relu) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForward)); diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_xpu_setup.py b/python/paddle/fluid/tests/custom_op/custom_relu_xpu_setup.py new file mode 100644 index 00000000000..b592cebbadf --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/custom_relu_xpu_setup.py @@ -0,0 +1,27 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from utils import extra_compile_args, paddle_includes + +from paddle.utils.cpp_extension import CppExtension, setup + +setup( + name='custom_relu_xpu_module_setup', + ext_modules=CppExtension( # XPU don't support GPU + sources=['custom_relu_op_xpu.cc'], + include_dirs=paddle_includes, + extra_compile_args=extra_compile_args, + verbose=True, + ), +) diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_xpu_setup.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_xpu_setup.py new file mode 100644 index 00000000000..d1500863588 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_xpu_setup.py @@ -0,0 +1,136 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import site +import sys +import unittest + +import numpy as np + +import paddle +import paddle.static as static +from paddle.fluid.framework import _test_eager_guard +from paddle.utils.cpp_extension.extension_utils import run_cmd + + +def custom_relu_dynamic(func, device, dtype, np_x, use_func=True): + paddle.set_device(device) + + t = paddle.to_tensor(np_x, dtype=dtype) + out = func(t) if use_func else paddle.nn.functional.relu(t) + + return out.numpy() + + +def custom_relu_static( + func, device, dtype, np_x, use_func=True, test_infer=False +): + paddle.enable_static() + paddle.set_device(device) + + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + x = static.data(name='X', shape=[None, 8], dtype=dtype) + out = func(x) if use_func else paddle.nn.functional.relu(x) + + exe = static.Executor() + exe.run(static.default_startup_program()) + # in static mode, x data has been covered by out + out_v = exe.run( + static.default_main_program(), + feed={'X': np_x}, + fetch_list=[out.name], + ) + + paddle.disable_static() + return out_v + + +class TestNewCustomOpSetUpInstall(unittest.TestCase): + def setUp(self): + cur_dir = os.path.dirname(os.path.abspath(__file__)) + # compile, install the custom op egg into site-packages under background + # Currently custom XPU op does not support Windows + if os.name == 'nt': + return + cmd = 'cd {} && {} custom_relu_xpu_setup.py install'.format( + cur_dir, sys.executable + ) + run_cmd(cmd) + + site_dir = site.getsitepackages()[0] + custom_egg_path = [ + x + for x in os.listdir(site_dir) + if 'custom_relu_xpu_module_setup' in x + ] + assert len(custom_egg_path) == 1, "Matched egg number is %d." % len( + custom_egg_path + ) + sys.path.append(os.path.join(site_dir, custom_egg_path[0])) + + # usage: import the package directly + import custom_relu_xpu_module_setup + + self.custom_op = custom_relu_xpu_module_setup.custom_relu + + self.dtypes = ['float32', 'float64'] + self.devices = ['xpu'] + + # config seed + SEED = 2021 + paddle.seed(SEED) + paddle.framework.random._manual_program_seed(SEED) + + def test_static(self): + for device in self.devices: + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + out = custom_relu_static(self.custom_op, device, dtype, x) + pd_out = custom_relu_static( + self.custom_op, device, dtype, x, False + ) + np.testing.assert_array_equal( + out, + pd_out, + err_msg='custom op out: {},\n paddle api out: {}'.format( + out, pd_out + ), + ) + + def func_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + out = custom_relu_dynamic(self.custom_op, device, dtype, x) + pd_out = custom_relu_dynamic( + self.custom_op, device, dtype, x, False + ) + np.testing.assert_array_equal( + out, + pd_out, + err_msg='custom op out: {},\n paddle api out: {}'.format( + out, pd_out + ), + ) + + def test_dynamic(self): + with _test_eager_guard(): + self.func_dynamic() + self.func_dynamic() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_custom_simple_slice.py b/python/paddle/fluid/tests/custom_op/test_custom_simple_slice.py index 1c54568e69c..4113e1c650d 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_simple_slice.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_simple_slice.py @@ -2,7 +2,7 @@ # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -# You may obtaina copy of the License at +# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # -- GitLab