From f2dc29a9fabcfd0d9d5f277019e5290483a8c650 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 19 Feb 2021 16:07:49 +0800 Subject: [PATCH] [CustomOp] Support output dtypes in generated Python API (#31045) --- .../fluid/tests/custom_op/relu_op3_simple.cc | 2 +- .../fluid/tests/custom_op/relu_op_simple.cc | 28 +++++++-- .../fluid/tests/custom_op/relu_op_simple.cu | 22 ++++++- .../custom_op/test_simple_custom_op_jit.py | 57 +++++++++++++++++++ .../custom_op/test_simple_custom_op_setup.py | 10 ++-- .../utils/cpp_extension/extension_utils.py | 35 ++++++------ 6 files changed, 125 insertions(+), 29 deletions(-) diff --git a/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc b/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc index 9df808a38a..ec64bce187 100644 --- a/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc +++ b/python/paddle/fluid/tests/custom_op/relu_op3_simple.cc @@ -33,7 +33,7 @@ std::vector ReluInferDType(paddle::DataType x_dtype); // to test jointly compile multi operators at same time. PD_BUILD_OP("relu3") .Inputs({"X"}) - .Outputs({"Out"}) + .Outputs({"Out", "Fake_float64", "ZFake_int32"}) .SetKernelFn(PD_KERNEL(ReluForward)) .SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) diff --git a/python/paddle/fluid/tests/custom_op/relu_op_simple.cc b/python/paddle/fluid/tests/custom_op/relu_op_simple.cc index 5abd1b77da..b02ecba682 100644 --- a/python/paddle/fluid/tests/custom_op/relu_op_simple.cc +++ b/python/paddle/fluid/tests/custom_op/relu_op_simple.cc @@ -17,6 +17,13 @@ #include "paddle/extension.h" +template +void fill_constant_cpu_kernel(data_t* out_data, int64_t x_numel, data_t value) { + for (int i = 0; i < x_numel; ++i) { + out_data[i] = value; + } +} + template void relu_cpu_forward_kernel(const data_t* x_data, data_t* out_data, @@ -46,8 +53,21 @@ std::vector relu_cpu_forward(const paddle::Tensor& x) { relu_cpu_forward_kernel( x.data(), out.mutable_data(x.place()), x.size()); })); + // fake multi output: Fake_float64 with float64 dtype + auto fake_float64 = paddle::Tensor(paddle::PlaceType::kCPU); + fake_float64.reshape(x.shape()); + + fill_constant_cpu_kernel( + fake_float64.mutable_data(x.place()), x.size(), 0.); + + // fake multi output: ZFake_int32 with int32 dtype + auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kCPU); + zfake_int32.reshape(x.shape()); + + fill_constant_cpu_kernel( + zfake_int32.mutable_data(x.place()), x.size(), 1); - return {out}; + return {out, fake_float64, zfake_int32}; } std::vector relu_cpu_backward(const paddle::Tensor& x, @@ -97,16 +117,16 @@ std::vector ReluBackward(const paddle::Tensor& x, } std::vector> ReluInferShape(std::vector x_shape) { - return {x_shape}; + return {x_shape, x_shape, x_shape}; } std::vector ReluInferDType(paddle::DataType x_dtype) { - return {x_dtype}; + return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32}; } PD_BUILD_OP("relu2") .Inputs({"X"}) - .Outputs({"Out"}) + .Outputs({"Out", "Fake_float64", "ZFake_int32"}) .SetKernelFn(PD_KERNEL(ReluForward)) .SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) diff --git a/python/paddle/fluid/tests/custom_op/relu_op_simple.cu b/python/paddle/fluid/tests/custom_op/relu_op_simple.cu index a9ce517607..2ef6a5c145 100644 --- a/python/paddle/fluid/tests/custom_op/relu_op_simple.cu +++ b/python/paddle/fluid/tests/custom_op/relu_op_simple.cu @@ -14,6 +14,16 @@ #include "paddle/extension.h" +template +__global__ void fill_constant_cuda_kernel(data_t* y, + const int num, + data_t value) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + y[i] = value; + } +} + template __global__ void relu_cuda_forward_kernel(const data_t* x, data_t* y, @@ -47,8 +57,18 @@ std::vector relu_cuda_forward(const paddle::Tensor& x) { relu_cuda_forward_kernel<<>>( x.data(), out.mutable_data(x.place()), numel); })); + // fake multi output: Fake_1 + auto fake_float64 = paddle::Tensor(paddle::PlaceType::kGPU); + fake_float64.reshape(x.shape()); + fill_constant_cuda_kernel<<>>( + fake_float64.mutable_data(x.place()), numel, 0.); + // fake multi output: ZFake_1 + auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kGPU); + zfake_int32.reshape(x.shape()); + fill_constant_cuda_kernel<<>>( + zfake_int32.mutable_data(x.place()), numel, 1); - return {out}; + return {out, fake_float64, zfake_int32}; } std::vector relu_cuda_backward(const paddle::Tensor& x, diff --git a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py b/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py index 926ab4064a..2c0dc1a4ca 100644 --- a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_simple_custom_op_jit.py @@ -64,5 +64,62 @@ class TestJITLoad(unittest.TestCase): x_grad, pd_x_grad)) +class TestMultiOutputDtypes(unittest.TestCase): + def setUp(self): + self.custom_op = custom_module.relu2 + self.dtypes = ['float32', 'float64'] + self.devices = ['cpu', 'gpu'] + + def test_static(self): + paddle.enable_static() + for device in self.devices: + for dtype in self.dtypes: + res = self.run_static(device, dtype) + self.check_multi_outputs(res) + paddle.disable_static() + + def test_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + paddle.set_device(device) + x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + x = paddle.to_tensor(x_data) + outs = self.custom_op(x) + + self.assertTrue(len(outs) == 3) + self.check_multi_outputs(outs, True) + + def check_multi_outputs(self, outs, is_dynamic=False): + out, zero_float64, one_int32 = outs + if is_dynamic: + zero_float64 = zero_float64.numpy() + one_int32 = one_int32.numpy() + # Fake_float64 + self.assertTrue('float64' in str(zero_float64.dtype)) + self.assertTrue( + np.array_equal(zero_float64, np.zeros([4, 8]).astype('float64'))) + # ZFake_int32 + self.assertTrue('int32' in str(one_int32.dtype)) + self.assertTrue( + np.array_equal(one_int32, np.ones([4, 8]).astype('int32'))) + + def run_static(self, device, dtype): + paddle.set_device(device) + x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + + with paddle.static.scope_guard(paddle.static.Scope()): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name='X', shape=[None, 8], dtype=dtype) + outs = self.custom_op(x) + + exe = paddle.static.Executor() + exe.run(paddle.static.default_startup_program()) + res = exe.run(paddle.static.default_main_program(), + feed={'X': x_data}, + fetch_list=outs) + + return res + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py b/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py index dd69aef86a..cfa2db0ba2 100644 --- a/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_simple_custom_op_setup.py @@ -29,7 +29,7 @@ def relu2_dynamic(func, device, dtype, np_x, use_func=True): t = paddle.to_tensor(np_x) t.stop_gradient = False - out = func(t) if use_func else paddle.nn.functional.relu(t) + out = func(t)[0] if use_func else paddle.nn.functional.relu(t) out.stop_gradient = False out.backward() @@ -45,17 +45,18 @@ def relu2_static(func, device, dtype, np_x, use_func=True): with static.program_guard(static.Program()): x = static.data(name='X', shape=[None, 8], dtype=dtype) x.stop_gradient = False - out = func(x) if use_func else paddle.nn.functional.relu(x) + # out, fake_float64, fake_int32 + out = func(x)[0] if use_func else paddle.nn.functional.relu(x) static.append_backward(out) exe = static.Executor() exe.run(static.default_startup_program()) - # in static mode, x data has been covered by out out_v = exe.run(static.default_main_program(), feed={'X': np_x}, fetch_list=[out.name]) + paddle.disable_static() return out_v @@ -68,7 +69,7 @@ def relu2_static_pe(func, device, dtype, np_x, use_func=True): with static.program_guard(static.Program()): x = static.data(name='X', shape=[None, 8], dtype=dtype) x.stop_gradient = False - out = func(x) if use_func else paddle.nn.functional.relu(x) + out = func(x)[0] if use_func else paddle.nn.functional.relu(x) static.append_backward(out) exe = static.Executor() @@ -82,6 +83,7 @@ def relu2_static_pe(func, device, dtype, np_x, use_func=True): feed={'X': np_x}, fetch_list=[out.name]) + paddle.disable_static() return out_v diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index ea855c7e2c..6f784730c9 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -402,12 +402,9 @@ def parse_op_info(op_name): op_proto = OpProtoHolder.instance().get_op_proto(op_name) in_names = [x.name for x in op_proto.inputs] - assert len(op_proto.outputs) == 1 - out_name = op_proto.outputs[0].name + out_names = [x.name for x in op_proto.outputs] - # TODO(Aurelius84): parse necessary out_dtype of custom op - out_infos = {out_name: ['float32']} - return in_names, out_infos + return in_names, out_names def _import_module_from_library(module_name, build_directory, verbose=False): @@ -450,13 +447,10 @@ def _generate_python_module(module_name, def _custom_api_content(op_name): - params_str, ins_str = _get_api_inputs_str(op_name) + params_str, ins_str, outs_str = _get_api_inputs_str(op_name) API_TEMPLATE = textwrap.dedent(""" from paddle.fluid.layer_helper import LayerHelper - from paddle.utils.cpp_extension import parse_op_info - - _, _out_infos = parse_op_info('{op_name}') def {op_name}({inputs}): helper = LayerHelper("{op_name}", **locals()) @@ -464,21 +458,22 @@ def _custom_api_content(op_name): # prepare inputs and output ins = {ins} outs = {{}} - for out_name in _out_infos: - outs[out_name] = [helper.create_variable(dtype=dtype) for dtype in _out_infos[out_name]] + out_names = {out_names} + for out_name in out_names: + # Set 'float32' temporarily, and the actual dtype of output variable will be inferred + # in runtime. + outs[out_name] = helper.create_variable(dtype='float32') helper.append_op(type="{op_name}", inputs=ins, outputs=outs) - res = list(outs.values())[0] - if len(res) == 1: - return res[0] - else: - return res + res = [outs[out_name] for out_name in out_names] + + return res[0] if len(res)==1 else res """).lstrip() # generate python api file api_content = API_TEMPLATE.format( - op_name=op_name, inputs=params_str, ins=ins_str) + op_name=op_name, inputs=params_str, ins=ins_str, out_names=outs_str) return api_content @@ -509,13 +504,15 @@ def _get_api_inputs_str(op_name): """ Returns string of api parameters and inputs dict. """ - in_names, _ = parse_op_info(op_name) + in_names, out_names = parse_op_info(op_name) # e.g: x, y, z params_str = ','.join([p.lower() for p in in_names]) # e.g: {'X': x, 'Y': y, 'Z': z} ins_str = "{%s}" % ','.join( ["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names]) - return params_str, ins_str + # e.g: ['Out', 'Index'] + outs_str = "[%s]" % ','.join(["'{}'".format(name) for name in out_names]) + return params_str, ins_str, outs_str def _write_setup_file(name, -- GitLab