diff --git a/paddle/fluid/extension/include/op_meta_info.h b/paddle/fluid/extension/include/op_meta_info.h index c16f61374f7cba5dd727fe5d22449bbeca772de8..1bc044f647fbae0c4666ecda9e2a2fc3dc8ef214 100644 --- a/paddle/fluid/extension/include/op_meta_info.h +++ b/paddle/fluid/extension/include/op_meta_info.h @@ -81,6 +81,26 @@ inline std::string Grad(const std::string& var_name) { using KernelFunc = std::vector (*)(std::vector inputs, std::vector attrs); +#define PD_SPECIALIZE_ComputeCallHelper(attr_type) \ + template \ + struct ComputeCallHelper { \ + template \ + static Return Compute(std::vector inputs, \ + std::vector attrs, \ + const PreviousArgs&... pargs) { \ + try { \ + attr_type arg = boost::any_cast(attrs[attr_idx]); \ + return ComputeCallHelper::template Compute( \ + inputs, attrs, pargs..., arg); \ + } catch (boost::bad_any_cast&) { \ + PD_THROW( \ + "Attribute cast error in custom operator. Expected " #attr_type \ + " value."); \ + } \ + } \ + } + template struct TypeTag {}; @@ -114,26 +134,20 @@ struct KernelFuncImpl { } }; - // TODO(chenweihang): add support for attribute input - // int attribute input (not used now) - template - struct ComputeCallHelper { - template - static Return Compute(std::vector inputs, - std::vector attrs, - const PreviousArgs&... pargs) { - try { - int arg = boost::any_cast(attrs[attr_idx]); - return ComputeCallHelper::template Compute( - inputs, attrs, pargs..., arg); - } catch (boost::bad_any_cast&) { - PD_THROW( - "Attribute cast error in custom operator. Expected int value."); - } - } - }; - + PD_SPECIALIZE_ComputeCallHelper(bool); + PD_SPECIALIZE_ComputeCallHelper(int); + PD_SPECIALIZE_ComputeCallHelper(float); + PD_SPECIALIZE_ComputeCallHelper(int64_t); + PD_SPECIALIZE_ComputeCallHelper(std::string); + PD_SPECIALIZE_ComputeCallHelper(std::vector); + PD_SPECIALIZE_ComputeCallHelper(std::vector); + PD_SPECIALIZE_ComputeCallHelper(std::vector); + PD_SPECIALIZE_ComputeCallHelper(std::vector); + // TODO(chenweihang): support other attribute type if needed. + // Why not support other attribute type here? + // - boost::blank, std::vector and std::vector + // are not used in op + // - BlockDesc* and std::vector are used in framework // end: base template template struct ComputeCallHelper> { @@ -245,10 +259,23 @@ struct InferDtypeFuncImpl { class PD_DLL_DECL OpMetaInfo { public: explicit OpMetaInfo(const std::string& op_name) : name_(op_name) {} + + // format: {"", "", ...} OpMetaInfo& Inputs(std::vector&& inputs); + + // format: {"", "", ...} OpMetaInfo& Outputs(std::vector&& outputs); + + // format: {":", ":", ...} + OpMetaInfo& Attrs(std::vector&& attrs); + + // format: PD_KERNEL(...) OpMetaInfo& SetKernelFn(KernelFunc&& func); + + // format: PD_INFER_SHAPE(...) OpMetaInfo& SetInferShapeFn(InferShapeFunc&& func); + + // format: PD_INFER_DTYPE(...) OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func); private: @@ -297,6 +324,7 @@ class PD_DLL_DECL OpMetaInfoBuilder { explicit OpMetaInfoBuilder(std::string&& name); OpMetaInfoBuilder& Inputs(std::vector&& inputs); OpMetaInfoBuilder& Outputs(std::vector&& outputs); + OpMetaInfoBuilder& Attrs(std::vector&& attrs); OpMetaInfoBuilder& SetKernelFn(KernelFunc func); OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func); OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func); diff --git a/paddle/fluid/extension/src/op_meta_info.cc b/paddle/fluid/extension/src/op_meta_info.cc index 0273dfd5d07a69a30e1ca00c3f2a42b9ff8a8c50..d362282b8d9d24c287e51643d3aca72d9fd36c50 100644 --- a/paddle/fluid/extension/src/op_meta_info.cc +++ b/paddle/fluid/extension/src/op_meta_info.cc @@ -32,6 +32,10 @@ OpMetaInfo& OpMetaInfo::Outputs(std::vector&& outputs) { outputs_ = std::forward>(outputs); return *this; } +OpMetaInfo& OpMetaInfo::Attrs(std::vector&& attrs) { + attrs_ = std::forward>(attrs); + return *this; +} OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) { kernel_fn_ = std::forward(func); return *this; @@ -78,6 +82,11 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs( return *this; } +OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector&& attrs) { + info_ptr_->Attrs(std::forward>(attrs)); + return *this; +} + OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) { info_ptr_->SetKernelFn(std::forward(func)); return *this; diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 1e2a77e915dea4e19046c68e176ba49637ece9ac..03a8cc366e7f2e8bb3baa2dd65ee609533cb8137 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -73,6 +73,24 @@ inline bool IsMemberOf(const std::vector& vec, return std::find(vec.cbegin(), vec.cend(), name) != vec.cend(); } +std::vector ParseAttrStr(const std::string& attr) { + auto split_pos = attr.find_first_of(":"); + PADDLE_ENFORCE_NE(split_pos, std::string::npos, + platform::errors::InvalidArgument( + "Invalid attribute string format. Attribute string " + "format is `:`.")); + + std::vector rlt; + // 1. name + rlt.emplace_back(string::trim_spaces(attr.substr(0, split_pos))); + // 2. type + rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1))); + + VLOG(1) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1]; + + return rlt; +} + } // namespace detail ////////////////// Kernel Define //////////////////// @@ -81,7 +99,8 @@ inline bool IsMemberOf(const std::vector& vec, static void RunKernelFunc(const framework::ExecutionContext& ctx, const paddle::KernelFunc& func, const std::vector& inputs, - const std::vector& outputs) { + const std::vector& outputs, + const std::vector& attrs) { VLOG(1) << "Custom Operator: Start run KernelFunc."; std::vector custom_ins; for (auto& in_name : inputs) { @@ -98,10 +117,43 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, custom_ins.emplace_back(custom_in); } - std::vector attrs; + std::vector custom_attrs; + for (auto& attr_str : attrs) { + auto attr_name_and_type = detail::ParseAttrStr(attr_str); + auto attr_name = attr_name_and_type[0]; + auto attr_type_str = attr_name_and_type[1]; + if (attr_type_str == "bool") { + custom_attrs.emplace_back(ctx.Attr(attr_name)); + } else if (attr_type_str == "int") { + custom_attrs.emplace_back(ctx.Attr(attr_name)); + } else if (attr_type_str == "float") { + custom_attrs.emplace_back(ctx.Attr(attr_name)); + } else if (attr_type_str == "int64_t") { + custom_attrs.emplace_back(ctx.Attr(attr_name)); + } else if (attr_type_str == "std::string") { + custom_attrs.emplace_back(ctx.Attr(attr_name)); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back(ctx.Attr>(attr_name)); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back(ctx.Attr>(attr_name)); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back(ctx.Attr>(attr_name)); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back(ctx.Attr>(attr_name)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported `%s` type value as custom attribute now. " + "Supported data types include `bool`, `int`, `float`, " + "`int64_t`, `std::string`, `std::vector`, " + "`std::vector`, `std::vector, " + "`std::vector`, Please check whether " + "the attribute data type and data type string are matched.", + attr_type_str)); + } + } VLOG(1) << "Run ComputeFunc."; - auto outs = func(custom_ins, attrs); + auto outs = func(custom_ins, custom_attrs); VLOG(1) << "Custom Operator: Share outputs into ExecutionContext."; for (size_t i = 0; i < outputs.size(); ++i) { @@ -164,7 +216,51 @@ class CustomOpMaker : public OpProtoAndCheckerMaker { for (auto& out_name : outputs_) { AddOutput(out_name, "The output " + out_name + "of Custom Operator."); } - // TODO(chenweihang): support attrs in later PR + for (auto& attr : attrs_) { + auto attr_name_and_type = detail::ParseAttrStr(attr); + auto attr_name = attr_name_and_type[0]; + auto attr_type_str = attr_name_and_type[1]; + if (attr_type_str == "bool") { + AddAttr(attr_name, "custom operator bool attribute.") + .SetDefault(false); + } else if (attr_type_str == "int") { + AddAttr(attr_name, "custom operator int attribute.").SetDefault(1); + } else if (attr_type_str == "float") { + AddAttr(attr_name, "custom operator float attribute.") + .SetDefault(1.0f); + } else if (attr_type_str == "int64_t") { + AddAttr(attr_name, "custom operator int64_t attribute.") + .SetDefault(1); + } else if (attr_type_str == "std::string") { + AddAttr(attr_name, "custom operator int attribute.") + .SetDefault(""); + } else if (attr_type_str == "std::vector") { + AddAttr>(attr_name, + "custom operator std::vector attribute.") + .SetDefault({}); + } else if (attr_type_str == "std::vector") { + AddAttr>( + attr_name, "custom operator std::vector attribute.") + .SetDefault({}); + } else if (attr_type_str == "std::vector") { + AddAttr>( + attr_name, "custom operator std::vector attribute.") + .SetDefault({}); + } else if (attr_type_str == "std::vector") { + AddAttr>( + attr_name, "custom operator std::vector attribute.") + .SetDefault({}); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported `%s` type value as custom attribute now. " + "Supported data types include `bool`, `int`, `float`, " + "`int64_t`, `std::string`, `std::vector`, " + "`std::vector`, `std::vector, " + "`std::vector`, Please check whether " + "the attribute data type and data type string are matched.", + attr_type_str)); + } + } AddComment(R"DOC( Custom Operator. @@ -227,7 +323,7 @@ class CustomGradOpMaker : public SingleGradOpMaker { VLOG(1) << "Custom Operator: GradOpDescMaker - output: " << out_name; grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name))); } - // TODO(chenweihang): support attrs in later PR + grad_op->SetAttrMap(this->Attrs()); } private: @@ -287,7 +383,7 @@ class CustomGradOpMaker VLOG(1) << "Custom Operator: GradOpBaseMaker - output: " << out_name; grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name))); } - // TODO(chenweihang): support attrs in later PR + grad_op->SetAttrMap(this->Attrs()); } private: @@ -303,21 +399,24 @@ void RegisterOperatorKernelWithPlace(const std::string& name, const proto::VarType::Type type, const PlaceType& place, const std::vector& inputs, - const std::vector& outputs) { + const std::vector& outputs, + const std::vector& attrs) { OpKernelType key(type, CustomTensorUtils::ConvertEnumPlaceToInnerPlace(place)); VLOG(1) << "Custom Operator: op kernel key: " << key; OperatorWithKernel::AllOpKernels()[name][key] = - [kernel_func, inputs, outputs](const framework::ExecutionContext& ctx) { + [kernel_func, inputs, outputs, + attrs](const framework::ExecutionContext& ctx) { VLOG(1) << "Custom Operator: run custom kernel func in lambda."; - RunKernelFunc(ctx, kernel_func, inputs, outputs); + RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs); }; } void RegisterOperatorKernel(const std::string& name, const paddle::KernelFunc& kernel_func, const std::vector& inputs, - const std::vector& outputs) { + const std::vector& outputs, + const std::vector& attrs) { VLOG(1) << "Custom Operator: op name in kernel: " << name; // NOTE [ Dummy Op Kernel Key ] // TODO(chenweihang): Because execute engine need get device context based @@ -325,9 +424,11 @@ void RegisterOperatorKernel(const std::string& name, // device. But this is not entirely correct, if user only give a cpu kernel, // but call api in gpu device, it will cause error. RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, - PlaceType::kCPU, inputs, outputs); + PlaceType::kCPU, inputs, outputs, attrs); +#ifdef PADDLE_WITH_CUDA RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW, - PlaceType::kGPU, inputs, outputs); + PlaceType::kGPU, inputs, outputs, attrs); +#endif } void RegisterOperatorWithMetaInfo( @@ -350,6 +451,8 @@ void RegisterOperatorWithMetaInfo( << string::join_strings(op_inputs, ','); VLOG(1) << "Custom Operator: forward, op outputs: " << string::join_strings(op_outputs, ','); + VLOG(1) << "Custom Operator: forward, op attrs: " + << string::join_strings(op_attrs, ','); // Op info.creator_ = [](const std::string& op_name, const VariableNameMap& inputs, @@ -426,7 +529,7 @@ void RegisterOperatorWithMetaInfo( }; // Kernel func - RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs); + RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs); // If grad op or double grad op exists std::string cur_op_name = op_name; @@ -436,6 +539,7 @@ void RegisterOperatorWithMetaInfo( auto& grad_op_name = OpMetaInfoHelper::GetOpName(cur_grad_op); auto& grad_op_inputs = OpMetaInfoHelper::GetInputs(cur_grad_op); auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op); + auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op); auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op); VLOG(1) << "Custom Operator: backward, op name: " << grad_op_name; @@ -489,7 +593,7 @@ void RegisterOperatorWithMetaInfo( // Kernel func RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs, - grad_op_outputs); + grad_op_outputs, grad_op_attrs); // update current info OpInfoMap::Instance().Insert(cur_op_name, info); diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index 10d8b898c7589b3bc7c8a7e07ca094618de1405e..3f85f4ef50a223949ef60678b61e97be29aea471 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -13,10 +13,13 @@ py_test(test_sysconfig SRCS test_sysconfig.py) # 'test_dispatch' compile .cc file py_test(test_dispatch_jit SRCS test_dispatch_jit.py) -set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 180) +set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 120) py_test(test_multi_out_jit SRCS test_multi_out_jit.py) -set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 180) +set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120) + +py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py) +set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120) if(NOT LINUX) return() diff --git a/python/paddle/fluid/tests/custom_op/attr_test_op.cc b/python/paddle/fluid/tests/custom_op/attr_test_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..474d3d2d4e2b3b566620a11d41564fb662bd35e3 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/attr_test_op.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/extension.h" + +template +void assign_cpu_kernel(const data_t* x_data, + data_t* out_data, + int64_t x_numel) { + for (int i = 0; i < x_numel; ++i) { + out_data[i] = x_data[i]; + } +} + +std::vector AttrTestForward( + const paddle::Tensor& x, + bool bool_attr, + int int_attr, + float float_attr, + int64_t int64_attr, + std::string str_attr, + std::vector int_vec_attr, + std::vector float_vec_attr, + std::vector int64_vec_attr, + std::vector str_vec_attr) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU); + out.reshape(x.shape()); + + PD_DISPATCH_FLOATING_TYPES( + x.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + x.data(), out.mutable_data(), x.size()); + })); + + // Check attrs value + if (bool_attr != true) { + throw std::runtime_error("bool_attr value error."); + } + if (int_attr != 10) { + throw std::runtime_error("int_attr value error."); + } + if (std::abs(float_attr - 3.14) > 1e-6) { + throw std::runtime_error("float_attr value error."); + } + if (int64_attr != 10000000000) { + throw std::runtime_error("int64_attr value error."); + } + if (str_attr != "StrAttr") { + throw std::runtime_error("str_attr value error."); + } + + if (int_vec_attr.size() != 3) { + throw std::runtime_error("int_vec_attr size error."); + } else { + for (auto& value : int_vec_attr) { + if (value != 10) { + throw std::runtime_error("int_vec_attr value error."); + } + } + } + + if (float_vec_attr.size() != 3) { + throw std::runtime_error("float_vec_attr size error."); + } else { + for (auto& value : float_vec_attr) { + if (std::abs(value - 3.14) > 1e-6) { + throw std::runtime_error("float_vec_attr value error."); + } + } + } + + if (int64_vec_attr.size() != 3) { + throw std::runtime_error("int64_vec_attr size error."); + } else { + for (auto& value : int64_vec_attr) { + if (value != 10000000000) { + throw std::runtime_error("int64_vec_attr value error."); + } + } + } + + if (str_vec_attr.size() != 3) { + throw std::runtime_error("str_vec_attr size error."); + } else { + for (auto& value : str_vec_attr) { + if (value != "StrAttr") { + throw std::runtime_error("str_vec_attr value error."); + } + } + } + + return {out}; +} + +// The attrs of backward op must be the subset of attrs of forward op +std::vector AttrTestBackward( + const paddle::Tensor& grad_out, + int int_attr, + std::vector float_vec_attr, + std::vector str_vec_attr) { + auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU); + grad_x.reshape(grad_out.shape()); + + PD_DISPATCH_FLOATING_TYPES(grad_out.type(), "assign_cpu_kernel", ([&] { + assign_cpu_kernel( + grad_out.data(), + grad_x.mutable_data(), + grad_out.size()); + })); + + if (int_attr != 10) { + throw std::runtime_error("int_attr value error."); + } + + if (float_vec_attr.size() != 3) { + throw std::runtime_error("float_vec_attr size error."); + } else { + for (auto& value : float_vec_attr) { + if (std::abs(value - 3.14) > 1e-6) { + throw std::runtime_error("float_vec_attr value error."); + } + } + } + + if (str_vec_attr.size() != 3) { + throw std::runtime_error("str_vec_attr size error."); + } else { + for (auto& value : str_vec_attr) { + if (value != "StrAttr") { + throw std::runtime_error("str_vec_attr value error."); + } + } + } + + return {grad_x}; +} + +std::vector> InferShape(std::vector x_shape) { + return {x_shape}; +} + +std::vector InferDType(paddle::DataType x_dtype) { + return {x_dtype}; +} + +PD_BUILD_OP("attr_test") + .Inputs({"X"}) + .Outputs({"Out"}) + .Attrs({"bool_attr: bool", + "int_attr: int", + "float_attr: float", + "int64_attr: int64_t", + "str_attr: std::string", + "int_vec_attr: std::vector", + "float_vec_attr: std::vector", + "int64_vec_attr: std::vector", + "str_vec_attr: std::vector"}) + .SetKernelFn(PD_KERNEL(AttrTestForward)) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDType)) + .SetBackwardOp("attr_test_grad") + .Inputs({paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .Attrs({"int_attr: int", + "float_vec_attr: std::vector", + "str_vec_attr: std::vector"}) + .SetKernelFn(PD_KERNEL(AttrTestBackward)); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_attrs_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_attrs_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..754f76cab86f083923423652055be191982e5b14 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_custom_attrs_jit.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import numpy as np + +import paddle +from paddle.utils.cpp_extension import load, get_build_directory +from utils import paddle_includes, extra_compile_args +from paddle.utils.cpp_extension.extension_utils import run_cmd + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = '{}\\custom_attrs_jit\\custom_attrs_jit.pyd'.format(get_build_directory( +)) +if os.name == 'nt' and os.path.isfile(file): + cmd = 'del {}'.format(file) + run_cmd(cmd, True) + +# Compile and load custom op Just-In-Time. +custom_attrs = load( + name='custom_attrs_jit', + sources=['attr_test_op.cc'], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_compile_args, # add for Coverage CI + verbose=True) + + +class TestJitCustomAttrs(unittest.TestCase): + def test_attr_value(self): + paddle.set_device('cpu') + # prepare test value + bool_attr = True + int_attr = 10 + float_attr = 3.14 + int64_attr = 10000000000 + str_attr = "StrAttr" + int_vec_attr = [10, 10, 10] + float_vec_attr = [3.14, 3.14, 3.14] + int64_vec_attr = [10000000000, 10000000000, 10000000000] + str_vec_attr = ["StrAttr", "StrAttr", "StrAttr"] + + x = paddle.ones([2, 2], dtype='float32') + x.stop_gradient = False + out = custom_attrs.attr_test( + x, bool_attr, int_attr, float_attr, int64_attr, str_attr, + int_vec_attr, float_vec_attr, int64_vec_attr, str_vec_attr) + out.stop_gradient = False + out.backward() + + self.assertTrue(np.array_equal(x.numpy(), out.numpy())) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index ee8505623af224f6c4026b5451523a75a5930420..82e91c3b737b45a4a94b6a9cd1af65566c8c7237 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -85,6 +85,14 @@ information ''' USING_NEW_CUSTOM_OP_LOAD_METHOD = True +DEFAULT_OP_ATTR_NAMES = [ + core.op_proto_and_checker_maker.kOpRoleAttrName(), + core.op_proto_and_checker_maker.kOpRoleVarAttrName(), + core.op_proto_and_checker_maker.kOpNameScopeAttrName(), + core.op_proto_and_checker_maker.kOpCreationCallstackAttrName(), + core.op_proto_and_checker_maker.kOpDeviceAttrName() +] + # NOTE(chenweihang): In order to be compatible with # the two custom op define method, after removing @@ -469,8 +477,11 @@ def parse_op_info(op_name): in_names = [x.name for x in op_proto.inputs] out_names = [x.name for x in op_proto.outputs] + attr_names = [ + x.name for x in op_proto.attrs if x.name not in DEFAULT_OP_ATTR_NAMES + ] - return in_names, out_names + return in_names, out_names, attr_names def _import_module_from_library(module_name, build_directory, verbose=False): @@ -516,7 +527,7 @@ def _generate_python_module(module_name, def _custom_api_content(op_name): - params_str, ins_str, outs_str = _get_api_inputs_str(op_name) + params_str, ins_str, attrs_str, outs_str = _get_api_inputs_str(op_name) API_TEMPLATE = textwrap.dedent(""" from paddle.fluid.layer_helper import LayerHelper @@ -526,6 +537,7 @@ def _custom_api_content(op_name): # prepare inputs and outputs ins = {ins} + attrs = {attrs} outs = {{}} out_names = {out_names} for out_name in out_names: @@ -533,7 +545,7 @@ def _custom_api_content(op_name): # in runtime. outs[out_name] = helper.create_variable(dtype='float32') - helper.append_op(type="{op_name}", inputs=ins, outputs=outs) + helper.append_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs) res = [outs[out_name] for out_name in out_names] @@ -542,7 +554,11 @@ def _custom_api_content(op_name): # generate python api file api_content = API_TEMPLATE.format( - op_name=op_name, inputs=params_str, ins=ins_str, out_names=outs_str) + op_name=op_name, + inputs=params_str, + ins=ins_str, + attrs=attrs_str, + out_names=outs_str) return api_content @@ -573,15 +589,21 @@ def _get_api_inputs_str(op_name): """ Returns string of api parameters and inputs dict. """ - in_names, out_names = parse_op_info(op_name) + in_names, out_names, attr_names = parse_op_info(op_name) # e.g: x, y, z - params_str = ','.join([p.lower() for p in in_names]) + param_names = in_names + attr_names + params_str = ','.join([p.lower() for p in param_names]) # e.g: {'X': x, 'Y': y, 'Z': z} ins_str = "{%s}" % ','.join( ["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names]) + # e.g: {'num': n} + attrs_str = "{%s}" % ",".join([ + "'{}' : {}".format(attr_name, attr_name.lower()) + for attr_name in attr_names + ]) # e.g: ['Out', 'Index'] outs_str = "[%s]" % ','.join(["'{}'".format(name) for name in out_names]) - return params_str, ins_str, outs_str + return params_str, ins_str, attrs_str, outs_str def _write_setup_file(name,