From 5c3873f632041a879950ffc9363989d53c34e0e0 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 8 Feb 2022 00:05:34 +0800 Subject: [PATCH] Add __PD_DEFINE_RAW_OP_KERNEL_FUNC for registering custom op kernel with ExecutionContext (#39352) * hack custom op * add ut * skip windows ci --- paddle/fluid/framework/custom_operator.cc | 84 +++++++++------- paddle/fluid/framework/custom_operator.h | 5 +- .../framework/custom_raw_op_kernel_func.h | 27 ++++++ paddle/fluid/pybind/pybind.cc | 9 ++ .../fluid/tests/custom_op/CMakeLists.txt | 3 + .../custom_op/custom_raw_op_kernel_op.cc | 52 ++++++++++ .../custom_op/custom_raw_op_kernel_op.cu | 21 ++++ .../tests/custom_op/custom_raw_op_kernel_op.h | 84 ++++++++++++++++ .../custom_raw_op_kernel_op_setup.py | 50 ++++++++++ .../custom_op/test_custom_raw_op_kernel_op.py | 97 +++++++++++++++++++ 10 files changed, 395 insertions(+), 37 deletions(-) create mode 100644 paddle/fluid/framework/custom_raw_op_kernel_func.h create mode 100644 python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cc create mode 100644 python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cu create mode 100644 python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.h create mode 100644 python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op_setup.py create mode 100644 python/paddle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 68445e7976..31243bad30 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -61,27 +61,27 @@ static T* DynLoad(void* handle, std::string name) { return func; } -inline bool IsGradVar(const std::string& var_name) { +inline static bool IsGradVar(const std::string& var_name) { std::string suffix = kGradVarSuffix; return var_name.rfind(suffix) != std::string::npos; } -inline bool IsDuplicableVar(const std::string& var_name) { +inline static bool IsDuplicableVar(const std::string& var_name) { std::string suffix = kTensorVectorSuffix; return var_name.rfind(suffix) != std::string::npos; } -inline std::string NoGrad(const std::string& var_name) { +inline static std::string NoGrad(const std::string& var_name) { std::string suffix = kGradVarSuffix; return var_name.substr(0, var_name.size() - kGradVarSuffixSize); } -inline bool IsMemberOf(const std::vector& vec, - const std::string& name) { +inline static bool IsMemberOf(const std::vector& vec, + const std::string& name) { return std::find(vec.cbegin(), vec.cend(), name) != vec.cend(); } -std::vector ParseAttrStr(const std::string& attr) { +static std::vector ParseAttrStr(const std::string& attr) { auto split_pos = attr.find_first_of(":"); PADDLE_ENFORCE_NE(split_pos, std::string::npos, platform::errors::InvalidArgument( @@ -602,44 +602,57 @@ class CustomGradOpMaker //////////// Operator and Kernel Register ////////////// -void RegisterOperatorKernelWithPlace(const std::string& name, - const paddle::KernelFunc& kernel_func, - const proto::VarType::Type type, - const PlaceType& place, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& attrs) { +static void RegisterOperatorKernelWithPlace( + const std::string& name, + const OperatorWithKernel::OpKernelFunc& op_kernel_func, + const proto::VarType::Type type, const PlaceType& place) { OpKernelType key(type, experimental::ConvertExtPlaceToInnerPlace(place)); VLOG(3) << "Custom Operator: op kernel key: " << key; - OperatorWithKernel::AllOpKernels()[name][key] = - [kernel_func, inputs, outputs, - attrs](const framework::ExecutionContext& ctx) { - VLOG(3) << "Custom Operator: run custom kernel func in lambda."; - RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs); - }; + OperatorWithKernel::AllOpKernels()[name][key] = op_kernel_func; } -void RegisterOperatorKernel(const std::string& name, - const paddle::KernelFunc& kernel_func, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& attrs) { +static void RegisterOperatorKernel(const std::string& name, + const paddle::KernelFunc& kernel_func, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& attrs, + void* dso_handle) { VLOG(3) << "Custom Operator: op name in kernel: " << name; // NOTE [ Dummy Op Kernel Key ] // TODO(chenweihang): Because execute engine need get device context based // op_kernel_key.place_, so we should register kernel for each // device. But this is not entirely correct, if user only give a cpu kernel, // but call api in gpu device, it will cause error. - RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, - PlaceType::kCPU, inputs, outputs, attrs); + OperatorWithKernel::OpKernelFunc op_kernel_func; + if (kernel_func) { + VLOG(3) << "Register custom operator " << name << " with kernel func"; + op_kernel_func = [kernel_func, inputs, outputs, + attrs](const framework::ExecutionContext& ctx) { + VLOG(3) << "Custom Operator: run custom kernel func in lambda."; + RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs); + }; + } else { + VLOG(3) << "Register custom operator " << name + << " with raw op kernel func"; + PADDLE_ENFORCE_NOT_NULL( + dso_handle, + platform::errors::InvalidArgument( + "The dso handle must be provided if kernel_func is nullptr.")); + using OpKernelFuncPtr = void(const framework::ExecutionContext&); + auto symbol_name = "PD_" + name + "_raw_op_kernel_func"; + auto* func = detail::DynLoad(dso_handle, symbol_name); + op_kernel_func = func; + } + RegisterOperatorKernelWithPlace(name, op_kernel_func, proto::VarType::RAW, + PlaceType::kCPU); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, - PlaceType::kGPU, inputs, outputs, attrs); + RegisterOperatorKernelWithPlace(name, op_kernel_func, proto::VarType::RAW, + PlaceType::kGPU); #endif } -void RegisterOperatorWithMetaInfo( - const std::vector& op_meta_infos) { +void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, + void* dso_handle) { /* Op register */ OpInfo info; @@ -792,7 +805,8 @@ void RegisterOperatorWithMetaInfo( } // Kernel func - RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs); + RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs, + dso_handle); // If grad op or double grad op exists std::string cur_op_name = op_name; @@ -900,7 +914,7 @@ void RegisterOperatorWithMetaInfo( // Kernel func RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs, - grad_op_outputs, grad_op_attrs); + grad_op_outputs, grad_op_attrs, dso_handle); // update current info OpInfoMap::Instance().Insert(cur_op_name, info); @@ -912,14 +926,14 @@ void RegisterOperatorWithMetaInfo( } void RegisterOperatorWithMetaInfoMap( - const paddle::OpMetaInfoMap& op_meta_info_map) { + const paddle::OpMetaInfoMap& op_meta_info_map, void* dso_handle) { auto& meta_info_map = op_meta_info_map.GetMap(); VLOG(3) << "Custom Operator: size of op meta info map - " << meta_info_map.size(); // pair: {op_type, OpMetaInfo} for (auto& pair : meta_info_map) { VLOG(3) << "Custom Operator: pair first -> op name: " << pair.first; - RegisterOperatorWithMetaInfo(pair.second); + RegisterOperatorWithMetaInfo(pair.second, dso_handle); } } @@ -934,7 +948,7 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { detail::DynLoad(handle, "PD_GetOpMetaInfoMap"); auto& op_meta_info_map = get_op_meta_info_map(); - RegisterOperatorWithMetaInfoMap(op_meta_info_map); + RegisterOperatorWithMetaInfoMap(op_meta_info_map, handle); } } // namespace framework diff --git a/paddle/fluid/framework/custom_operator.h b/paddle/fluid/framework/custom_operator.h index d8712e60d2..576237dfbe 100644 --- a/paddle/fluid/framework/custom_operator.h +++ b/paddle/fluid/framework/custom_operator.h @@ -26,10 +26,11 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name); // Register custom op api: register op directly void RegisterOperatorWithMetaInfoMap( - const paddle::OpMetaInfoMap& op_meta_info_map); + const paddle::OpMetaInfoMap& op_meta_info_map, void* dso_handle = nullptr); // Interface for selective register custom op. -void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos); +void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, + void* dso_handle = nullptr); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/custom_raw_op_kernel_func.h b/paddle/fluid/framework/custom_raw_op_kernel_func.h new file mode 100644 index 0000000000..3087f5375f --- /dev/null +++ b/paddle/fluid/framework/custom_raw_op_kernel_func.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/operator.h" +#include "paddle/pten/api/ext/op_meta_info.h" + +// NOTE(zengjinle): this macro is only for internal usage. Commonly, users +// should not use this macro. +#define __PD_DEFINE_RAW_OP_KERNEL_FUNC(op_name, ctx) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_raw_op_kernel_func__##op_name, \ + "__PD_DEFINE_RAW_KERNEL_FUNC must be called in global namespace."); \ + extern "C" void PD_##op_name##_raw_op_kernel_func( \ + const ::paddle::framework::ExecutionContext& ctx) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a5c4bb1a80..e31935848a 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -185,6 +185,14 @@ bool IsCompiledWithCUDA() { #endif } +bool IsCompiledWithNCCL() { +#ifdef PADDLE_WITH_NCCL + return true; +#else + return false; +#endif +} + bool IsCompiledWithROCM() { #ifndef PADDLE_WITH_HIP return false; @@ -2433,6 +2441,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("is_compiled_with_ipu", IsCompiledWithIPU); m.def("is_compiled_with_xpu", IsCompiledWithXPU); m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN); + m.def("is_compiled_with_nccl", IsCompiledWithNCCL); m.def("is_compiled_with_cinn", IsCompiledWithCINN); m.def("is_compiled_with_mlu", IsCompiledWithMLU); m.def("_is_compiled_with_heterps", IsCompiledWithHETERPS); diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index 68b7904135..42aed28074 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -10,6 +10,9 @@ if(WITH_GPU OR APPLE) set_tests_properties(test_custom_relu_model PROPERTIES TIMEOUT 180) endif() +py_test(test_custom_raw_op_kernel_op SRCS test_custom_raw_op_kernel_op.py) +set_tests_properties(test_custom_raw_op_kernel_op PROPERTIES TIMEOUT 180) + # CPU custom op tests: only compile .cc file py_test(test_dispatch_jit SRCS test_dispatch_jit.py) py_test(test_multi_out_jit SRCS test_multi_out_jit.py) diff --git a/python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cc b/python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cc new file mode 100644 index 0000000000..c9a3f7a907 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "custom_raw_op_kernel_op.h" // NOLINT +#include "paddle/fluid/framework/custom_raw_op_kernel_func.h" +#include "paddle/fluid/platform/enforce.h" + +void ReluCPUForward(const paddle::framework::Tensor &x, + paddle::framework::Tensor *y) { + custom_raw_op::ReluForward(x, y); +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +void ReluGPUForward(const paddle::framework::Tensor &x, + paddle::framework::Tensor *y); +#else +void ReluGPUForward(const paddle::framework::Tensor &x, + paddle::framework::Tensor *y) { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "ReluGPUForward is not supported when not compiled with GPU.")); +} +#endif + +__PD_DEFINE_RAW_OP_KERNEL_FUNC(custom_raw_relu, ctx) { + namespace f = paddle::framework; + const auto *x = ctx.Input("X"); + auto *y = ctx.Output("Y"); + PADDLE_ENFORCE_NOT_NULL(x, + paddle::platform::errors::InvalidArgument( + "Input(X) should not be nullptr.")); + PADDLE_ENFORCE_NOT_NULL(y, + paddle::platform::errors::InvalidArgument( + "Input(X) should not be nullptr.")); + if (paddle::platform::is_gpu_place(x->place())) { + ReluGPUForward(*x, y); + } else { + ReluCPUForward(*x, y); + } +} + +PD_BUILD_OP(custom_raw_relu).Inputs({"X"}).Outputs({"Y"}); diff --git a/python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cu b/python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cu new file mode 100644 index 0000000000..72cab225d1 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "custom_raw_op_kernel_op.h" // NOLINT + +void ReluGPUForward(const paddle::framework::Tensor &x, + paddle::framework::Tensor *y) { + custom_raw_op::ReluForward(x, y); +} diff --git a/python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.h b/python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.h new file mode 100644 index 0000000000..f919340303 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.h @@ -0,0 +1,84 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/for_range.h" + +namespace custom_raw_op { + +struct ReluFunctor { + explicit ReluFunctor(const paddle::framework::Tensor &x, + paddle::framework::Tensor *y) + : x_(x), y_(y) {} + + template + struct Impl { + Impl(const U *x, U *y) : x_(x), y_(y) {} + + HOSTDEVICE void operator()(size_t i) const { + y_[i] = (x_[i] > static_cast(0) ? x_[i] : static_cast(0)); + } + + private: + const U *x_; + U *y_; + }; + + template + void apply() { + auto n = x_.numel(); + auto place = x_.place(); + const auto *x_data = x_.data(); + + y_->Resize(x_.dims()); + auto *y_data = y_->mutable_data(place); + + const auto &dev_ctx = + *paddle::platform::DeviceContextPool::Instance().Get(place); + +#define LAUNCH_RELU_KERNEL(DevCtxT) \ + do { \ + auto &__dev_ctx = dynamic_cast(dev_ctx); \ + paddle::platform::ForRange for_range(__dev_ctx, n); \ + Impl functor(x_data, y_data); \ + for_range(functor); \ + } while (0) + +#if defined(__NVCC__) || defined(__HIPCC__) + if (paddle::platform::is_gpu_place(place)) { + LAUNCH_RELU_KERNEL(paddle::platform::CUDADeviceContext); + return; + } +#endif + LAUNCH_RELU_KERNEL(paddle::platform::CPUDeviceContext); + +#undef LAUNCH_RELU_KERNEL + } + + private: + const paddle::framework::Tensor &x_; + paddle::framework::Tensor *y_; +}; + +inline void ReluForward(const paddle::framework::Tensor &x, + paddle::framework::Tensor *y) { + custom_raw_op::ReluFunctor functor(x, y); + paddle::framework::VisitDataType(x.type(), functor); +} + +} // namespace custom_raw_op diff --git a/python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op_setup.py b/python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op_setup.py new file mode 100644 index 0000000000..8889a56ad2 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op_setup.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import paddle +import paddle.fluid.core as core +from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup +from utils import paddle_includes, extra_compile_args + +if paddle.is_compiled_with_cuda(): + sources = ['custom_raw_op_kernel_op.cc', 'custom_raw_op_kernel_op.cu'] + extension = CUDAExtension +else: + sources = ['custom_raw_op_kernel_op.cc'] + extension = CppExtension + +cwd = os.path.dirname(os.path.abspath(__file__)) +os.chdir(cwd) + +if os.name == 'nt': + compile_dir = os.path.join(os.environ['work_dir'], os.environ['BUILD_DIR']) +else: + compile_dir = os.path.join(os.environ['PADDLE_ROOT'], 'build') + +macros = [] +if core.is_compiled_with_mkldnn(): + macros.append(("PADDLE_WITH_MKLDNN", None)) +if core.is_compiled_with_nccl(): + macros.append(("PADDLE_WITH_NCCL", None)) + +include_dirs = list(paddle_includes) + [cwd] +setup( + name=os.getenv("MODULE_NAME", "custom_raw_op_kernel_op_setup"), + ext_modules=extension( + sources=sources, + include_dirs=include_dirs, + extra_compile_args=extra_compile_args, + _compile_dir=compile_dir, + define_macros=macros)) diff --git a/python/paddle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py b/python/paddle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py new file mode 100644 index 0000000000..207ea87974 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py @@ -0,0 +1,97 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import paddle +import shlex +import site +import sys +import importlib +import unittest +import numpy as np + +MODULE_NAME = "custom_raw_op_kernel_op_lib" + + +def prepare_module_path(): + # NOTE(Aurelius84): Normally, it's no need to add following codes for users. + # But we simulate to pip install in current process, so interpreter don't snap + # sys.path has been updated. So we update it manually. + + # See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3 + if os.name == 'nt': + # NOTE(zhouwei25): getsitepackages on windows will return a list: [python install dir, site packages dir] + site_dir = site.getsitepackages()[1] + else: + site_dir = site.getsitepackages()[0] + custom_egg_path = [x for x in os.listdir(site_dir) if MODULE_NAME in x] + assert len(custom_egg_path) == 1, "Matched egg number is %d." % len( + custom_egg_path) + sys.path.append(os.path.join(site_dir, custom_egg_path[0])) + + +# FIXME(zengjinle): do not know how to get the _compile_dir argument +# on Windows CI when compiling the custom op. Skip it on Windows CI +# temporarily. +@unittest.skipIf(os.name == "nt", "Windows does not support yet.") +class TestCustomRawReluOp(unittest.TestCase): + @classmethod + def setUpClass(cls): + path = os.path.dirname(os.path.abspath(__file__)) + path = os.path.join(path, "custom_raw_op_kernel_op_setup.py") + cmd = [sys.executable, path, "install", "--force"] + cmd = " ".join([shlex.quote(c) for c in cmd]) + os.environ['MODULE_NAME'] = MODULE_NAME + assert os.system(cmd) == 0 + prepare_module_path() + + @classmethod + def tearDownClass(cls): + cmd = [sys.executable, "-m", "pip", "uninstall", "-y", MODULE_NAME] + cmd = " ".join([shlex.quote(c) for c in cmd]) + assert os.system(cmd) == 0 + + def custom_raw_relu(self, x): + module = importlib.import_module(MODULE_NAME) + custom_raw_relu_op = getattr(module, "custom_raw_relu") + self.assertTrue(custom_raw_relu_op is not None) + return custom_raw_relu_op(x) + + def test_dygraph(self): + x = paddle.to_tensor(np.random.uniform(low=-1.0, high=1.0, size=[2, 3])) + y1 = self.custom_raw_relu(x) + y2 = paddle.nn.ReLU()(x) + self.assertTrue(np.array_equal(y1.numpy(), y2.numpy())) + + def test_static(self): + paddle.enable_static() + shape = [2, 3] + x = paddle.static.data(name="x", dtype='float32', shape=shape) + y1 = self.custom_raw_relu(x) + y2 = paddle.nn.ReLU()(x) + + exe = paddle.static.Executor() + exe.run(paddle.static.default_startup_program()) + x_np = np.random.uniform( + low=-1.0, high=1.0, size=[2, 3]).astype('float32') + y1_value, y2_value = exe.run(paddle.static.default_main_program(), + feed={x.name: x_np}, + fetch_list=[y1, y2]) + self.assertTrue(np.array_equal(y1_value, y2_value)) + + paddle.disable_static() + + +if __name__ == "__main__": + unittest.main() -- GitLab