未验证 提交 e8cdb49a 编写于 作者: C Chen Weihang 提交者: GitHub

[CustomOp] Support attributes as func input in custom op (#31128)

* add simple attr support and test

* add int, float attr support

* support other attribute

* add custom attrs test in cmake

* polish details

* fix test failed

* add backward test

* update test flags
上级 ffbf7135
......@@ -81,6 +81,26 @@ inline std::string Grad(const std::string& var_name) {
using KernelFunc = std::vector<Tensor> (*)(std::vector<Tensor> inputs,
std::vector<boost::any> attrs);
#define PD_SPECIALIZE_ComputeCallHelper(attr_type) \
template <typename... Tail> \
struct ComputeCallHelper<attr_type, Tail...> { \
template <int in_idx, int attr_idx, typename... PreviousArgs> \
static Return Compute(std::vector<Tensor> inputs, \
std::vector<boost::any> attrs, \
const PreviousArgs&... pargs) { \
try { \
attr_type arg = boost::any_cast<attr_type>(attrs[attr_idx]); \
return ComputeCallHelper<Tail...>::template Compute<in_idx, \
attr_idx + 1>( \
inputs, attrs, pargs..., arg); \
} catch (boost::bad_any_cast&) { \
PD_THROW( \
"Attribute cast error in custom operator. Expected " #attr_type \
" value."); \
} \
} \
}
template <typename T>
struct TypeTag {};
......@@ -114,26 +134,20 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
}
};
// TODO(chenweihang): add support for attribute input
// int attribute input (not used now)
template <typename... Tail>
struct ComputeCallHelper<int, Tail...> {
template <int in_idx, int attr_idx, typename... PreviousArgs>
static Return Compute(std::vector<Tensor> inputs,
std::vector<boost::any> attrs,
const PreviousArgs&... pargs) {
try {
int arg = boost::any_cast<int>(attrs[attr_idx]);
return ComputeCallHelper<Tail...>::template Compute<in_idx,
attr_idx + 1>(
inputs, attrs, pargs..., arg);
} catch (boost::bad_any_cast&) {
PD_THROW(
"Attribute cast error in custom operator. Expected int value.");
}
}
};
PD_SPECIALIZE_ComputeCallHelper(bool);
PD_SPECIALIZE_ComputeCallHelper(int);
PD_SPECIALIZE_ComputeCallHelper(float);
PD_SPECIALIZE_ComputeCallHelper(int64_t);
PD_SPECIALIZE_ComputeCallHelper(std::string);
PD_SPECIALIZE_ComputeCallHelper(std::vector<int>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<float>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<int64_t>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<std::string>);
// TODO(chenweihang): support other attribute type if needed.
// Why not support other attribute type here?
// - boost::blank, std::vector<bool> and std::vector<double>
// are not used in op
// - BlockDesc* and std::vector<BlockDesc*> are used in framework
// end: base template
template <typename T>
struct ComputeCallHelper<TypeTag<T>> {
......@@ -245,10 +259,23 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
class PD_DLL_DECL OpMetaInfo {
public:
explicit OpMetaInfo(const std::string& op_name) : name_(op_name) {}
// format: {"<name1>", "<name2>", ...}
OpMetaInfo& Inputs(std::vector<std::string>&& inputs);
// format: {"<name1>", "<name2>", ...}
OpMetaInfo& Outputs(std::vector<std::string>&& outputs);
// format: {"<name1>:<type1>", "<name1>:<type1>", ...}
OpMetaInfo& Attrs(std::vector<std::string>&& attrs);
// format: PD_KERNEL(...)
OpMetaInfo& SetKernelFn(KernelFunc&& func);
// format: PD_INFER_SHAPE(...)
OpMetaInfo& SetInferShapeFn(InferShapeFunc&& func);
// format: PD_INFER_DTYPE(...)
OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func);
private:
......@@ -297,6 +324,7 @@ class PD_DLL_DECL OpMetaInfoBuilder {
explicit OpMetaInfoBuilder(std::string&& name);
OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs);
OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs);
OpMetaInfoBuilder& Attrs(std::vector<std::string>&& attrs);
OpMetaInfoBuilder& SetKernelFn(KernelFunc func);
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func);
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
......
......@@ -32,6 +32,10 @@ OpMetaInfo& OpMetaInfo::Outputs(std::vector<std::string>&& outputs) {
outputs_ = std::forward<std::vector<std::string>>(outputs);
return *this;
}
OpMetaInfo& OpMetaInfo::Attrs(std::vector<std::string>&& attrs) {
attrs_ = std::forward<std::vector<std::string>>(attrs);
return *this;
}
OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) {
kernel_fn_ = std::forward<KernelFunc>(func);
return *this;
......@@ -78,6 +82,11 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs(
return *this;
}
OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector<std::string>&& attrs) {
info_ptr_->Attrs(std::forward<std::vector<std::string>>(attrs));
return *this;
}
OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
info_ptr_->SetKernelFn(std::forward<KernelFunc>(func));
return *this;
......
......@@ -73,6 +73,24 @@ inline bool IsMemberOf(const std::vector<std::string>& vec,
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
}
std::vector<std::string> ParseAttrStr(const std::string& attr) {
auto split_pos = attr.find_first_of(":");
PADDLE_ENFORCE_NE(split_pos, std::string::npos,
platform::errors::InvalidArgument(
"Invalid attribute string format. Attribute string "
"format is `<name>:<type>`."));
std::vector<std::string> rlt;
// 1. name
rlt.emplace_back(string::trim_spaces(attr.substr(0, split_pos)));
// 2. type
rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1)));
VLOG(1) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1];
return rlt;
}
} // namespace detail
////////////////// Kernel Define ////////////////////
......@@ -81,7 +99,8 @@ inline bool IsMemberOf(const std::vector<std::string>& vec,
static void RunKernelFunc(const framework::ExecutionContext& ctx,
const paddle::KernelFunc& func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
VLOG(1) << "Custom Operator: Start run KernelFunc.";
std::vector<paddle::Tensor> custom_ins;
for (auto& in_name : inputs) {
......@@ -98,10 +117,43 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
custom_ins.emplace_back(custom_in);
}
std::vector<boost::any> attrs;
std::vector<boost::any> custom_attrs;
for (auto& attr_str : attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx.Attr<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx.Attr<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx.Attr<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx.Attr<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx.Attr<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<int64_t>>(attr_name));
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>, "
"`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.",
attr_type_str));
}
}
VLOG(1) << "Run ComputeFunc.";
auto outs = func(custom_ins, attrs);
auto outs = func(custom_ins, custom_attrs);
VLOG(1) << "Custom Operator: Share outputs into ExecutionContext.";
for (size_t i = 0; i < outputs.size(); ++i) {
......@@ -164,7 +216,51 @@ class CustomOpMaker : public OpProtoAndCheckerMaker {
for (auto& out_name : outputs_) {
AddOutput(out_name, "The output " + out_name + "of Custom Operator.");
}
// TODO(chenweihang): support attrs in later PR
for (auto& attr : attrs_) {
auto attr_name_and_type = detail::ParseAttrStr(attr);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
AddAttr<bool>(attr_name, "custom operator bool attribute.")
.SetDefault(false);
} else if (attr_type_str == "int") {
AddAttr<int>(attr_name, "custom operator int attribute.").SetDefault(1);
} else if (attr_type_str == "float") {
AddAttr<float>(attr_name, "custom operator float attribute.")
.SetDefault(1.0f);
} else if (attr_type_str == "int64_t") {
AddAttr<int64_t>(attr_name, "custom operator int64_t attribute.")
.SetDefault(1);
} else if (attr_type_str == "std::string") {
AddAttr<std::string>(attr_name, "custom operator int attribute.")
.SetDefault("");
} else if (attr_type_str == "std::vector<int>") {
AddAttr<std::vector<int>>(attr_name,
"custom operator std::vector<int> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<float>") {
AddAttr<std::vector<float>>(
attr_name, "custom operator std::vector<float> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<int64_t>") {
AddAttr<std::vector<int64_t>>(
attr_name, "custom operator std::vector<int64_t> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<std::string>") {
AddAttr<std::vector<std::string>>(
attr_name, "custom operator std::vector<std::string> attribute.")
.SetDefault({});
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>, "
"`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.",
attr_type_str));
}
}
AddComment(R"DOC(
Custom Operator.
......@@ -227,7 +323,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
VLOG(1) << "Custom Operator: GradOpDescMaker - output: " << out_name;
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
}
// TODO(chenweihang): support attrs in later PR
grad_op->SetAttrMap(this->Attrs());
}
private:
......@@ -287,7 +383,7 @@ class CustomGradOpMaker<imperative::OpBase>
VLOG(1) << "Custom Operator: GradOpBaseMaker - output: " << out_name;
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
}
// TODO(chenweihang): support attrs in later PR
grad_op->SetAttrMap(this->Attrs());
}
private:
......@@ -303,21 +399,24 @@ void RegisterOperatorKernelWithPlace(const std::string& name,
const proto::VarType::Type type,
const PlaceType& place,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
OpKernelType key(type,
CustomTensorUtils::ConvertEnumPlaceToInnerPlace(place));
VLOG(1) << "Custom Operator: op kernel key: " << key;
OperatorWithKernel::AllOpKernels()[name][key] =
[kernel_func, inputs, outputs](const framework::ExecutionContext& ctx) {
[kernel_func, inputs, outputs,
attrs](const framework::ExecutionContext& ctx) {
VLOG(1) << "Custom Operator: run custom kernel func in lambda.";
RunKernelFunc(ctx, kernel_func, inputs, outputs);
RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs);
};
}
void RegisterOperatorKernel(const std::string& name,
const paddle::KernelFunc& kernel_func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
VLOG(1) << "Custom Operator: op name in kernel: " << name;
// NOTE [ Dummy Op Kernel Key ]
// TODO(chenweihang): Because execute engine need get device context based
......@@ -325,9 +424,11 @@ void RegisterOperatorKernel(const std::string& name,
// device. But this is not entirely correct, if user only give a cpu kernel,
// but call api in gpu device, it will cause error.
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW,
PlaceType::kCPU, inputs, outputs);
PlaceType::kCPU, inputs, outputs, attrs);
#ifdef PADDLE_WITH_CUDA
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW,
PlaceType::kGPU, inputs, outputs);
PlaceType::kGPU, inputs, outputs, attrs);
#endif
}
void RegisterOperatorWithMetaInfo(
......@@ -350,6 +451,8 @@ void RegisterOperatorWithMetaInfo(
<< string::join_strings(op_inputs, ',');
VLOG(1) << "Custom Operator: forward, op outputs: "
<< string::join_strings(op_outputs, ',');
VLOG(1) << "Custom Operator: forward, op attrs: "
<< string::join_strings(op_attrs, ',');
// Op
info.creator_ = [](const std::string& op_name, const VariableNameMap& inputs,
......@@ -426,7 +529,7 @@ void RegisterOperatorWithMetaInfo(
};
// Kernel func
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs);
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs);
// If grad op or double grad op exists
std::string cur_op_name = op_name;
......@@ -436,6 +539,7 @@ void RegisterOperatorWithMetaInfo(
auto& grad_op_name = OpMetaInfoHelper::GetOpName(cur_grad_op);
auto& grad_op_inputs = OpMetaInfoHelper::GetInputs(cur_grad_op);
auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op);
auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op);
auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op);
VLOG(1) << "Custom Operator: backward, op name: " << grad_op_name;
......@@ -489,7 +593,7 @@ void RegisterOperatorWithMetaInfo(
// Kernel func
RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs,
grad_op_outputs);
grad_op_outputs, grad_op_attrs);
// update current info
OpInfoMap::Instance().Insert(cur_op_name, info);
......
......@@ -13,10 +13,13 @@ py_test(test_sysconfig SRCS test_sysconfig.py)
# 'test_dispatch' compile .cc file
py_test(test_dispatch_jit SRCS test_dispatch_jit.py)
set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 180)
set_tests_properties(test_dispatch_jit PROPERTIES TIMEOUT 120)
py_test(test_multi_out_jit SRCS test_multi_out_jit.py)
set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 180)
set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120)
py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py)
set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120)
if(NOT LINUX)
return()
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdlib>
#include <iostream>
#include <vector>
#include "paddle/extension.h"
template <typename data_t>
void assign_cpu_kernel(const data_t* x_data,
data_t* out_data,
int64_t x_numel) {
for (int i = 0; i < x_numel; ++i) {
out_data[i] = x_data[i];
}
}
std::vector<paddle::Tensor> AttrTestForward(
const paddle::Tensor& x,
bool bool_attr,
int int_attr,
float float_attr,
int64_t int64_attr,
std::string str_attr,
std::vector<int> int_vec_attr,
std::vector<float> float_vec_attr,
std::vector<int64_t> int64_vec_attr,
std::vector<std::string> str_vec_attr) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
// Check attrs value
if (bool_attr != true) {
throw std::runtime_error("bool_attr value error.");
}
if (int_attr != 10) {
throw std::runtime_error("int_attr value error.");
}
if (std::abs(float_attr - 3.14) > 1e-6) {
throw std::runtime_error("float_attr value error.");
}
if (int64_attr != 10000000000) {
throw std::runtime_error("int64_attr value error.");
}
if (str_attr != "StrAttr") {
throw std::runtime_error("str_attr value error.");
}
if (int_vec_attr.size() != 3) {
throw std::runtime_error("int_vec_attr size error.");
} else {
for (auto& value : int_vec_attr) {
if (value != 10) {
throw std::runtime_error("int_vec_attr value error.");
}
}
}
if (float_vec_attr.size() != 3) {
throw std::runtime_error("float_vec_attr size error.");
} else {
for (auto& value : float_vec_attr) {
if (std::abs(value - 3.14) > 1e-6) {
throw std::runtime_error("float_vec_attr value error.");
}
}
}
if (int64_vec_attr.size() != 3) {
throw std::runtime_error("int64_vec_attr size error.");
} else {
for (auto& value : int64_vec_attr) {
if (value != 10000000000) {
throw std::runtime_error("int64_vec_attr value error.");
}
}
}
if (str_vec_attr.size() != 3) {
throw std::runtime_error("str_vec_attr size error.");
} else {
for (auto& value : str_vec_attr) {
if (value != "StrAttr") {
throw std::runtime_error("str_vec_attr value error.");
}
}
}
return {out};
}
// The attrs of backward op must be the subset of attrs of forward op
std::vector<paddle::Tensor> AttrTestBackward(
const paddle::Tensor& grad_out,
int int_attr,
std::vector<float> float_vec_attr,
std::vector<std::string> str_vec_attr) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU);
grad_x.reshape(grad_out.shape());
PD_DISPATCH_FLOATING_TYPES(grad_out.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
grad_out.data<data_t>(),
grad_x.mutable_data<data_t>(),
grad_out.size());
}));
if (int_attr != 10) {
throw std::runtime_error("int_attr value error.");
}
if (float_vec_attr.size() != 3) {
throw std::runtime_error("float_vec_attr size error.");
} else {
for (auto& value : float_vec_attr) {
if (std::abs(value - 3.14) > 1e-6) {
throw std::runtime_error("float_vec_attr value error.");
}
}
}
if (str_vec_attr.size() != 3) {
throw std::runtime_error("str_vec_attr size error.");
} else {
for (auto& value : str_vec_attr) {
if (value != "StrAttr") {
throw std::runtime_error("str_vec_attr value error.");
}
}
}
return {grad_x};
}
std::vector<std::vector<int64_t>> InferShape(std::vector<int64_t> x_shape) {
return {x_shape};
}
std::vector<paddle::DataType> InferDType(paddle::DataType x_dtype) {
return {x_dtype};
}
PD_BUILD_OP("attr_test")
.Inputs({"X"})
.Outputs({"Out"})
.Attrs({"bool_attr: bool",
"int_attr: int",
"float_attr: float",
"int64_attr: int64_t",
"str_attr: std::string",
"int_vec_attr: std::vector<int>",
"float_vec_attr: std::vector<float>",
"int64_vec_attr: std::vector<int64_t>",
"str_vec_attr: std::vector<std::string>"})
.SetKernelFn(PD_KERNEL(AttrTestForward))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType))
.SetBackwardOp("attr_test_grad")
.Inputs({paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.Attrs({"int_attr: int",
"float_vec_attr: std::vector<float>",
"str_vec_attr: std::vector<std::string>"})
.SetKernelFn(PD_KERNEL(AttrTestBackward));
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
import paddle
from paddle.utils.cpp_extension import load, get_build_directory
from utils import paddle_includes, extra_compile_args
from paddle.utils.cpp_extension.extension_utils import run_cmd
# Because Windows don't use docker, the shared lib already exists in the
# cache dir, it will not be compiled again unless the shared lib is removed.
file = '{}\\custom_attrs_jit\\custom_attrs_jit.pyd'.format(get_build_directory(
))
if os.name == 'nt' and os.path.isfile(file):
cmd = 'del {}'.format(file)
run_cmd(cmd, True)
# Compile and load custom op Just-In-Time.
custom_attrs = load(
name='custom_attrs_jit',
sources=['attr_test_op.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=extra_compile_args, # add for Coverage CI
verbose=True)
class TestJitCustomAttrs(unittest.TestCase):
def test_attr_value(self):
paddle.set_device('cpu')
# prepare test value
bool_attr = True
int_attr = 10
float_attr = 3.14
int64_attr = 10000000000
str_attr = "StrAttr"
int_vec_attr = [10, 10, 10]
float_vec_attr = [3.14, 3.14, 3.14]
int64_vec_attr = [10000000000, 10000000000, 10000000000]
str_vec_attr = ["StrAttr", "StrAttr", "StrAttr"]
x = paddle.ones([2, 2], dtype='float32')
x.stop_gradient = False
out = custom_attrs.attr_test(
x, bool_attr, int_attr, float_attr, int64_attr, str_attr,
int_vec_attr, float_vec_attr, int64_vec_attr, str_vec_attr)
out.stop_gradient = False
out.backward()
self.assertTrue(np.array_equal(x.numpy(), out.numpy()))
if __name__ == '__main__':
unittest.main()
......@@ -85,6 +85,14 @@ information
'''
USING_NEW_CUSTOM_OP_LOAD_METHOD = True
DEFAULT_OP_ATTR_NAMES = [
core.op_proto_and_checker_maker.kOpRoleAttrName(),
core.op_proto_and_checker_maker.kOpRoleVarAttrName(),
core.op_proto_and_checker_maker.kOpNameScopeAttrName(),
core.op_proto_and_checker_maker.kOpCreationCallstackAttrName(),
core.op_proto_and_checker_maker.kOpDeviceAttrName()
]
# NOTE(chenweihang): In order to be compatible with
# the two custom op define method, after removing
......@@ -469,8 +477,11 @@ def parse_op_info(op_name):
in_names = [x.name for x in op_proto.inputs]
out_names = [x.name for x in op_proto.outputs]
attr_names = [
x.name for x in op_proto.attrs if x.name not in DEFAULT_OP_ATTR_NAMES
]
return in_names, out_names
return in_names, out_names, attr_names
def _import_module_from_library(module_name, build_directory, verbose=False):
......@@ -516,7 +527,7 @@ def _generate_python_module(module_name,
def _custom_api_content(op_name):
params_str, ins_str, outs_str = _get_api_inputs_str(op_name)
params_str, ins_str, attrs_str, outs_str = _get_api_inputs_str(op_name)
API_TEMPLATE = textwrap.dedent("""
from paddle.fluid.layer_helper import LayerHelper
......@@ -526,6 +537,7 @@ def _custom_api_content(op_name):
# prepare inputs and outputs
ins = {ins}
attrs = {attrs}
outs = {{}}
out_names = {out_names}
for out_name in out_names:
......@@ -533,7 +545,7 @@ def _custom_api_content(op_name):
# in runtime.
outs[out_name] = helper.create_variable(dtype='float32')
helper.append_op(type="{op_name}", inputs=ins, outputs=outs)
helper.append_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs)
res = [outs[out_name] for out_name in out_names]
......@@ -542,7 +554,11 @@ def _custom_api_content(op_name):
# generate python api file
api_content = API_TEMPLATE.format(
op_name=op_name, inputs=params_str, ins=ins_str, out_names=outs_str)
op_name=op_name,
inputs=params_str,
ins=ins_str,
attrs=attrs_str,
out_names=outs_str)
return api_content
......@@ -573,15 +589,21 @@ def _get_api_inputs_str(op_name):
"""
Returns string of api parameters and inputs dict.
"""
in_names, out_names = parse_op_info(op_name)
in_names, out_names, attr_names = parse_op_info(op_name)
# e.g: x, y, z
params_str = ','.join([p.lower() for p in in_names])
param_names = in_names + attr_names
params_str = ','.join([p.lower() for p in param_names])
# e.g: {'X': x, 'Y': y, 'Z': z}
ins_str = "{%s}" % ','.join(
["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names])
# e.g: {'num': n}
attrs_str = "{%s}" % ",".join([
"'{}' : {}".format(attr_name, attr_name.lower())
for attr_name in attr_names
])
# e.g: ['Out', 'Index']
outs_str = "[%s]" % ','.join(["'{}'".format(name) for name in out_names])
return params_str, ins_str, outs_str
return params_str, ins_str, attrs_str, outs_str
def _write_setup_file(name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册