From e60fd1f6a8da123f4c0129d5790b906a8c44477e Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 23 Feb 2021 04:54:38 -0600 Subject: [PATCH] [CustomOp] Split test and add inference test (#31078) * split test & add inference test * add timeout config * change to setup install * change to jit compile * add verbose for test * fix load setup name repeat * polish details * resolve conflict * fix code format error --- .../fluid/tests/custom_op/CMakeLists.txt | 17 ++- .../{relu_op_simple.cc => custom_relu_op.cc} | 30 +---- .../{relu_op_simple.cu => custom_relu_op.cu} | 22 +--- ...lu_op3_simple.cc => custom_relu_op_dup.cc} | 6 +- ...install_simple.py => custom_relu_setup.py} | 7 +- .../tests/custom_op/multi_out_test_op.cc | 76 +++++++++++ .../custom_op/test_custom_relu_op_jit.py | 86 +++++++++++++ ..._setup.py => test_custom_relu_op_setup.py} | 120 +++++++++++++++--- ...{test_dispatch.py => test_dispatch_jit.py} | 0 ...custom_op_jit.py => test_multi_out_jit.py} | 109 ++++++---------- .../utils/cpp_extension/cpp_extension.py | 2 +- 11 files changed, 324 insertions(+), 151 deletions(-) rename python/paddle/fluid/tests/custom_op/{relu_op_simple.cc => custom_relu_op.cc} (81%) rename python/paddle/fluid/tests/custom_op/{relu_op_simple.cu => custom_relu_op.cu} (75%) rename python/paddle/fluid/tests/custom_op/{relu_op3_simple.cc => custom_relu_op_dup.cc} (92%) rename python/paddle/fluid/tests/custom_op/{setup_install_simple.py => custom_relu_setup.py} (79%) create mode 100644 python/paddle/fluid/tests/custom_op/multi_out_test_op.cc create mode 100644 python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py rename python/paddle/fluid/tests/custom_op/{test_simple_custom_op_setup.py => test_custom_relu_op_setup.py} (53%) rename python/paddle/fluid/tests/custom_op/{test_dispatch.py => test_dispatch_jit.py} (100%) rename python/paddle/fluid/tests/custom_op/{test_simple_custom_op_jit.py => test_multi_out_jit.py} (62%) diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index d7acab4d03..1307df1fc1 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -1,17 +1,20 @@ # New custom OP can support Windows/Linux now -# 'test_simple_custom_op_jit/test_simple_custom_op_setup' compile .cc and .cu file -py_test(test_simple_custom_op_setup SRCS test_simple_custom_op_setup.py) -py_test(test_simple_custom_op_jit SRCS test_simple_custom_op_jit.py) +# 'test_custom_relu_op_setup/jit' compile .cc and .cu file +py_test(test_custom_relu_op_setup SRCS test_custom_relu_op_setup.py) +py_test(test_custom_relu_op_jit SRCS test_custom_relu_op_jit.py) # Compiling shared library will cost some time, but running process is very fast. -set_tests_properties(test_simple_custom_op_setup PROPERTIES TIMEOUT 250) -set_tests_properties(test_simple_custom_op_jit PROPERTIES TIMEOUT 180) +set_tests_properties(test_custom_relu_op_setup PROPERTIES TIMEOUT 250) +set_tests_properties(test_custom_relu_op_jit PROPERTIES TIMEOUT 180) py_test(test_sysconfig SRCS test_sysconfig.py) # 'test_dispatch' compile .cc file -py_test(test_dispatch SRCS test_dispatch.py) -set_tests_properties(test_dispatch PROPERTIES TIMEOUT 180) +py_test(test_dispatch_jit SRCS test_dispatch_jit.py) +set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 180) + +py_test(test_multi_out_jit SRCS test_multi_out_jit.py) +set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 180) if(NOT LINUX) return() diff --git a/python/paddle/fluid/tests/custom_op/relu_op_simple.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc similarity index 81% rename from python/paddle/fluid/tests/custom_op/relu_op_simple.cc rename to python/paddle/fluid/tests/custom_op/custom_relu_op.cc index b02ecba682..0e358e24ae 100644 --- a/python/paddle/fluid/tests/custom_op/relu_op_simple.cc +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc @@ -17,13 +17,6 @@ #include "paddle/extension.h" -template -void fill_constant_cpu_kernel(data_t* out_data, int64_t x_numel, data_t value) { - for (int i = 0; i < x_numel; ++i) { - out_data[i] = value; - } -} - template void relu_cpu_forward_kernel(const data_t* x_data, data_t* out_data, @@ -53,21 +46,8 @@ std::vector relu_cpu_forward(const paddle::Tensor& x) { relu_cpu_forward_kernel( x.data(), out.mutable_data(x.place()), x.size()); })); - // fake multi output: Fake_float64 with float64 dtype - auto fake_float64 = paddle::Tensor(paddle::PlaceType::kCPU); - fake_float64.reshape(x.shape()); - - fill_constant_cpu_kernel( - fake_float64.mutable_data(x.place()), x.size(), 0.); - - // fake multi output: ZFake_int32 with int32 dtype - auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kCPU); - zfake_int32.reshape(x.shape()); - - fill_constant_cpu_kernel( - zfake_int32.mutable_data(x.place()), x.size(), 1); - return {out, fake_float64, zfake_int32}; + return {out}; } std::vector relu_cpu_backward(const paddle::Tensor& x, @@ -117,16 +97,16 @@ std::vector ReluBackward(const paddle::Tensor& x, } std::vector> ReluInferShape(std::vector x_shape) { - return {x_shape, x_shape, x_shape}; + return {x_shape}; } std::vector ReluInferDType(paddle::DataType x_dtype) { - return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32}; + return {x_dtype}; } -PD_BUILD_OP("relu2") +PD_BUILD_OP("custom_relu") .Inputs({"X"}) - .Outputs({"Out", "Fake_float64", "ZFake_int32"}) + .Outputs({"Out"}) .SetKernelFn(PD_KERNEL(ReluForward)) .SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) diff --git a/python/paddle/fluid/tests/custom_op/relu_op_simple.cu b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu similarity index 75% rename from python/paddle/fluid/tests/custom_op/relu_op_simple.cu rename to python/paddle/fluid/tests/custom_op/custom_relu_op.cu index 2ef6a5c145..a9ce517607 100644 --- a/python/paddle/fluid/tests/custom_op/relu_op_simple.cu +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu @@ -14,16 +14,6 @@ #include "paddle/extension.h" -template -__global__ void fill_constant_cuda_kernel(data_t* y, - const int num, - data_t value) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; - for (int i = gid; i < num; i += blockDim.x * gridDim.x) { - y[i] = value; - } -} - template __global__ void relu_cuda_forward_kernel(const data_t* x, data_t* y, @@ -57,18 +47,8 @@ std::vector relu_cuda_forward(const paddle::Tensor& x) { relu_cuda_forward_kernel<<>>( x.data(), out.mutable_data(x.place()), numel); })); - // fake multi output: Fake_1 - auto fake_float64 = paddle::Tensor(paddle::PlaceType::kGPU); - fake_float64.reshape(x.shape()); - fill_constant_cuda_kernel<<>>( - fake_float64.mutable_data(x.place()), numel, 0.); - // fake multi output: ZFake_1 - auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kGPU); - zfake_int32.reshape(x.shape()); - fill_constant_cuda_kernel<<>>( - zfake_int32.mutable_data(x.place()), numel, 1); - return {out, fake_float64, zfake_int32}; + return {out}; } std::vector relu_cuda_backward(const paddle::Tensor& x, diff --git a/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op_dup.cc similarity index 92% rename from python/paddle/fluid/tests/custom_op/relu_op3_simple.cc rename to python/paddle/fluid/tests/custom_op/custom_relu_op_dup.cc index ec64bce187..7319bdd762 100644 --- a/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op_dup.cc @@ -29,11 +29,11 @@ std::vector> ReluInferShape(std::vector x_shape); std::vector ReluInferDType(paddle::DataType x_dtype); -// Reuse codes in `relu_op_simple.cc/cu` to register another custom operator +// Reuse codes in `custom_relu_op.cc/cu` to register another custom operator // to test jointly compile multi operators at same time. -PD_BUILD_OP("relu3") +PD_BUILD_OP("custom_relu_dup") .Inputs({"X"}) - .Outputs({"Out", "Fake_float64", "ZFake_int32"}) + .Outputs({"Out"}) .SetKernelFn(PD_KERNEL(ReluForward)) .SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) diff --git a/python/paddle/fluid/tests/custom_op/setup_install_simple.py b/python/paddle/fluid/tests/custom_op/custom_relu_setup.py similarity index 79% rename from python/paddle/fluid/tests/custom_op/setup_install_simple.py rename to python/paddle/fluid/tests/custom_op/custom_relu_setup.py index ed236ccbd4..598b850c87 100644 --- a/python/paddle/fluid/tests/custom_op/setup_install_simple.py +++ b/python/paddle/fluid/tests/custom_op/custom_relu_setup.py @@ -17,11 +17,14 @@ import os from utils import paddle_includes, extra_compile_args from paddle.utils.cpp_extension import CUDAExtension, setup +# custom_relu_op_dup.cc is only used for multi ops test, +# not a new op, if you want to test only one op, remove this +# source file setup( - name='simple_setup_relu2', + name='custom_relu_module_setup', ext_modules=CUDAExtension( # test for not specific name here. sources=[ - 'relu_op_simple.cc', 'relu_op_simple.cu', 'relu_op3_simple.cc' + 'custom_relu_op.cc', 'custom_relu_op.cu', 'custom_relu_op_dup.cc' ], # test for multi ops include_dirs=paddle_includes, extra_compile_args=extra_compile_args)) diff --git a/python/paddle/fluid/tests/custom_op/multi_out_test_op.cc b/python/paddle/fluid/tests/custom_op/multi_out_test_op.cc new file mode 100644 index 0000000000..bece0f4984 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/multi_out_test_op.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" + +template +void assign_cpu_kernel(const data_t* x_data, + data_t* out_data, + int64_t x_numel) { + for (int i = 0; i < x_numel; ++i) { + out_data[i] = x_data[i]; + } +} + +template +void fill_constant_cpu_kernel(data_t* out_data, int64_t x_numel, data_t value) { + for (int i = 0; i < x_numel; ++i) { + out_data[i] = value; + } +} + +std::vector MultiOutCPU(const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_FLOATING_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(x.place()), x.size()); + })); + + // fake multi output: Fake_float64 with float64 dtype + auto fake_float64 = paddle::Tensor(paddle::PlaceType::kCPU); + fake_float64.reshape(x.shape()); + + fill_constant_cpu_kernel( + fake_float64.mutable_data(x.place()), x.size(), 0.); + + // fake multi output: ZFake_int32 with int32 dtype + auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kCPU); + zfake_int32.reshape(x.shape()); + + fill_constant_cpu_kernel( + zfake_int32.mutable_data(x.place()), x.size(), 1); + + return {out, fake_float64, zfake_int32}; +} + +std::vector> InferShape(std::vector x_shape) { + return {x_shape, x_shape, x_shape}; +} + +std::vector InferDtype(paddle::DataType x_dtype) { + return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32}; +} + +PD_BUILD_OP("multi_out") + .Inputs({"X"}) + .Outputs({"Out", "Fake_float64", "ZFake_int32"}) + .SetKernelFn(PD_KERNEL(MultiOutCPU)) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDtype)); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py new file mode 100644 index 0000000000..018e654429 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py @@ -0,0 +1,86 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import unittest +import paddle +import numpy as np +from paddle.utils.cpp_extension import load, get_build_directory +from paddle.utils.cpp_extension.extension_utils import run_cmd +from utils import paddle_includes, extra_compile_args +from test_custom_relu_op_setup import custom_relu_dynamic, custom_relu_static + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +if os.name == 'nt': + cmd = 'del {}\\custom_relu_module_jit.pyd'.format(get_build_directory()) + run_cmd(cmd, True) + +# Compile and load custom op Just-In-Time. +# custom_relu_op_dup.cc is only used for multi ops test, +# not a new op, if you want to test only one op, remove this +# source file +custom_module = load( + name='custom_relu_module_jit', + sources=[ + 'custom_relu_op.cc', 'custom_relu_op.cu', 'custom_relu_op_dup.cc' + ], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cflags=extra_compile_args, # add for Coverage CI + verbose=True) + + +class TestJITLoad(unittest.TestCase): + def setUp(self): + self.custom_ops = [ + custom_module.custom_relu, custom_module.custom_relu_dup + ] + self.dtypes = ['float32', 'float64'] + self.devices = ['cpu', 'gpu'] + + def test_static(self): + for device in self.devices: + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + for custom_op in self.custom_ops: + out = custom_relu_static(custom_op, device, dtype, x) + pd_out = custom_relu_static(custom_op, device, dtype, x, + False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + + def test_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + for custom_op in self.custom_ops: + out, x_grad = custom_relu_dynamic(custom_op, device, dtype, + x) + pd_out, pd_x_grad = custom_relu_dynamic(custom_op, device, + dtype, x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format( + out, pd_out)) + self.assertTrue( + np.array_equal(x_grad, pd_x_grad), + "custom op x grad: {},\n paddle api x grad: {}".format( + x_grad, pd_x_grad)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py similarity index 53% rename from python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py rename to python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py index f312508d39..6781915e02 100644 --- a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py @@ -23,13 +23,13 @@ import numpy as np from paddle.utils.cpp_extension.extension_utils import run_cmd -def relu2_dynamic(func, device, dtype, np_x, use_func=True): +def custom_relu_dynamic(func, device, dtype, np_x, use_func=True): paddle.set_device(device) t = paddle.to_tensor(np_x) t.stop_gradient = False - out = func(t)[0] if use_func else paddle.nn.functional.relu(t) + out = func(t) if use_func else paddle.nn.functional.relu(t) out.stop_gradient = False out.backward() @@ -37,7 +37,12 @@ def relu2_dynamic(func, device, dtype, np_x, use_func=True): return out.numpy(), t.grad -def relu2_static(func, device, dtype, np_x, use_func=True): +def custom_relu_static(func, + device, + dtype, + np_x, + use_func=True, + test_infer=False): paddle.enable_static() paddle.set_device(device) @@ -45,8 +50,7 @@ def relu2_static(func, device, dtype, np_x, use_func=True): with static.program_guard(static.Program()): x = static.data(name='X', shape=[None, 8], dtype=dtype) x.stop_gradient = False - # out, fake_float64, fake_int32 - out = func(x)[0] if use_func else paddle.nn.functional.relu(x) + out = func(x) if use_func else paddle.nn.functional.relu(x) static.append_backward(out) exe = static.Executor() @@ -60,7 +64,7 @@ def relu2_static(func, device, dtype, np_x, use_func=True): return out_v -def relu2_static_pe(func, device, dtype, np_x, use_func=True): +def custom_relu_static_pe(func, device, dtype, np_x, use_func=True): paddle.enable_static() paddle.set_device(device) @@ -69,7 +73,7 @@ def relu2_static_pe(func, device, dtype, np_x, use_func=True): with static.program_guard(static.Program()): x = static.data(name='X', shape=[None, 8], dtype=dtype) x.stop_gradient = False - out = func(x)[0] if use_func else paddle.nn.functional.relu(x) + out = func(x) if use_func else paddle.nn.functional.relu(x) static.append_backward(out) exe = static.Executor() @@ -87,16 +91,58 @@ def relu2_static_pe(func, device, dtype, np_x, use_func=True): return out_v +def custom_relu_static_inference(func, device, np_data, np_label, path_prefix): + paddle.set_device(device) + + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + # simple module + data = static.data( + name='data', shape=[None, 1, 28, 28], dtype='float32') + label = static.data(name='label', shape=[None, 1], dtype='int64') + + hidden = static.nn.fc(data, size=128) + hidden = func(hidden) + hidden = static.nn.fc(hidden, size=128) + predict = static.nn.fc(hidden, size=10, activation='softmax') + loss = paddle.nn.functional.cross_entropy(input=hidden, label=label) + avg_loss = paddle.mean(loss) + + opt = paddle.optimizer.SGD(learning_rate=0.1) + opt.minimize(avg_loss) + + # run start up model + exe = static.Executor() + exe.run(static.default_startup_program()) + + # train + for i in range(4): + avg_loss_v = exe.run(static.default_main_program(), + feed={'data': np_data, + 'label': np_label}, + fetch_list=[avg_loss]) + + # save inference model + static.save_inference_model(path_prefix, [data], [predict], exe) + + # get train predict value + predict_v = exe.run(static.default_main_program(), + feed={'data': np_data, + 'label': np_label}, + fetch_list=[predict]) + + return predict_v + + class TestNewCustomOpSetUpInstall(unittest.TestCase): def setUp(self): cur_dir = os.path.dirname(os.path.abspath(__file__)) # compile, install the custom op egg into site-packages under background if os.name == 'nt': - cmd = 'cd /d {} && python setup_install_simple.py install'.format( + cmd = 'cd /d {} && python custom_relu_setup.py install'.format( cur_dir) else: - cmd = 'cd {} && python setup_install_simple.py install'.format( - cur_dir) + cmd = 'cd {} && python custom_relu_setup.py install'.format(cur_dir) run_cmd(cmd) # NOTE(Aurelius84): Normally, it's no need to add following codes for users. @@ -110,26 +156,36 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): else: site_dir = site.getsitepackages()[0] custom_egg_path = [ - x for x in os.listdir(site_dir) if 'simple_setup_relu2' in x + x for x in os.listdir(site_dir) if 'custom_relu_module_setup' in x ] assert len(custom_egg_path) == 1, "Matched egg number is %d." % len( custom_egg_path) sys.path.append(os.path.join(site_dir, custom_egg_path[0])) # usage: import the package directly - import simple_setup_relu2 - self.custom_ops = [simple_setup_relu2.relu2, simple_setup_relu2.relu3] + import custom_relu_module_setup + # `custom_relu_dup` is same as `custom_relu_dup` + self.custom_ops = [ + custom_relu_module_setup.custom_relu, + custom_relu_module_setup.custom_relu_dup + ] self.dtypes = ['float32', 'float64'] self.devices = ['cpu', 'gpu'] + # config seed + SEED = 2021 + paddle.seed(SEED) + paddle.framework.random._manual_program_seed(SEED) + def test_static(self): for device in self.devices: for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: - out = relu2_static(custom_op, device, dtype, x) - pd_out = relu2_static(custom_op, device, dtype, x, False) + out = custom_relu_static(custom_op, device, dtype, x) + pd_out = custom_relu_static(custom_op, device, dtype, x, + False) self.assertTrue( np.array_equal(out, pd_out), "custom op out: {},\n paddle api out: {}".format( @@ -140,8 +196,9 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: - out = relu2_static_pe(custom_op, device, dtype, x) - pd_out = relu2_static_pe(custom_op, device, dtype, x, False) + out = custom_relu_static_pe(custom_op, device, dtype, x) + pd_out = custom_relu_static_pe(custom_op, device, dtype, x, + False) self.assertTrue( np.array_equal(out, pd_out), "custom op out: {},\n paddle api out: {}".format( @@ -152,9 +209,10 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): for dtype in self.dtypes: x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: - out, x_grad = relu2_dynamic(custom_op, device, dtype, x) - pd_out, pd_x_grad = relu2_dynamic(custom_op, device, dtype, - x, False) + out, x_grad = custom_relu_dynamic(custom_op, device, dtype, + x) + pd_out, pd_x_grad = custom_relu_dynamic(custom_op, device, + dtype, x, False) self.assertTrue( np.array_equal(out, pd_out), "custom op out: {},\n paddle api out: {}".format( @@ -164,6 +222,28 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): "custom op x grad: {},\n paddle api x grad: {}".format( x_grad, pd_x_grad)) + def test_static_save_and_load_inference_model(self): + paddle.enable_static() + np_data = np.random.random((1, 1, 28, 28)).astype("float32") + np_label = np.random.random((1, 1)).astype("int64") + path_prefix = "custom_op_inference/custom_relu" + for device in self.devices: + predict = custom_relu_static_inference( + self.custom_ops[0], device, np_data, np_label, path_prefix) + # load inference model + with static.scope_guard(static.Scope()): + exe = static.Executor() + [inference_program, feed_target_names, + fetch_targets] = static.load_inference_model(path_prefix, exe) + predict_infer = exe.run(inference_program, + feed={feed_target_names[0]: np_data}, + fetch_list=fetch_targets) + self.assertTrue( + np.array_equal(predict, predict_infer), + "custom op predict: {},\n custom op infer predict: {}". + format(predict, predict_infer)) + paddle.disable_static() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_dispatch.py b/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py similarity index 100% rename from python/paddle/fluid/tests/custom_op/test_dispatch.py rename to python/paddle/fluid/tests/custom_op/test_dispatch_jit.py diff --git a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py b/python/paddle/fluid/tests/custom_op/test_multi_out_jit.py similarity index 62% rename from python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py rename to python/paddle/fluid/tests/custom_op/test_multi_out_jit.py index f4d3c4f659..00cd689ca6 100644 --- a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_multi_out_jit.py @@ -15,88 +15,51 @@ import os import subprocess import unittest -import paddle import numpy as np + +import paddle +from paddle.utils.cpp_extension import load from paddle.utils.cpp_extension import load, get_build_directory from paddle.utils.cpp_extension.extension_utils import run_cmd from utils import paddle_includes, extra_compile_args -from test_simple_custom_op_setup import relu2_dynamic, relu2_static # Because Windows don't use docker, the shared lib already exists in the # cache dir, it will not be compiled again unless the shared lib is removed. if os.name == 'nt': - cmd = 'del {}\\simple_jit_relu2.pyd'.format(get_build_directory()) + cmd = 'del {}\\multi_out_jit.pyd'.format(get_build_directory()) run_cmd(cmd, True) # Compile and load custom op Just-In-Time. -custom_module = load( - name='simple_jit_relu2', - sources=['relu_op_simple.cc', 'relu_op_simple.cu', 'relu_op3_simple.cc'], +multi_out_module = load( + name='multi_out_jit', + sources=['multi_out_test_op.cc'], extra_include_paths=paddle_includes, # add for Coverage CI extra_cflags=extra_compile_args, # add for Coverage CI verbose=True) -class TestJITLoad(unittest.TestCase): - def setUp(self): - self.custom_ops = [custom_module.relu2, custom_module.relu3] - self.dtypes = ['float32', 'float64'] - self.devices = ['cpu', 'gpu'] - - def test_static(self): - for device in self.devices: - for dtype in self.dtypes: - x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - for custom_op in self.custom_ops: - out = relu2_static(custom_op, device, dtype, x) - pd_out = relu2_static(custom_op, device, dtype, x, False) - self.assertTrue( - np.array_equal(out, pd_out), - "custom op out: {},\n paddle api out: {}".format( - out, pd_out)) - - def test_dynamic(self): - for device in self.devices: - for dtype in self.dtypes: - x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - for custom_op in self.custom_ops: - out, x_grad = relu2_dynamic(custom_op, device, dtype, x) - pd_out, pd_x_grad = relu2_dynamic(custom_op, device, dtype, - x, False) - self.assertTrue( - np.array_equal(out, pd_out), - "custom op out: {},\n paddle api out: {}".format( - out, pd_out)) - self.assertTrue( - np.array_equal(x_grad, pd_x_grad), - "custom op x grad: {},\n paddle api x grad: {}".format( - x_grad, pd_x_grad)) - - class TestMultiOutputDtypes(unittest.TestCase): def setUp(self): - self.custom_op = custom_module.relu2 + self.custom_op = multi_out_module.multi_out self.dtypes = ['float32', 'float64'] - self.devices = ['cpu', 'gpu'] + self.devices = ['cpu'] - def test_static(self): - paddle.enable_static() - for device in self.devices: - for dtype in self.dtypes: - res = self.run_static(device, dtype) - self.check_multi_outputs(res) - paddle.disable_static() + def run_static(self, device, dtype): + paddle.set_device(device) + x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - def test_dynamic(self): - for device in self.devices: - for dtype in self.dtypes: - paddle.set_device(device) - x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype) - x = paddle.to_tensor(x_data) + with paddle.static.scope_guard(paddle.static.Scope()): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name='X', shape=[None, 8], dtype=dtype) outs = self.custom_op(x) - self.assertTrue(len(outs) == 3) - self.check_multi_outputs(outs, True) + exe = paddle.static.Executor() + exe.run(paddle.static.default_startup_program()) + res = exe.run(paddle.static.default_main_program(), + feed={'X': x_data}, + fetch_list=outs) + + return res def check_multi_outputs(self, outs, is_dynamic=False): out, zero_float64, one_int32 = outs @@ -112,22 +75,24 @@ class TestMultiOutputDtypes(unittest.TestCase): self.assertTrue( np.array_equal(one_int32, np.ones([4, 8]).astype('int32'))) - def run_static(self, device, dtype): - paddle.set_device(device) - x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + def test_static(self): + paddle.enable_static() + for device in self.devices: + for dtype in self.dtypes: + res = self.run_static(device, dtype) + self.check_multi_outputs(res) + paddle.disable_static() - with paddle.static.scope_guard(paddle.static.Scope()): - with paddle.static.program_guard(paddle.static.Program()): - x = paddle.static.data(name='X', shape=[None, 8], dtype=dtype) + def test_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + paddle.set_device(device) + x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + x = paddle.to_tensor(x_data) outs = self.custom_op(x) - exe = paddle.static.Executor() - exe.run(paddle.static.default_startup_program()) - res = exe.run(paddle.static.default_main_program(), - feed={'X': x_data}, - fetch_list=outs) - - return res + self.assertTrue(len(outs) == 3) + self.check_multi_outputs(outs, True) if __name__ == '__main__': diff --git a/python/paddle/utils/cpp_extension/cpp_extension.py b/python/paddle/utils/cpp_extension/cpp_extension.py index 8c0893b16c..2789a89978 100644 --- a/python/paddle/utils/cpp_extension/cpp_extension.py +++ b/python/paddle/utils/cpp_extension/cpp_extension.py @@ -558,7 +558,7 @@ def load(name, log_v("build_directory: {}".format(build_directory), verbose) - file_path = os.path.join(build_directory, "setup.py") + file_path = os.path.join(build_directory, "{}_setup.py".format(name)) sources = [os.path.abspath(source) for source in sources] # TODO(Aurelius84): split cflags and cuda_flags -- GitLab