未验证 提交 d23bf89c 编写于 作者: L Leo Chen 提交者: GitHub

support list of list attribute for NPU (#31299)

* support list of list attribute for NPU

* fix compile problem

* fix reference
上级 77a0c41c
...@@ -45,6 +45,17 @@ using Attribute = boost::variant< ...@@ -45,6 +45,17 @@ using Attribute = boost::variant<
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
#ifdef PADDLE_WITH_ASCEND_CL
using NPUAttribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*, int64_t,
std::vector<BlockDesc*>, std::vector<int64_t>,
std::vector<double>, std::vector<std::vector<int64_t>>>;
using NPUAttributeMap = std::unordered_map<std::string, NPUAttribute>;
#endif
using OpCreator = std::function<OperatorBase*( using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/, const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
......
...@@ -70,7 +70,7 @@ NpuOpRunner::NpuOpRunner(std::string op_type) : op_type_(op_type) { ...@@ -70,7 +70,7 @@ NpuOpRunner::NpuOpRunner(std::string op_type) : op_type_(op_type) {
NpuOpRunner::NpuOpRunner(std::string op_type, const std::vector<Tensor> &inputs, NpuOpRunner::NpuOpRunner(std::string op_type, const std::vector<Tensor> &inputs,
const std::vector<Tensor> &outputs, const std::vector<Tensor> &outputs,
const AttributeMap &attrs) const NPUAttributeMap &attrs)
: op_type_(op_type) { : op_type_(op_type) {
attr_ = aclopCreateAttr(); attr_ = aclopCreateAttr();
AddInputs(inputs); AddInputs(inputs);
...@@ -85,7 +85,7 @@ NpuOpRunner::~NpuOpRunner() { ...@@ -85,7 +85,7 @@ NpuOpRunner::~NpuOpRunner() {
const std::string &NpuOpRunner::Type() { return op_type_; } const std::string &NpuOpRunner::Type() { return op_type_; }
NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name, NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name,
const Attribute &attr) { const NPUAttribute &attr) {
if (attr.type() == typeid(bool)) { if (attr.type() == typeid(bool)) {
PADDLE_ENFORCE_NPU_SUCCESS( PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrBool(attr_, name.c_str(), BOOST_GET_CONST(bool, attr))); aclopSetAttrBool(attr_, name.c_str(), BOOST_GET_CONST(bool, attr)));
...@@ -135,6 +135,16 @@ NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name, ...@@ -135,6 +135,16 @@ NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name,
} }
PADDLE_ENFORCE_NPU_SUCCESS( PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrListString(attr_, name.c_str(), s.size(), s.data())); aclopSetAttrListString(attr_, name.c_str(), s.size(), s.data()));
} else if (attr.type() == typeid(std::vector<std::vector<int64_t>>)) {
auto a = BOOST_GET_CONST(std::vector<std::vector<int64_t>>, attr);
std::vector<int64_t *> data;
std::vector<int> num;
for (auto &&v : a) {
data.push_back(v.data());
num.push_back(v.size());
}
PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrListListInt(attr_, name.c_str(), data.size(), num.data(), data.data()));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Can not convert attribubte '%s' to convert to aclopAttr", name)); "Can not convert attribubte '%s' to convert to aclopAttr", name));
...@@ -142,7 +152,7 @@ NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name, ...@@ -142,7 +152,7 @@ NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name,
return *this; return *this;
} }
NpuOpRunner &NpuOpRunner::AddAttrs(const AttributeMap &attrs) { NpuOpRunner &NpuOpRunner::AddAttrs(const NPUAttributeMap &attrs) {
for (const auto &pair : attrs) { for (const auto &pair : attrs) {
AddAttr(pair.first, pair.second); AddAttr(pair.first, pair.second);
} }
......
...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#pragma once #pragma once
#include <paddle/fluid/framework/operator.h> #include <paddle/fluid/framework/operator.h>
#include <paddle/fluid/framework/type_defs.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -26,8 +28,8 @@ namespace operators { ...@@ -26,8 +28,8 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout; using DataLayout = framework::DataLayout;
using Attribute = framework::Attribute; using NPUAttribute = framework::NPUAttribute;
using AttributeMap = framework::AttributeMap; using NPUAttributeMap = framework::NPUAttributeMap;
class NpuOpRunner { class NpuOpRunner {
public: public:
...@@ -35,15 +37,15 @@ class NpuOpRunner { ...@@ -35,15 +37,15 @@ class NpuOpRunner {
explicit NpuOpRunner(std::string op_type, explicit NpuOpRunner(std::string op_type,
const std::vector<Tensor> &inputs = {}, const std::vector<Tensor> &inputs = {},
const std::vector<Tensor> &outputs = {}, const std::vector<Tensor> &outputs = {},
const AttributeMap &attrs = {}); const NPUAttributeMap &attrs = {});
~NpuOpRunner(); ~NpuOpRunner();
const std::string &Type(); const std::string &Type();
NpuOpRunner &AddAttr(const std::string &name, const Attribute &attr); NpuOpRunner &AddAttr(const std::string &name, const NPUAttribute &attr);
NpuOpRunner &AddAttrs(const AttributeMap &attrs); NpuOpRunner &AddAttrs(const NPUAttributeMap &attrs);
NpuOpRunner &AddInput(const Tensor &tensor); NpuOpRunner &AddInput(const Tensor &tensor);
...@@ -82,3 +84,4 @@ class NpuOpRunner { ...@@ -82,3 +84,4 @@ class NpuOpRunner {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册