未验证 提交 f2dc29a9 编写于 作者: A Aurelius84 提交者: GitHub

[CustomOp] Support output dtypes in generated Python API (#31045)

上级 615d8a22
......@@ -33,7 +33,7 @@ std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype);
// to test jointly compile multi operators at same time.
PD_BUILD_OP("relu3")
.Inputs({"X"})
.Outputs({"Out"})
.Outputs({"Out", "Fake_float64", "ZFake_int32"})
.SetKernelFn(PD_KERNEL(ReluForward))
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
......
......@@ -17,6 +17,13 @@
#include "paddle/extension.h"
template <typename data_t>
void fill_constant_cpu_kernel(data_t* out_data, int64_t x_numel, data_t value) {
for (int i = 0; i < x_numel; ++i) {
out_data[i] = value;
}
}
template <typename data_t>
void relu_cpu_forward_kernel(const data_t* x_data,
data_t* out_data,
......@@ -46,8 +53,21 @@ std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
}));
// fake multi output: Fake_float64 with float64 dtype
auto fake_float64 = paddle::Tensor(paddle::PlaceType::kCPU);
fake_float64.reshape(x.shape());
fill_constant_cpu_kernel<double>(
fake_float64.mutable_data<double>(x.place()), x.size(), 0.);
// fake multi output: ZFake_int32 with int32 dtype
auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kCPU);
zfake_int32.reshape(x.shape());
fill_constant_cpu_kernel<int32_t>(
zfake_int32.mutable_data<int32_t>(x.place()), x.size(), 1);
return {out};
return {out, fake_float64, zfake_int32};
}
std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
......@@ -97,16 +117,16 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
}
std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape) {
return {x_shape};
return {x_shape, x_shape, x_shape};
}
std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype) {
return {x_dtype};
return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32};
}
PD_BUILD_OP("relu2")
.Inputs({"X"})
.Outputs({"Out"})
.Outputs({"Out", "Fake_float64", "ZFake_int32"})
.SetKernelFn(PD_KERNEL(ReluForward))
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
......
......@@ -14,6 +14,16 @@
#include "paddle/extension.h"
template <typename data_t>
__global__ void fill_constant_cuda_kernel(data_t* y,
const int num,
data_t value) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
y[i] = value;
}
}
template <typename data_t>
__global__ void relu_cuda_forward_kernel(const data_t* x,
data_t* y,
......@@ -47,8 +57,18 @@ std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
relu_cuda_forward_kernel<data_t><<<grid, block>>>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
}));
// fake multi output: Fake_1
auto fake_float64 = paddle::Tensor(paddle::PlaceType::kGPU);
fake_float64.reshape(x.shape());
fill_constant_cuda_kernel<double><<<grid, block>>>(
fake_float64.mutable_data<double>(x.place()), numel, 0.);
// fake multi output: ZFake_1
auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kGPU);
zfake_int32.reshape(x.shape());
fill_constant_cuda_kernel<int32_t><<<grid, block>>>(
zfake_int32.mutable_data<int32_t>(x.place()), numel, 1);
return {out};
return {out, fake_float64, zfake_int32};
}
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
......
......@@ -64,5 +64,62 @@ class TestJITLoad(unittest.TestCase):
x_grad, pd_x_grad))
class TestMultiOutputDtypes(unittest.TestCase):
def setUp(self):
self.custom_op = custom_module.relu2
self.dtypes = ['float32', 'float64']
self.devices = ['cpu', 'gpu']
def test_static(self):
paddle.enable_static()
for device in self.devices:
for dtype in self.dtypes:
res = self.run_static(device, dtype)
self.check_multi_outputs(res)
paddle.disable_static()
def test_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
paddle.set_device(device)
x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
x = paddle.to_tensor(x_data)
outs = self.custom_op(x)
self.assertTrue(len(outs) == 3)
self.check_multi_outputs(outs, True)
def check_multi_outputs(self, outs, is_dynamic=False):
out, zero_float64, one_int32 = outs
if is_dynamic:
zero_float64 = zero_float64.numpy()
one_int32 = one_int32.numpy()
# Fake_float64
self.assertTrue('float64' in str(zero_float64.dtype))
self.assertTrue(
np.array_equal(zero_float64, np.zeros([4, 8]).astype('float64')))
# ZFake_int32
self.assertTrue('int32' in str(one_int32.dtype))
self.assertTrue(
np.array_equal(one_int32, np.ones([4, 8]).astype('int32')))
def run_static(self, device, dtype):
paddle.set_device(device)
x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
with paddle.static.scope_guard(paddle.static.Scope()):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name='X', shape=[None, 8], dtype=dtype)
outs = self.custom_op(x)
exe = paddle.static.Executor()
exe.run(paddle.static.default_startup_program())
res = exe.run(paddle.static.default_main_program(),
feed={'X': x_data},
fetch_list=outs)
return res
if __name__ == '__main__':
unittest.main()
......@@ -29,7 +29,7 @@ def relu2_dynamic(func, device, dtype, np_x, use_func=True):
t = paddle.to_tensor(np_x)
t.stop_gradient = False
out = func(t) if use_func else paddle.nn.functional.relu(t)
out = func(t)[0] if use_func else paddle.nn.functional.relu(t)
out.stop_gradient = False
out.backward()
......@@ -45,17 +45,18 @@ def relu2_static(func, device, dtype, np_x, use_func=True):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False
out = func(x) if use_func else paddle.nn.functional.relu(x)
# out, fake_float64, fake_int32
out = func(x)[0] if use_func else paddle.nn.functional.relu(x)
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
# in static mode, x data has been covered by out
out_v = exe.run(static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name])
paddle.disable_static()
return out_v
......@@ -68,7 +69,7 @@ def relu2_static_pe(func, device, dtype, np_x, use_func=True):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
x.stop_gradient = False
out = func(x) if use_func else paddle.nn.functional.relu(x)
out = func(x)[0] if use_func else paddle.nn.functional.relu(x)
static.append_backward(out)
exe = static.Executor()
......@@ -82,6 +83,7 @@ def relu2_static_pe(func, device, dtype, np_x, use_func=True):
feed={'X': np_x},
fetch_list=[out.name])
paddle.disable_static()
return out_v
......
......@@ -402,12 +402,9 @@ def parse_op_info(op_name):
op_proto = OpProtoHolder.instance().get_op_proto(op_name)
in_names = [x.name for x in op_proto.inputs]
assert len(op_proto.outputs) == 1
out_name = op_proto.outputs[0].name
out_names = [x.name for x in op_proto.outputs]
# TODO(Aurelius84): parse necessary out_dtype of custom op
out_infos = {out_name: ['float32']}
return in_names, out_infos
return in_names, out_names
def _import_module_from_library(module_name, build_directory, verbose=False):
......@@ -450,13 +447,10 @@ def _generate_python_module(module_name,
def _custom_api_content(op_name):
params_str, ins_str = _get_api_inputs_str(op_name)
params_str, ins_str, outs_str = _get_api_inputs_str(op_name)
API_TEMPLATE = textwrap.dedent("""
from paddle.fluid.layer_helper import LayerHelper
from paddle.utils.cpp_extension import parse_op_info
_, _out_infos = parse_op_info('{op_name}')
def {op_name}({inputs}):
helper = LayerHelper("{op_name}", **locals())
......@@ -464,21 +458,22 @@ def _custom_api_content(op_name):
# prepare inputs and output
ins = {ins}
outs = {{}}
for out_name in _out_infos:
outs[out_name] = [helper.create_variable(dtype=dtype) for dtype in _out_infos[out_name]]
out_names = {out_names}
for out_name in out_names:
# Set 'float32' temporarily, and the actual dtype of output variable will be inferred
# in runtime.
outs[out_name] = helper.create_variable(dtype='float32')
helper.append_op(type="{op_name}", inputs=ins, outputs=outs)
res = list(outs.values())[0]
if len(res) == 1:
return res[0]
else:
return res
res = [outs[out_name] for out_name in out_names]
return res[0] if len(res)==1 else res
""").lstrip()
# generate python api file
api_content = API_TEMPLATE.format(
op_name=op_name, inputs=params_str, ins=ins_str)
op_name=op_name, inputs=params_str, ins=ins_str, out_names=outs_str)
return api_content
......@@ -509,13 +504,15 @@ def _get_api_inputs_str(op_name):
"""
Returns string of api parameters and inputs dict.
"""
in_names, _ = parse_op_info(op_name)
in_names, out_names = parse_op_info(op_name)
# e.g: x, y, z
params_str = ','.join([p.lower() for p in in_names])
# e.g: {'X': x, 'Y': y, 'Z': z}
ins_str = "{%s}" % ','.join(
["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names])
return params_str, ins_str
# e.g: ['Out', 'Index']
outs_str = "[%s]" % ','.join(["'{}'".format(name) for name in out_names])
return params_str, ins_str, outs_str
def _write_setup_file(name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册