未验证 提交 5c3873f6 编写于 作者: S sneaxiy 提交者: GitHub

Add __PD_DEFINE_RAW_OP_KERNEL_FUNC for registering custom op kernel with ExecutionContext (#39352)

* hack custom op

* add ut

* skip windows ci
上级 fee4316d
...@@ -61,27 +61,27 @@ static T* DynLoad(void* handle, std::string name) { ...@@ -61,27 +61,27 @@ static T* DynLoad(void* handle, std::string name) {
return func; return func;
} }
inline bool IsGradVar(const std::string& var_name) { inline static bool IsGradVar(const std::string& var_name) {
std::string suffix = kGradVarSuffix; std::string suffix = kGradVarSuffix;
return var_name.rfind(suffix) != std::string::npos; return var_name.rfind(suffix) != std::string::npos;
} }
inline bool IsDuplicableVar(const std::string& var_name) { inline static bool IsDuplicableVar(const std::string& var_name) {
std::string suffix = kTensorVectorSuffix; std::string suffix = kTensorVectorSuffix;
return var_name.rfind(suffix) != std::string::npos; return var_name.rfind(suffix) != std::string::npos;
} }
inline std::string NoGrad(const std::string& var_name) { inline static std::string NoGrad(const std::string& var_name) {
std::string suffix = kGradVarSuffix; std::string suffix = kGradVarSuffix;
return var_name.substr(0, var_name.size() - kGradVarSuffixSize); return var_name.substr(0, var_name.size() - kGradVarSuffixSize);
} }
inline bool IsMemberOf(const std::vector<std::string>& vec, inline static bool IsMemberOf(const std::vector<std::string>& vec,
const std::string& name) { const std::string& name) {
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend(); return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
} }
std::vector<std::string> ParseAttrStr(const std::string& attr) { static std::vector<std::string> ParseAttrStr(const std::string& attr) {
auto split_pos = attr.find_first_of(":"); auto split_pos = attr.find_first_of(":");
PADDLE_ENFORCE_NE(split_pos, std::string::npos, PADDLE_ENFORCE_NE(split_pos, std::string::npos,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -602,44 +602,57 @@ class CustomGradOpMaker<imperative::OpBase> ...@@ -602,44 +602,57 @@ class CustomGradOpMaker<imperative::OpBase>
//////////// Operator and Kernel Register ////////////// //////////// Operator and Kernel Register //////////////
void RegisterOperatorKernelWithPlace(const std::string& name, static void RegisterOperatorKernelWithPlace(
const paddle::KernelFunc& kernel_func, const std::string& name,
const proto::VarType::Type type, const OperatorWithKernel::OpKernelFunc& op_kernel_func,
const PlaceType& place, const proto::VarType::Type type, const PlaceType& place) {
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
OpKernelType key(type, experimental::ConvertExtPlaceToInnerPlace(place)); OpKernelType key(type, experimental::ConvertExtPlaceToInnerPlace(place));
VLOG(3) << "Custom Operator: op kernel key: " << key; VLOG(3) << "Custom Operator: op kernel key: " << key;
OperatorWithKernel::AllOpKernels()[name][key] = OperatorWithKernel::AllOpKernels()[name][key] = op_kernel_func;
[kernel_func, inputs, outputs,
attrs](const framework::ExecutionContext& ctx) {
VLOG(3) << "Custom Operator: run custom kernel func in lambda.";
RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs);
};
} }
void RegisterOperatorKernel(const std::string& name, static void RegisterOperatorKernel(const std::string& name,
const paddle::KernelFunc& kernel_func, const paddle::KernelFunc& kernel_func,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) { const std::vector<std::string>& attrs,
void* dso_handle) {
VLOG(3) << "Custom Operator: op name in kernel: " << name; VLOG(3) << "Custom Operator: op name in kernel: " << name;
// NOTE [ Dummy Op Kernel Key ] // NOTE [ Dummy Op Kernel Key ]
// TODO(chenweihang): Because execute engine need get device context based // TODO(chenweihang): Because execute engine need get device context based
// op_kernel_key.place_, so we should register kernel for each // op_kernel_key.place_, so we should register kernel for each
// device. But this is not entirely correct, if user only give a cpu kernel, // device. But this is not entirely correct, if user only give a cpu kernel,
// but call api in gpu device, it will cause error. // but call api in gpu device, it will cause error.
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, OperatorWithKernel::OpKernelFunc op_kernel_func;
PlaceType::kCPU, inputs, outputs, attrs); if (kernel_func) {
VLOG(3) << "Register custom operator " << name << " with kernel func";
op_kernel_func = [kernel_func, inputs, outputs,
attrs](const framework::ExecutionContext& ctx) {
VLOG(3) << "Custom Operator: run custom kernel func in lambda.";
RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs);
};
} else {
VLOG(3) << "Register custom operator " << name
<< " with raw op kernel func";
PADDLE_ENFORCE_NOT_NULL(
dso_handle,
platform::errors::InvalidArgument(
"The dso handle must be provided if kernel_func is nullptr."));
using OpKernelFuncPtr = void(const framework::ExecutionContext&);
auto symbol_name = "PD_" + name + "_raw_op_kernel_func";
auto* func = detail::DynLoad<OpKernelFuncPtr>(dso_handle, symbol_name);
op_kernel_func = func;
}
RegisterOperatorKernelWithPlace(name, op_kernel_func, proto::VarType::RAW,
PlaceType::kCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, RegisterOperatorKernelWithPlace(name, op_kernel_func, proto::VarType::RAW,
PlaceType::kGPU, inputs, outputs, attrs); PlaceType::kGPU);
#endif #endif
} }
void RegisterOperatorWithMetaInfo( void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
const std::vector<OpMetaInfo>& op_meta_infos) { void* dso_handle) {
/* Op register */ /* Op register */
OpInfo info; OpInfo info;
...@@ -792,7 +805,8 @@ void RegisterOperatorWithMetaInfo( ...@@ -792,7 +805,8 @@ void RegisterOperatorWithMetaInfo(
} }
// Kernel func // Kernel func
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs); RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs,
dso_handle);
// If grad op or double grad op exists // If grad op or double grad op exists
std::string cur_op_name = op_name; std::string cur_op_name = op_name;
...@@ -900,7 +914,7 @@ void RegisterOperatorWithMetaInfo( ...@@ -900,7 +914,7 @@ void RegisterOperatorWithMetaInfo(
// Kernel func // Kernel func
RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs, RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs,
grad_op_outputs, grad_op_attrs); grad_op_outputs, grad_op_attrs, dso_handle);
// update current info // update current info
OpInfoMap::Instance().Insert(cur_op_name, info); OpInfoMap::Instance().Insert(cur_op_name, info);
...@@ -912,14 +926,14 @@ void RegisterOperatorWithMetaInfo( ...@@ -912,14 +926,14 @@ void RegisterOperatorWithMetaInfo(
} }
void RegisterOperatorWithMetaInfoMap( void RegisterOperatorWithMetaInfoMap(
const paddle::OpMetaInfoMap& op_meta_info_map) { const paddle::OpMetaInfoMap& op_meta_info_map, void* dso_handle) {
auto& meta_info_map = op_meta_info_map.GetMap(); auto& meta_info_map = op_meta_info_map.GetMap();
VLOG(3) << "Custom Operator: size of op meta info map - " VLOG(3) << "Custom Operator: size of op meta info map - "
<< meta_info_map.size(); << meta_info_map.size();
// pair: {op_type, OpMetaInfo} // pair: {op_type, OpMetaInfo}
for (auto& pair : meta_info_map) { for (auto& pair : meta_info_map) {
VLOG(3) << "Custom Operator: pair first -> op name: " << pair.first; VLOG(3) << "Custom Operator: pair first -> op name: " << pair.first;
RegisterOperatorWithMetaInfo(pair.second); RegisterOperatorWithMetaInfo(pair.second, dso_handle);
} }
} }
...@@ -934,7 +948,7 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { ...@@ -934,7 +948,7 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) {
detail::DynLoad<get_op_meta_info_map_t>(handle, "PD_GetOpMetaInfoMap"); detail::DynLoad<get_op_meta_info_map_t>(handle, "PD_GetOpMetaInfoMap");
auto& op_meta_info_map = get_op_meta_info_map(); auto& op_meta_info_map = get_op_meta_info_map();
RegisterOperatorWithMetaInfoMap(op_meta_info_map); RegisterOperatorWithMetaInfoMap(op_meta_info_map, handle);
} }
} // namespace framework } // namespace framework
......
...@@ -26,10 +26,11 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name); ...@@ -26,10 +26,11 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name);
// Register custom op api: register op directly // Register custom op api: register op directly
void RegisterOperatorWithMetaInfoMap( void RegisterOperatorWithMetaInfoMap(
const paddle::OpMetaInfoMap& op_meta_info_map); const paddle::OpMetaInfoMap& op_meta_info_map, void* dso_handle = nullptr);
// Interface for selective register custom op. // Interface for selective register custom op.
void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos); void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
void* dso_handle = nullptr);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/pten/api/ext/op_meta_info.h"
// NOTE(zengjinle): this macro is only for internal usage. Commonly, users
// should not use this macro.
#define __PD_DEFINE_RAW_OP_KERNEL_FUNC(op_name, ctx) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_raw_op_kernel_func__##op_name, \
"__PD_DEFINE_RAW_KERNEL_FUNC must be called in global namespace."); \
extern "C" void PD_##op_name##_raw_op_kernel_func( \
const ::paddle::framework::ExecutionContext& ctx)
...@@ -185,6 +185,14 @@ bool IsCompiledWithCUDA() { ...@@ -185,6 +185,14 @@ bool IsCompiledWithCUDA() {
#endif #endif
} }
bool IsCompiledWithNCCL() {
#ifdef PADDLE_WITH_NCCL
return true;
#else
return false;
#endif
}
bool IsCompiledWithROCM() { bool IsCompiledWithROCM() {
#ifndef PADDLE_WITH_HIP #ifndef PADDLE_WITH_HIP
return false; return false;
...@@ -2433,6 +2441,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2433,6 +2441,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("is_compiled_with_ipu", IsCompiledWithIPU); m.def("is_compiled_with_ipu", IsCompiledWithIPU);
m.def("is_compiled_with_xpu", IsCompiledWithXPU); m.def("is_compiled_with_xpu", IsCompiledWithXPU);
m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN); m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN);
m.def("is_compiled_with_nccl", IsCompiledWithNCCL);
m.def("is_compiled_with_cinn", IsCompiledWithCINN); m.def("is_compiled_with_cinn", IsCompiledWithCINN);
m.def("is_compiled_with_mlu", IsCompiledWithMLU); m.def("is_compiled_with_mlu", IsCompiledWithMLU);
m.def("_is_compiled_with_heterps", IsCompiledWithHETERPS); m.def("_is_compiled_with_heterps", IsCompiledWithHETERPS);
......
...@@ -10,6 +10,9 @@ if(WITH_GPU OR APPLE) ...@@ -10,6 +10,9 @@ if(WITH_GPU OR APPLE)
set_tests_properties(test_custom_relu_model PROPERTIES TIMEOUT 180) set_tests_properties(test_custom_relu_model PROPERTIES TIMEOUT 180)
endif() endif()
py_test(test_custom_raw_op_kernel_op SRCS test_custom_raw_op_kernel_op.py)
set_tests_properties(test_custom_raw_op_kernel_op PROPERTIES TIMEOUT 180)
# CPU custom op tests: only compile .cc file # CPU custom op tests: only compile .cc file
py_test(test_dispatch_jit SRCS test_dispatch_jit.py) py_test(test_dispatch_jit SRCS test_dispatch_jit.py)
py_test(test_multi_out_jit SRCS test_multi_out_jit.py) py_test(test_multi_out_jit SRCS test_multi_out_jit.py)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "custom_raw_op_kernel_op.h" // NOLINT
#include "paddle/fluid/framework/custom_raw_op_kernel_func.h"
#include "paddle/fluid/platform/enforce.h"
void ReluCPUForward(const paddle::framework::Tensor &x,
paddle::framework::Tensor *y) {
custom_raw_op::ReluForward(x, y);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void ReluGPUForward(const paddle::framework::Tensor &x,
paddle::framework::Tensor *y);
#else
void ReluGPUForward(const paddle::framework::Tensor &x,
paddle::framework::Tensor *y) {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"ReluGPUForward is not supported when not compiled with GPU."));
}
#endif
__PD_DEFINE_RAW_OP_KERNEL_FUNC(custom_raw_relu, ctx) {
namespace f = paddle::framework;
const auto *x = ctx.Input<f::Tensor>("X");
auto *y = ctx.Output<f::Tensor>("Y");
PADDLE_ENFORCE_NOT_NULL(x,
paddle::platform::errors::InvalidArgument(
"Input(X) should not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(y,
paddle::platform::errors::InvalidArgument(
"Input(X) should not be nullptr."));
if (paddle::platform::is_gpu_place(x->place())) {
ReluGPUForward(*x, y);
} else {
ReluCPUForward(*x, y);
}
}
PD_BUILD_OP(custom_raw_relu).Inputs({"X"}).Outputs({"Y"});
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "custom_raw_op_kernel_op.h" // NOLINT
void ReluGPUForward(const paddle::framework::Tensor &x,
paddle::framework::Tensor *y) {
custom_raw_op::ReluForward(x, y);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
namespace custom_raw_op {
struct ReluFunctor {
explicit ReluFunctor(const paddle::framework::Tensor &x,
paddle::framework::Tensor *y)
: x_(x), y_(y) {}
template <typename U>
struct Impl {
Impl(const U *x, U *y) : x_(x), y_(y) {}
HOSTDEVICE void operator()(size_t i) const {
y_[i] = (x_[i] > static_cast<U>(0) ? x_[i] : static_cast<U>(0));
}
private:
const U *x_;
U *y_;
};
template <typename T>
void apply() {
auto n = x_.numel();
auto place = x_.place();
const auto *x_data = x_.data<T>();
y_->Resize(x_.dims());
auto *y_data = y_->mutable_data<T>(place);
const auto &dev_ctx =
*paddle::platform::DeviceContextPool::Instance().Get(place);
#define LAUNCH_RELU_KERNEL(DevCtxT) \
do { \
auto &__dev_ctx = dynamic_cast<const DevCtxT &>(dev_ctx); \
paddle::platform::ForRange<DevCtxT> for_range(__dev_ctx, n); \
Impl<T> functor(x_data, y_data); \
for_range(functor); \
} while (0)
#if defined(__NVCC__) || defined(__HIPCC__)
if (paddle::platform::is_gpu_place(place)) {
LAUNCH_RELU_KERNEL(paddle::platform::CUDADeviceContext);
return;
}
#endif
LAUNCH_RELU_KERNEL(paddle::platform::CPUDeviceContext);
#undef LAUNCH_RELU_KERNEL
}
private:
const paddle::framework::Tensor &x_;
paddle::framework::Tensor *y_;
};
inline void ReluForward(const paddle::framework::Tensor &x,
paddle::framework::Tensor *y) {
custom_raw_op::ReluFunctor functor(x, y);
paddle::framework::VisitDataType(x.type(), functor);
}
} // namespace custom_raw_op
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.fluid.core as core
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup
from utils import paddle_includes, extra_compile_args
if paddle.is_compiled_with_cuda():
sources = ['custom_raw_op_kernel_op.cc', 'custom_raw_op_kernel_op.cu']
extension = CUDAExtension
else:
sources = ['custom_raw_op_kernel_op.cc']
extension = CppExtension
cwd = os.path.dirname(os.path.abspath(__file__))
os.chdir(cwd)
if os.name == 'nt':
compile_dir = os.path.join(os.environ['work_dir'], os.environ['BUILD_DIR'])
else:
compile_dir = os.path.join(os.environ['PADDLE_ROOT'], 'build')
macros = []
if core.is_compiled_with_mkldnn():
macros.append(("PADDLE_WITH_MKLDNN", None))
if core.is_compiled_with_nccl():
macros.append(("PADDLE_WITH_NCCL", None))
include_dirs = list(paddle_includes) + [cwd]
setup(
name=os.getenv("MODULE_NAME", "custom_raw_op_kernel_op_setup"),
ext_modules=extension(
sources=sources,
include_dirs=include_dirs,
extra_compile_args=extra_compile_args,
_compile_dir=compile_dir,
define_macros=macros))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import shlex
import site
import sys
import importlib
import unittest
import numpy as np
MODULE_NAME = "custom_raw_op_kernel_op_lib"
def prepare_module_path():
# NOTE(Aurelius84): Normally, it's no need to add following codes for users.
# But we simulate to pip install in current process, so interpreter don't snap
# sys.path has been updated. So we update it manually.
# See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3
if os.name == 'nt':
# NOTE(zhouwei25): getsitepackages on windows will return a list: [python install dir, site packages dir]
site_dir = site.getsitepackages()[1]
else:
site_dir = site.getsitepackages()[0]
custom_egg_path = [x for x in os.listdir(site_dir) if MODULE_NAME in x]
assert len(custom_egg_path) == 1, "Matched egg number is %d." % len(
custom_egg_path)
sys.path.append(os.path.join(site_dir, custom_egg_path[0]))
# FIXME(zengjinle): do not know how to get the _compile_dir argument
# on Windows CI when compiling the custom op. Skip it on Windows CI
# temporarily.
@unittest.skipIf(os.name == "nt", "Windows does not support yet.")
class TestCustomRawReluOp(unittest.TestCase):
@classmethod
def setUpClass(cls):
path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path, "custom_raw_op_kernel_op_setup.py")
cmd = [sys.executable, path, "install", "--force"]
cmd = " ".join([shlex.quote(c) for c in cmd])
os.environ['MODULE_NAME'] = MODULE_NAME
assert os.system(cmd) == 0
prepare_module_path()
@classmethod
def tearDownClass(cls):
cmd = [sys.executable, "-m", "pip", "uninstall", "-y", MODULE_NAME]
cmd = " ".join([shlex.quote(c) for c in cmd])
assert os.system(cmd) == 0
def custom_raw_relu(self, x):
module = importlib.import_module(MODULE_NAME)
custom_raw_relu_op = getattr(module, "custom_raw_relu")
self.assertTrue(custom_raw_relu_op is not None)
return custom_raw_relu_op(x)
def test_dygraph(self):
x = paddle.to_tensor(np.random.uniform(low=-1.0, high=1.0, size=[2, 3]))
y1 = self.custom_raw_relu(x)
y2 = paddle.nn.ReLU()(x)
self.assertTrue(np.array_equal(y1.numpy(), y2.numpy()))
def test_static(self):
paddle.enable_static()
shape = [2, 3]
x = paddle.static.data(name="x", dtype='float32', shape=shape)
y1 = self.custom_raw_relu(x)
y2 = paddle.nn.ReLU()(x)
exe = paddle.static.Executor()
exe.run(paddle.static.default_startup_program())
x_np = np.random.uniform(
low=-1.0, high=1.0, size=[2, 3]).astype('float32')
y1_value, y2_value = exe.run(paddle.static.default_main_program(),
feed={x.name: x_np},
fetch_list=[y1, y2])
self.assertTrue(np.array_equal(y1_value, y2_value))
paddle.disable_static()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册