未验证 提交 5ecd0ad5 编写于 作者: H HongyuJia 提交者: GitHub

[Custom XPU Support] Custom extension support xpu backend (#48733)

* support custom_xpu

* update cmake to test xpu

* support custom_xpu, verify mechanism

* fix test_custom_relu_op_xpu_setup.py, test=kunlun

* fix FLAGS_init_allocated_mem

* cancel TIMEOUT property

* reset FLAGS_init_allocated_mem property
上级 5b6767ac
......@@ -284,7 +284,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
auto* true_out = true_out_ptrs.at(i);
auto calc_out =
std::dynamic_pointer_cast<phi::DenseTensor>(calc_outs->at(i).impl());
// assgin meta info
// assign meta info
auto* true_out_meta = phi::DenseTensorUtils::GetMutableMeta(true_out);
true_out_meta->dims = calc_out->dims();
true_out_meta->dtype = calc_out->dtype();
......@@ -708,6 +708,10 @@ static void RegisterOperatorKernel(const std::string& name,
RegisterOperatorKernelWithPlace(
name, op_kernel_func, proto::VarType::RAW, platform::CUDAPlace());
#endif
#if defined(PADDLE_WITH_XPU)
RegisterOperatorKernelWithPlace(
name, op_kernel_func, proto::VarType::RAW, platform::XPUPlace());
#endif
}
void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
......
......@@ -21,6 +21,12 @@ if(WITH_GPU OR APPLE)
endif()
endif()
if(WITH_XPU)
set(CUSTOM_XPU_ENVS FLAGS_init_allocated_mem=0)
py_test(test_custom_relu_op_xpu_setup SRCS test_custom_relu_op_xpu_setup.py
ENVS ${CUSTOM_XPU_ENVS})
endif()
py_test(test_custom_raw_op_kernel_op SRCS test_custom_raw_op_kernel_op.py)
set_tests_properties(test_custom_raw_op_kernel_op PROPERTIES TIMEOUT 180)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "paddle/extension.h"
#define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
#define CHECK_XPU_INPUT(x) PD_CHECK(x.is_xpu(), #x " must be a XPU Tensor.")
template <typename data_t>
void relu_cpu_forward_kernel(const data_t* x_data,
data_t* out_data,
int64_t x_numel) {
PD_CHECK(x_data != nullptr, "x_data is nullptr.");
PD_CHECK(out_data != nullptr, "out_data is nullptr.");
for (int64_t i = 0; i < x_numel; ++i) {
out_data[i] = std::max(static_cast<data_t>(0.), x_data[i]);
}
}
std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
CHECK_CPU_INPUT(x);
auto out = paddle::empty_like(x);
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.data<data_t>(), x.numel());
}));
return {out};
}
std::vector<paddle::Tensor> relu_xpu_forward(const paddle::Tensor& x) {
CHECK_XPU_INPUT(x);
auto out = paddle::relu(x);
return {out};
}
std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
if (x.is_cpu()) {
return relu_cpu_forward(x);
} else if (x.is_xpu()) {
return relu_xpu_forward(x);
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_relu)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward));
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils import extra_compile_args, paddle_includes
from paddle.utils.cpp_extension import CppExtension, setup
setup(
name='custom_relu_xpu_module_setup',
ext_modules=CppExtension( # XPU don't support GPU
sources=['custom_relu_op_xpu.cc'],
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args,
verbose=True,
),
)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import site
import sys
import unittest
import numpy as np
import paddle
import paddle.static as static
from paddle.fluid.framework import _test_eager_guard
from paddle.utils.cpp_extension.extension_utils import run_cmd
def custom_relu_dynamic(func, device, dtype, np_x, use_func=True):
paddle.set_device(device)
t = paddle.to_tensor(np_x, dtype=dtype)
out = func(t) if use_func else paddle.nn.functional.relu(t)
return out.numpy()
def custom_relu_static(
func, device, dtype, np_x, use_func=True, test_infer=False
):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
out = func(x) if use_func else paddle.nn.functional.relu(x)
exe = static.Executor()
exe.run(static.default_startup_program())
# in static mode, x data has been covered by out
out_v = exe.run(
static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name],
)
paddle.disable_static()
return out_v
class TestNewCustomOpSetUpInstall(unittest.TestCase):
def setUp(self):
cur_dir = os.path.dirname(os.path.abspath(__file__))
# compile, install the custom op egg into site-packages under background
# Currently custom XPU op does not support Windows
if os.name == 'nt':
return
cmd = 'cd {} && {} custom_relu_xpu_setup.py install'.format(
cur_dir, sys.executable
)
run_cmd(cmd)
site_dir = site.getsitepackages()[0]
custom_egg_path = [
x
for x in os.listdir(site_dir)
if 'custom_relu_xpu_module_setup' in x
]
assert len(custom_egg_path) == 1, "Matched egg number is %d." % len(
custom_egg_path
)
sys.path.append(os.path.join(site_dir, custom_egg_path[0]))
# usage: import the package directly
import custom_relu_xpu_module_setup
self.custom_op = custom_relu_xpu_module_setup.custom_relu
self.dtypes = ['float32', 'float64']
self.devices = ['xpu']
# config seed
SEED = 2021
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
def test_static(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = custom_relu_static(self.custom_op, device, dtype, x)
pd_out = custom_relu_static(
self.custom_op, device, dtype, x, False
)
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out
),
)
def func_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = custom_relu_dynamic(self.custom_op, device, dtype, x)
pd_out = custom_relu_dynamic(
self.custom_op, device, dtype, x, False
)
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out
),
)
def test_dynamic(self):
with _test_eager_guard():
self.func_dynamic()
self.func_dynamic()
if __name__ == '__main__':
unittest.main()
......@@ -2,7 +2,7 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtaina copy of the License at
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册