未验证 提交 f2dc29a9 编写于 作者: A Aurelius84 提交者: GitHub

[CustomOp] Support output dtypes in generated Python API (#31045)

上级 615d8a22
...@@ -33,7 +33,7 @@ std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype); ...@@ -33,7 +33,7 @@ std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype);
// to test jointly compile multi operators at same time. // to test jointly compile multi operators at same time.
PD_BUILD_OP("relu3") PD_BUILD_OP("relu3")
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out", "Fake_float64", "ZFake_int32"})
.SetKernelFn(PD_KERNEL(ReluForward)) .SetKernelFn(PD_KERNEL(ReluForward))
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) .SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
......
...@@ -17,6 +17,13 @@ ...@@ -17,6 +17,13 @@
#include "paddle/extension.h" #include "paddle/extension.h"
template <typename data_t>
void fill_constant_cpu_kernel(data_t* out_data, int64_t x_numel, data_t value) {
for (int i = 0; i < x_numel; ++i) {
out_data[i] = value;
}
}
template <typename data_t> template <typename data_t>
void relu_cpu_forward_kernel(const data_t* x_data, void relu_cpu_forward_kernel(const data_t* x_data,
data_t* out_data, data_t* out_data,
...@@ -46,8 +53,21 @@ std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) { ...@@ -46,8 +53,21 @@ std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
relu_cpu_forward_kernel<data_t>( relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size()); x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
})); }));
// fake multi output: Fake_float64 with float64 dtype
auto fake_float64 = paddle::Tensor(paddle::PlaceType::kCPU);
fake_float64.reshape(x.shape());
fill_constant_cpu_kernel<double>(
fake_float64.mutable_data<double>(x.place()), x.size(), 0.);
// fake multi output: ZFake_int32 with int32 dtype
auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kCPU);
zfake_int32.reshape(x.shape());
fill_constant_cpu_kernel<int32_t>(
zfake_int32.mutable_data<int32_t>(x.place()), x.size(), 1);
return {out}; return {out, fake_float64, zfake_int32};
} }
std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x, std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
...@@ -97,16 +117,16 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x, ...@@ -97,16 +117,16 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
} }
std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape) { std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape) {
return {x_shape}; return {x_shape, x_shape, x_shape};
} }
std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype) { std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype) {
return {x_dtype}; return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32};
} }
PD_BUILD_OP("relu2") PD_BUILD_OP("relu2")
.Inputs({"X"}) .Inputs({"X"})
.Outputs({"Out"}) .Outputs({"Out", "Fake_float64", "ZFake_int32"})
.SetKernelFn(PD_KERNEL(ReluForward)) .SetKernelFn(PD_KERNEL(ReluForward))
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType)) .SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
......
...@@ -14,6 +14,16 @@ ...@@ -14,6 +14,16 @@
#include "paddle/extension.h" #include "paddle/extension.h"
template <typename data_t>
__global__ void fill_constant_cuda_kernel(data_t* y,
const int num,
data_t value) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
y[i] = value;
}
}
template <typename data_t> template <typename data_t>
__global__ void relu_cuda_forward_kernel(const data_t* x, __global__ void relu_cuda_forward_kernel(const data_t* x,
data_t* y, data_t* y,
...@@ -47,8 +57,18 @@ std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) { ...@@ -47,8 +57,18 @@ std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
relu_cuda_forward_kernel<data_t><<<grid, block>>>( relu_cuda_forward_kernel<data_t><<<grid, block>>>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel); x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
})); }));
// fake multi output: Fake_1
auto fake_float64 = paddle::Tensor(paddle::PlaceType::kGPU);
fake_float64.reshape(x.shape());
fill_constant_cuda_kernel<double><<<grid, block>>>(
fake_float64.mutable_data<double>(x.place()), numel, 0.);
// fake multi output: ZFake_1
auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kGPU);
zfake_int32.reshape(x.shape());
fill_constant_cuda_kernel<int32_t><<<grid, block>>>(
zfake_int32.mutable_data<int32_t>(x.place()), numel, 1);
return {out}; return {out, fake_float64, zfake_int32};
} }
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x, std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
......
...@@ -64,5 +64,62 @@ class TestJITLoad(unittest.TestCase): ...@@ -64,5 +64,62 @@ class TestJITLoad(unittest.TestCase):
x_grad, pd_x_grad)) x_grad, pd_x_grad))
class TestMultiOutputDtypes(unittest.TestCase):
def setUp(self):
self.custom_op = custom_module.relu2
self.dtypes = ['float32', 'float64']
self.devices = ['cpu', 'gpu']
def test_static(self):
paddle.enable_static()
for device in self.devices:
for dtype in self.dtypes:
res = self.run_static(device, dtype)
self.check_multi_outputs(res)
paddle.disable_static()
def test_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
paddle.set_device(device)
x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
x = paddle.to_tensor(x_data)
outs = self.custom_op(x)
self.assertTrue(len(outs) == 3)
self.check_multi_outputs(outs, True)
def check_multi_outputs(self, outs, is_dynamic=False):
out, zero_float64, one_int32 = outs
if is_dynamic:
zero_float64 = zero_float64.numpy()
one_int32 = one_int32.numpy()
# Fake_float64
self.assertTrue('float64' in str(zero_float64.dtype))
self.assertTrue(
np.array_equal(zero_float64, np.zeros([4, 8]).astype('float64')))
# ZFake_int32
self.assertTrue('int32' in str(one_int32.dtype))
self.assertTrue(
np.array_equal(one_int32, np.ones([4, 8]).astype('int32')))
def run_static(self, device, dtype):
paddle.set_device(device)
x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
with paddle.static.scope_guard(paddle.static.Scope()):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name='X', shape=[None, 8], dtype=dtype)
outs = self.custom_op(x)
exe = paddle.static.Executor()
exe.run(paddle.static.default_startup_program())
res = exe.run(paddle.static.default_main_program(),
feed={'X': x_data},
fetch_list=outs)
return res
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -29,7 +29,7 @@ def relu2_dynamic(func, device, dtype, np_x, use_func=True): ...@@ -29,7 +29,7 @@ def relu2_dynamic(func, device, dtype, np_x, use_func=True):
t = paddle.to_tensor(np_x) t = paddle.to_tensor(np_x)
t.stop_gradient = False t.stop_gradient = False
out = func(t) if use_func else paddle.nn.functional.relu(t) out = func(t)[0] if use_func else paddle.nn.functional.relu(t)
out.stop_gradient = False out.stop_gradient = False
out.backward() out.backward()
...@@ -45,17 +45,18 @@ def relu2_static(func, device, dtype, np_x, use_func=True): ...@@ -45,17 +45,18 @@ def relu2_static(func, device, dtype, np_x, use_func=True):
with static.program_guard(static.Program()): with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype) x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False x.stop_gradient = False
out = func(x) if use_func else paddle.nn.functional.relu(x) # out, fake_float64, fake_int32
out = func(x)[0] if use_func else paddle.nn.functional.relu(x)
static.append_backward(out) static.append_backward(out)
exe = static.Executor() exe = static.Executor()
exe.run(static.default_startup_program()) exe.run(static.default_startup_program())
# in static mode, x data has been covered by out # in static mode, x data has been covered by out
out_v = exe.run(static.default_main_program(), out_v = exe.run(static.default_main_program(),
feed={'X': np_x}, feed={'X': np_x},
fetch_list=[out.name]) fetch_list=[out.name])
paddle.disable_static()
return out_v return out_v
...@@ -68,7 +69,7 @@ def relu2_static_pe(func, device, dtype, np_x, use_func=True): ...@@ -68,7 +69,7 @@ def relu2_static_pe(func, device, dtype, np_x, use_func=True):
with static.program_guard(static.Program()): with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype) x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False x.stop_gradient = False
out = func(x) if use_func else paddle.nn.functional.relu(x) out = func(x)[0] if use_func else paddle.nn.functional.relu(x)
static.append_backward(out) static.append_backward(out)
exe = static.Executor() exe = static.Executor()
...@@ -82,6 +83,7 @@ def relu2_static_pe(func, device, dtype, np_x, use_func=True): ...@@ -82,6 +83,7 @@ def relu2_static_pe(func, device, dtype, np_x, use_func=True):
feed={'X': np_x}, feed={'X': np_x},
fetch_list=[out.name]) fetch_list=[out.name])
paddle.disable_static()
return out_v return out_v
......
...@@ -402,12 +402,9 @@ def parse_op_info(op_name): ...@@ -402,12 +402,9 @@ def parse_op_info(op_name):
op_proto = OpProtoHolder.instance().get_op_proto(op_name) op_proto = OpProtoHolder.instance().get_op_proto(op_name)
in_names = [x.name for x in op_proto.inputs] in_names = [x.name for x in op_proto.inputs]
assert len(op_proto.outputs) == 1 out_names = [x.name for x in op_proto.outputs]
out_name = op_proto.outputs[0].name
# TODO(Aurelius84): parse necessary out_dtype of custom op return in_names, out_names
out_infos = {out_name: ['float32']}
return in_names, out_infos
def _import_module_from_library(module_name, build_directory, verbose=False): def _import_module_from_library(module_name, build_directory, verbose=False):
...@@ -450,13 +447,10 @@ def _generate_python_module(module_name, ...@@ -450,13 +447,10 @@ def _generate_python_module(module_name,
def _custom_api_content(op_name): def _custom_api_content(op_name):
params_str, ins_str = _get_api_inputs_str(op_name) params_str, ins_str, outs_str = _get_api_inputs_str(op_name)
API_TEMPLATE = textwrap.dedent(""" API_TEMPLATE = textwrap.dedent("""
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.utils.cpp_extension import parse_op_info
_, _out_infos = parse_op_info('{op_name}')
def {op_name}({inputs}): def {op_name}({inputs}):
helper = LayerHelper("{op_name}", **locals()) helper = LayerHelper("{op_name}", **locals())
...@@ -464,21 +458,22 @@ def _custom_api_content(op_name): ...@@ -464,21 +458,22 @@ def _custom_api_content(op_name):
# prepare inputs and output # prepare inputs and output
ins = {ins} ins = {ins}
outs = {{}} outs = {{}}
for out_name in _out_infos: out_names = {out_names}
outs[out_name] = [helper.create_variable(dtype=dtype) for dtype in _out_infos[out_name]] for out_name in out_names:
# Set 'float32' temporarily, and the actual dtype of output variable will be inferred
# in runtime.
outs[out_name] = helper.create_variable(dtype='float32')
helper.append_op(type="{op_name}", inputs=ins, outputs=outs) helper.append_op(type="{op_name}", inputs=ins, outputs=outs)
res = list(outs.values())[0] res = [outs[out_name] for out_name in out_names]
if len(res) == 1:
return res[0] return res[0] if len(res)==1 else res
else:
return res
""").lstrip() """).lstrip()
# generate python api file # generate python api file
api_content = API_TEMPLATE.format( api_content = API_TEMPLATE.format(
op_name=op_name, inputs=params_str, ins=ins_str) op_name=op_name, inputs=params_str, ins=ins_str, out_names=outs_str)
return api_content return api_content
...@@ -509,13 +504,15 @@ def _get_api_inputs_str(op_name): ...@@ -509,13 +504,15 @@ def _get_api_inputs_str(op_name):
""" """
Returns string of api parameters and inputs dict. Returns string of api parameters and inputs dict.
""" """
in_names, _ = parse_op_info(op_name) in_names, out_names = parse_op_info(op_name)
# e.g: x, y, z # e.g: x, y, z
params_str = ','.join([p.lower() for p in in_names]) params_str = ','.join([p.lower() for p in in_names])
# e.g: {'X': x, 'Y': y, 'Z': z} # e.g: {'X': x, 'Y': y, 'Z': z}
ins_str = "{%s}" % ','.join( ins_str = "{%s}" % ','.join(
["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names]) ["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names])
return params_str, ins_str # e.g: ['Out', 'Index']
outs_str = "[%s]" % ','.join(["'{}'".format(name) for name in out_names])
return params_str, ins_str, outs_str
def _write_setup_file(name, def _write_setup_file(name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册