diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index a90986815ed8a8d2de31e36f2f51259efab5ee1a..633d48bf7345a9249b50cf94350de365eab42e86 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -669,6 +669,7 @@ static void RunInferDtypeFunc( const paddle::InferDtypeFunc& func, const std::vector& inputs, const std::vector& outputs, + const std::vector& attrs, const std::unordered_map& inplace_map, const std::unordered_map& inplace_reverse_map) { std::vector input_dtypes; @@ -711,8 +712,51 @@ static void RunInferDtypeFunc( } } + std::vector custom_attrs; + for (auto& attr_str : attrs) { + auto attr_name_and_type = paddle::ParseAttrStr(attr_str); + auto attr_name = attr_name_and_type[0]; + auto attr_type_str = attr_name_and_type[1]; + if (attr_type_str == "bool") { + custom_attrs.emplace_back( + PADDLE_GET_CONST(bool, ctx->GetAttr(attr_name))); + } else if (attr_type_str == "int") { + custom_attrs.emplace_back(PADDLE_GET_CONST(int, ctx->GetAttr(attr_name))); + } else if (attr_type_str == "float") { + custom_attrs.emplace_back( + PADDLE_GET_CONST(float, ctx->GetAttr(attr_name))); + } else if (attr_type_str == "int64_t") { + custom_attrs.emplace_back( + PADDLE_GET_CONST(int64_t, ctx->GetAttr(attr_name))); + } else if (attr_type_str == "std::string") { + custom_attrs.emplace_back( + PADDLE_GET_CONST(std::string, ctx->GetAttr(attr_name))); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back( + PADDLE_GET_CONST(std::vector, ctx->GetAttr(attr_name))); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back( + PADDLE_GET_CONST(std::vector, ctx->GetAttr(attr_name))); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back( + PADDLE_GET_CONST(std::vector, ctx->GetAttr(attr_name))); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back( + PADDLE_GET_CONST(std::vector, ctx->GetAttr(attr_name))); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported `%s` type value as custom attribute now. " + "Supported data types include `bool`, `int`, `float`, " + "`int64_t`, `std::string`, `std::vector`, " + "`std::vector`, `std::vector`, " + "`std::vector`, Please check whether the attribute data " + "type and data type string are matched.", + attr_type_str)); + } + } + VLOG(3) << "Custom Operator: InferDtype - infer output dtype."; - auto output_dtypes = func(input_dtypes, vec_input_dtypes); + auto output_dtypes = func(input_dtypes, vec_input_dtypes, custom_attrs); if (inplace_map.empty()) { PADDLE_ENFORCE_EQ(outputs.size(), output_dtypes.size(), @@ -1016,6 +1060,7 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, } else { info.infer_var_type_ = [op_inputs, op_outputs, + op_attrs, op_inplace_map, op_inplace_reverse_map, infer_dtype_func](InferVarTypeContext* ctx) { @@ -1023,6 +1068,7 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, infer_dtype_func, op_inputs, op_outputs, + op_attrs, op_inplace_map, op_inplace_reverse_map); }; @@ -1051,6 +1097,7 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, OpMetaInfoHelper::GetInplaceReverseMap(cur_grad_op); auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op); auto& grad_infer_shape_fn = OpMetaInfoHelper::GetInferShapeFn(cur_grad_op); + auto& grad_infer_dtype_fn = OpMetaInfoHelper::GetInferDtypeFn(cur_grad_op); VLOG(3) << "Custom Operator: backward, op name: " << grad_op_name; VLOG(3) << "Custom Operator: backward, op inputs: " @@ -1182,6 +1229,25 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, }; } + // Grad InferDtype + if (grad_infer_dtype_fn != nullptr) { + grad_info.infer_var_type_ = + [grad_op_inputs, + grad_op_outputs, + grad_op_attrs, + grad_op_inplace_map, + grad_op_inplace_reverse_map, + grad_infer_dtype_fn](InferVarTypeContext* ctx) { + RunInferDtypeFunc(ctx, + grad_infer_dtype_fn, + grad_op_inputs, + grad_op_outputs, + grad_op_attrs, + grad_op_inplace_map, + grad_op_inplace_reverse_map); + }; + } + // Kernel func RegisterOperatorKernel(grad_op_name, grad_kernel_fn, diff --git a/paddle/phi/api/ext/op_meta_info.h b/paddle/phi/api/ext/op_meta_info.h index 73a784a6eb987dd3709a219cd451a16206bd2353..c774cafcfd26a8a3e4e18319da727b9c8a595ea9 100644 --- a/paddle/phi/api/ext/op_meta_info.h +++ b/paddle/phi/api/ext/op_meta_info.h @@ -643,38 +643,74 @@ struct InferShapeFuncImpl { // Record Op Infer dtype core function using InferDtypeFunc = std::vector (*)( const std::vector& input_dtypes, - const std::vector>& vec_input_dtypes); + const std::vector>& vec_input_dtypes, + const std::vector& attrs); + +#define PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPE(input_type) \ + template \ + struct InferDtypeCallHelper { \ + template \ + static Return InferDtype( \ + const std::vector& input_dtypes, \ + const std::vector>& vec_input_dtypes, \ + const std::vector& attrs, \ + const PreviousArgs&... pargs) { \ + input_type arg = input_dtypes[in_idx]; \ + return InferDtypeCallHelper:: \ + template InferDtype( \ + input_dtypes, vec_input_dtypes, attrs, pargs..., arg); \ + } \ + } -#define PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(input_type) \ +#define PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(input_type) \ + template \ + struct InferDtypeCallHelper { \ + template \ + static Return InferDtype( \ + const std::vector& input_dtypes, \ + const std::vector>& vec_input_dtypes, \ + const std::vector& attrs, \ + const PreviousArgs&... pargs) { \ + input_type arg = vec_input_dtypes[vec_in_idx]; \ + return InferDtypeCallHelper:: \ + template InferDtype( \ + input_dtypes, vec_input_dtypes, attrs, pargs..., arg); \ + } \ + } + +#define PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(attr_type) \ template \ - struct InferDtypeCallHelper { \ - template \ + struct InferDtypeCallHelper { \ + template \ static Return InferDtype( \ const std::vector& input_dtypes, \ const std::vector>& vec_input_dtypes, \ + const std::vector& attrs, \ const PreviousArgs&... pargs) { \ - input_type arg = input_dtypes[in_idx]; \ - return InferDtypeCallHelper::template InferDtype( \ - input_dtypes, vec_input_dtypes, pargs..., arg); \ + try { \ + attr_type arg = paddle::any_cast(attrs[attr_idx]); \ + return InferDtypeCallHelper:: \ + template InferDtype( \ + input_dtypes, vec_input_dtypes, attrs, pargs..., arg); \ + } catch (paddle::bad_any_cast&) { \ + PD_THROW( \ + "Attribute cast error in custom operator InferDtypeFn. " \ + "Expected " #attr_type \ + " value. InferDtypeFn's attribute list must be exactly same as " \ + "Forward KernelFn's attribute list"); \ + } \ } \ } -#define PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(input_type) \ - template \ - struct InferDtypeCallHelper { \ - template \ - static Return InferDtype( \ - const std::vector& input_dtypes, \ - const std::vector>& vec_input_dtypes, \ - const PreviousArgs&... pargs) { \ - input_type arg = vec_input_dtypes[vec_in_idx]; \ - return InferDtypeCallHelper:: \ - template InferDtype( \ - input_dtypes, vec_input_dtypes, pargs..., arg); \ - } \ - } - template struct InferDtypeFuncImpl; @@ -682,35 +718,39 @@ template struct InferDtypeFuncImpl { static Return InferDtype( const std::vector& input_dtypes, - const std::vector>& vec_input_dtypes) { - return InferDtypeCallHelper>::template InferDtype<0, - 0>( - input_dtypes, vec_input_dtypes); + const std::vector>& vec_input_dtypes, + const std::vector& attrs) { + return InferDtypeCallHelper>:: + template InferDtype<0, 0, 0>(input_dtypes, vec_input_dtypes, attrs); } private: template struct InferDtypeCallHelper; - PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(const DataType&); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPE(const DataType&); PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(const std::vector&); template struct InferDtypeCallHelper&, Tail...> { - template + template static Return InferDtype( const std::vector& input_dtypes, const std::vector>& vec_input_dtypes, + const std::vector& attrs, const PreviousArgs&... pargs) { const DataType& arg = input_dtypes[in_idx]; if (arg == DataType::UNDEFINED) { - return InferDtypeCallHelper::template InferDtype( - input_dtypes, vec_input_dtypes, pargs..., paddle::none); + return InferDtypeCallHelper:: + template InferDtype( + input_dtypes, vec_input_dtypes, attrs, pargs..., paddle::none); } else { - return InferDtypeCallHelper::template InferDtype( - input_dtypes, vec_input_dtypes, pargs..., arg); + return InferDtypeCallHelper:: + template InferDtype( + input_dtypes, vec_input_dtypes, attrs, pargs..., arg); } } }; @@ -718,36 +758,65 @@ struct InferDtypeFuncImpl { template struct InferDtypeCallHelper>&, Tail...> { - template + template static Return InferDtype( const std::vector& input_dtypes, const std::vector>& vec_input_dtypes, + const std::vector& attrs, const PreviousArgs&... pargs) { const std::vector& arg = vec_input_dtypes[vec_in_idx]; if (arg.empty()) { return InferDtypeCallHelper:: - template InferDtype( - input_dtypes, vec_input_dtypes, pargs..., paddle::none); + template InferDtype( + input_dtypes, vec_input_dtypes, attrs, pargs..., paddle::none); } else { return InferDtypeCallHelper:: - template InferDtype( - input_dtypes, vec_input_dtypes, pargs..., arg); + template InferDtype( + input_dtypes, vec_input_dtypes, attrs, pargs..., arg); } } }; - // NOTE(chenweihang): Used to be compatible with the 2.0.1 released + // NOTE(HongyuJia): Used to be compatible with the 2.0.1 released // interface, and will be deprecated in the future - PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(DataType); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPE(DataType); PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(std::vector); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(bool); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(int); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(float); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(int64_t); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const std::string&); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const std::vector&); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const std::vector&); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const std::vector&); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const std::vector&); + + // NOTE(HongyuJia): Used to be compatible with the 2.0.1 released + // interface, and will be deprecated in the future + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const bool&); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const int&); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const float&); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const int64_t&); + + // NOTE(HongyuJia): Used to be compatible with the 2.1 released + // interface, but not recommended + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(std::string); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(std::vector); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(std::vector); + PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(std::vector); + // end: base template template struct InferDtypeCallHelper> { - template + template static Return InferDtype( const std::vector& input_dtypes, const std::vector>& vec_input_dtypes, + const std::vector& attrs, const Args&... args) { return impl_fn(args...); } diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index fcf6cd29f85cb5a5a06515e8b0822225bfe24df9..9e4085fe1cbd720ee13bc1b5b41c56203ddea06a 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -506,13 +506,6 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) { } OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) { - PADDLE_ENFORCE_EQ( - index_, - 0UL, - phi::errors::Unimplemented( - "Currently, the InferDtypeFn setting of Grad Op is not supported, " - "And backward Tensor `X@GRAD` will use the dtype of forward Tensor " - "`X` by default.")); info_ptr_->SetInferDtypeFn(std::forward(func)); return *this; } diff --git a/test/custom_op/CMakeLists.txt b/test/custom_op/CMakeLists.txt index fbdc8f9cc653e4e4e9b3f38bf1ec448fadf13e92..631e7016b647d716738b735529414af934d485bf 100644 --- a/test/custom_op/CMakeLists.txt +++ b/test/custom_op/CMakeLists.txt @@ -39,6 +39,8 @@ if(WITH_TESTING) py_test(test_dispatch_jit SRCS test_dispatch_jit.py) py_test(test_multi_out_jit SRCS test_multi_out_jit.py) py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py) + py_test(test_custom_cast_op_jit SRCS test_custom_cast_op_jit.py) + set_tests_properties(test_custom_cast_op_jit PROPERTIES TIMEOUT 180) py_test(test_custom_concat SRCS test_custom_concat.py) set_tests_properties( test_custom_concat PROPERTIES ENVIRONMENT diff --git a/test/custom_op/custom_cast_op.cc b/test/custom_op/custom_cast_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9cfef93ea951ce970fa08397dde2048cb05aa44e --- /dev/null +++ b/test/custom_op/custom_cast_op.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" + +paddle::DataType ConvertDtype(const std::string& data_type) { + if (data_type == "float16") { + return paddle::DataType::FLOAT16; + } else if (data_type == "float32") { + return paddle::DataType::FLOAT32; + } else if (data_type == "float64") { + return paddle::DataType::FLOAT64; + } else { + PD_THROW("DataType Not Supported."); + } +} + +std::vector CastForward(const paddle::Tensor& x, + const std::string& data_type) { + return {paddle::experimental::cast(x, ConvertDtype(data_type))}; +} + +std::vector CastForwardInferDtype( + const paddle::DataType& input_dtype, const std::string& data_type) { + return {ConvertDtype(data_type)}; +} + +std::vector CastBackward(const paddle::Tensor& grad_out, + const std::string& data_type) { + return {paddle::experimental::cast(grad_out, ConvertDtype(data_type))}; +} + +std::vector CastBackwardInferDtype( + const paddle::DataType& grad_out_dtype, const std::string& data_type) { + return {ConvertDtype(data_type)}; +} + +PD_BUILD_OP(custom_cast) + .Inputs({"X"}) + .Attrs({"data_type: std::string"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(CastForward)) + .SetInferDtypeFn(PD_INFER_DTYPE(CastForwardInferDtype)); + +PD_BUILD_GRAD_OP(custom_cast) + .Inputs({paddle::Grad("Out")}) + .Attrs({"data_type: std::string"}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(CastBackward)) + .SetInferDtypeFn(PD_INFER_DTYPE(CastBackwardInferDtype)); diff --git a/test/custom_op/test_custom_cast_op_jit.py b/test/custom_op/test_custom_cast_op_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..24c344c8ad9856ccb8971274e6003a1c1be330a2 --- /dev/null +++ b/test/custom_op/test_custom_cast_op_jit.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np +from utils import ( + extra_cc_args, + extra_nvcc_args, + paddle_includes, + paddle_libraries, +) + +import paddle +from paddle import static +from paddle.utils.cpp_extension import get_build_directory, load +from paddle.utils.cpp_extension.extension_utils import run_cmd + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = '{}\\custom_cast_module_jit\\custom_cast_module_jit.pyd'.format( + get_build_directory() +) +if os.name == 'nt' and os.path.isfile(file): + cmd = f'del {file}' + run_cmd(cmd, True) + +custom_module = load( + name='custom_cast_module_jit', + sources=['custom_cast_op.cc'], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_library_paths=paddle_libraries, + extra_cxx_cflags=extra_cc_args, # test for cc flags + extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags + verbose=True, +) + + +def custom_cast_dynamic(device, dtype, np_x): + paddle.set_device(device) + + x = paddle.to_tensor(np_x, dtype="float32") + x.stop_gradient = False + + out = custom_module.custom_cast(x, dtype) + out.stop_gradient = False + + out.backward() + + assert str(out.dtype).split(".")[-1] == dtype + assert str(x.grad.dtype).split(".")[-1] == dtype + + +def custom_cast_static(device, dtype, np_x): + paddle.enable_static() + paddle.set_device(device) + + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + x = static.data(name='X', shape=[None, 8], dtype="float32") + x.stop_gradient = False + out = custom_module.custom_cast(x, dtype) + static.append_backward(out) + + exe = static.Executor() + exe.run(static.default_startup_program()) + # in static graph mode, x data has been covered by out + out_v, x_grad_v = exe.run( + static.default_main_program(), + feed={'X': np_x}, + fetch_list=[out.name, x.name + "@GRAD"], + ) + + assert x_grad_v[0].dtype == dtype + assert out_v[0].dtype == dtype + + paddle.disable_static() + return out_v + + +class TestCustomCastOp(unittest.TestCase): + def setUp(self): + self.dtypes = ['float32', 'float64'] + + def test_static(self): + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype("float32") + custom_cast_static('cpu', dtype, x) + + def test_dynamic(self): + for dtype in self.dtypes: + x = np.random.uniform(-1, 1, [4, 8]).astype("float32") + custom_cast_dynamic('cpu', dtype, x) + + +if __name__ == '__main__': + unittest.main()