未验证 提交 7ac2f80f 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add attr method for ArgumentMappingContext (#39130)

* add attr for arg map context

* add argument fn declare

* add attr test for get attr value method

* polish details
上级 ff7f9d06
......@@ -167,6 +167,7 @@ cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
framework_proto selected_rows_utils data_device_transform data_type_transform data_layout_transform)
cc_library(attribute SRCS attribute.cc DEPS framework_proto boost enforce)
cc_test(attribute_test SRCS attribute_test.cc DEPS attribute framework_proto proto_desc)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context)
......
......@@ -17,6 +17,39 @@ limitations under the License. */
namespace paddle {
namespace framework {
paddle::any GetAttrValue(const Attribute& attr) {
if (attr.type() == typeid(int)) {
return paddle::any(BOOST_GET_CONST(int, attr));
} else if (attr.type() == typeid(float)) {
return paddle::any(BOOST_GET_CONST(float, attr));
} else if (attr.type() == typeid(std::string)) {
return paddle::any(BOOST_GET_CONST(std::string, attr));
} else if (attr.type() == typeid(std::vector<int>)) {
return paddle::any(BOOST_GET_CONST(std::vector<int>, attr));
} else if (attr.type() == typeid(std::vector<float>)) {
return paddle::any(BOOST_GET_CONST(std::vector<float>, attr));
} else if (attr.type() == typeid(std::vector<std::string>)) {
return paddle::any(BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr.type() == typeid(bool)) {
return paddle::any(BOOST_GET_CONST(bool, attr));
} else if (attr.type() == typeid(std::vector<bool>)) {
return paddle::any(BOOST_GET_CONST(std::vector<bool>, attr));
} else if (attr.type() == typeid(BlockDesc*)) {
return paddle::any(BOOST_GET_CONST(BlockDesc*, attr));
} else if (attr.type() == typeid(int64_t)) {
return paddle::any(BOOST_GET_CONST(int64_t, attr));
} else if (attr.type() == typeid(std::vector<BlockDesc*>)) {
return paddle::any(BOOST_GET_CONST(std::vector<BlockDesc*>, attr));
} else if (attr.type() == typeid(std::vector<int64_t>)) {
return paddle::any(BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (attr.type() == typeid(std::vector<double>)) {
return paddle::any(BOOST_GET_CONST(std::vector<double>, attr));
} else {
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported Attribute value type."));
}
}
Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) {
case proto::AttrType::BOOLEAN: {
......
......@@ -27,10 +27,15 @@ limitations under the License. */
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/utils/any.h"
namespace paddle {
namespace framework {
paddle::any GetAttrValue(const Attribute& attr);
Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
template <typename T>
struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name)
......@@ -204,8 +209,6 @@ inline proto::AttrType AttrTypeID() {
return static_cast<proto::AttrType>(tmp.which() - 1);
}
Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
class AttrReader {
public:
explicit AttrReader(const AttributeMap& attrs)
......@@ -234,6 +237,22 @@ class AttrReader {
return *attr_value;
}
inline const Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name);
bool found = it != attrs_.end();
if (!found) {
if (default_attrs_ != nullptr) {
it = default_attrs_->find(name);
found = it != default_attrs_->end();
}
}
PADDLE_ENFORCE_EQ(found, true,
platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));
return it->second;
}
private:
const AttributeMap& attrs_;
const AttributeMap* default_attrs_;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/program_desc.h"
#include "gtest/gtest.h"
#include "paddle/utils/any.h"
TEST(Attribute, GetAttrValueToAny) {
paddle::framework::Attribute x_int(100);
auto rlt_int = paddle::framework::GetAttrValue(x_int);
EXPECT_EQ(paddle::any_cast<int>(rlt_int), 100);
float float_value = 3.14;
paddle::framework::Attribute x_float(float_value);
auto rlt_float = paddle::framework::GetAttrValue(x_float);
EXPECT_NEAR(paddle::any_cast<float>(rlt_float), 3.14, 1e-6);
std::string str_value("test");
paddle::framework::Attribute x_str(str_value);
auto rlt_str = paddle::framework::GetAttrValue(x_str);
EXPECT_EQ(paddle::any_cast<std::string>(rlt_str), "test");
std::vector<int> vec_int_var(2, 100);
paddle::framework::Attribute x_vec_int = vec_int_var;
auto rlt_vec_int = paddle::framework::GetAttrValue(x_vec_int);
auto vec_int = paddle::any_cast<std::vector<int>>(rlt_vec_int);
EXPECT_EQ(vec_int.size(), 2UL);
EXPECT_EQ(vec_int[0], 100);
EXPECT_EQ(vec_int[1], 100);
std::vector<float> vec_float_var(2, 3.14);
paddle::framework::Attribute x_vec_float = vec_float_var;
auto rlt_vec_float = paddle::framework::GetAttrValue(x_vec_float);
auto vec_float = paddle::any_cast<std::vector<float>>(rlt_vec_float);
EXPECT_EQ(vec_float.size(), 2UL);
EXPECT_NEAR(vec_float[0], 3.14, 1e-6);
EXPECT_NEAR(vec_float[1], 3.14, 1e-6);
std::vector<std::string> vec_str_var(2, "test");
paddle::framework::Attribute x_vec_str = vec_str_var;
auto rlt_vec_str = paddle::framework::GetAttrValue(x_vec_str);
auto vec_str = paddle::any_cast<std::vector<std::string>>(rlt_vec_str);
EXPECT_EQ(vec_str.size(), 2UL);
EXPECT_EQ(vec_str[0], "test");
EXPECT_EQ(vec_str[1], "test");
paddle::framework::Attribute x_bool(true);
auto rlt_bool = paddle::framework::GetAttrValue(x_bool);
EXPECT_EQ(paddle::any_cast<bool>(rlt_bool), true);
std::vector<bool> vec_bool_var(2, true);
paddle::framework::Attribute x_vec_bool = vec_bool_var;
auto rlt_vec_bool = paddle::framework::GetAttrValue(x_vec_bool);
auto vec_bool = paddle::any_cast<std::vector<bool>>(rlt_vec_bool);
EXPECT_EQ(vec_bool.size(), 2UL);
EXPECT_EQ(vec_bool[0], true);
EXPECT_EQ(vec_bool[1], true);
paddle::framework::ProgramDesc prog;
paddle::framework::proto::BlockDesc proto_block;
paddle::framework::BlockDesc block_desc(&prog, &proto_block);
paddle::framework::Attribute x_block_desc(&block_desc);
auto rlt_block_desc = paddle::framework::GetAttrValue(x_block_desc);
auto block_desc_ptr =
paddle::any_cast<paddle::framework::BlockDesc*>(rlt_block_desc);
EXPECT_NE(block_desc_ptr, nullptr);
std::vector<paddle::framework::BlockDesc*> vec_block_desc_var;
vec_block_desc_var.emplace_back(&block_desc);
paddle::framework::Attribute x_vec_block_desc(vec_block_desc_var);
auto rlt_vec_block_desc = paddle::framework::GetAttrValue(x_vec_block_desc);
auto vec_block_desc =
paddle::any_cast<std::vector<paddle::framework::BlockDesc*>>(
rlt_vec_block_desc);
EXPECT_EQ(vec_block_desc.size(), 1UL);
EXPECT_NE(vec_block_desc[0], nullptr);
int64_t int64_value = 100;
paddle::framework::Attribute x_int64(int64_value);
auto rlt_int64 = paddle::framework::GetAttrValue(x_int64);
EXPECT_EQ(paddle::any_cast<int64_t>(rlt_int64), 100);
std::vector<int64_t> vec_int64_var(2, 100);
paddle::framework::Attribute x_vec_int64 = vec_int64_var;
auto rlt_vec_int64 = paddle::framework::GetAttrValue(x_vec_int64);
auto vec_int64 = paddle::any_cast<std::vector<int64_t>>(rlt_vec_int64);
EXPECT_EQ(vec_int64.size(), 2UL);
EXPECT_EQ(vec_int64[0], 100);
EXPECT_EQ(vec_int64[1], 100);
std::vector<double> vec_double_var(2, 3.14);
paddle::framework::Attribute x_vec_double = vec_double_var;
auto rlt_vec_double = paddle::framework::GetAttrValue(x_vec_double);
auto vec_double = paddle::any_cast<std::vector<double>>(rlt_vec_double);
EXPECT_EQ(vec_double.size(), 2UL);
EXPECT_NEAR(vec_double[0], 3.14, 1e-6);
EXPECT_NEAR(vec_double[1], 3.14, 1e-6);
}
......@@ -40,7 +40,7 @@ limitations under the License. */
#include "paddle/fluid/platform/variant.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/pten/core/arg_map_context.h"
#include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_factory.h"
......@@ -454,8 +454,9 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
return ctx_.HasOutput(name);
}
bool HasAttr(const std::string& name) const override {
return ctx_.HasAttr(name);
paddle::any Attr(const std::string& name) const override {
auto& attr = ctx_.GetAttr(name);
return GetAttrValue(attr);
}
size_t InputSize(const std::string& name) const override {
......
......@@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/core/arg_map_context.h"
#include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"
......
# utils used for compatible for fluid op system
add_subdirectory(compat)
if(WITH_GPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
elseif(WITH_ROCM)
......@@ -8,7 +11,6 @@ endif()
cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce convert_utils)
cc_library(kernel_context SRCS kernel_context.cc DEPS enforce pten_context)
cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce)
cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce)
cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector)
......
cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce)
......@@ -12,41 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/arg_map_context.h"
#include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"
namespace pten {
OpArgumentMappingFnMap& OpArgumentMappingFnMap::Instance() {
static OpArgumentMappingFnMap g_op_arg_mapping_fn_map;
return g_op_arg_mapping_fn_map;
}
bool OpArgumentMappingFnMap::Has(const std::string& op_type) const {
return fn_map_.find(op_type) != fn_map_.end();
}
const ArgumentMappingFn& OpArgumentMappingFnMap::Get(
const std::string& op_type) const {
auto it = fn_map_.find(op_type);
PADDLE_ENFORCE_NE(
it,
fn_map_.end(),
paddle::platform::errors::NotFound(
"Operator `%s`'s argument mapping funciton is not registered.",
op_type));
return it->second;
}
void OpArgumentMappingFnMap::Emplace(const std::string& op_type,
const std::string api_name,
ArgumentMappingFn fn) {
name_map_.emplace(op_type, api_name);
fn_map_.emplace(op_type, fn);
}
std::ostream& operator<<(std::ostream& os, KernelSignature signature) {
os << "Kernel Signature - name: " << signature.name << "; inputs: "
<< paddle::string::join_strings(std::get<0>(signature.args), ", ")
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <string>
#include <tuple>
#include "paddle/utils/any.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"
......@@ -28,22 +29,6 @@ using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>>;
// TODO(chenweihang): Add more methods if needed in future
class ArgumentMappingContext {
public:
virtual ~ArgumentMappingContext() = default;
virtual bool HasInput(const std::string& name) const = 0;
virtual bool HasOutput(const std::string& name) const = 0;
virtual bool HasAttr(const std::string& name) const = 0;
virtual size_t InputSize(const std::string& name) const = 0;
virtual size_t OutputSize(const std::string& name) const = 0;
virtual bool IsDenseTensorInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
};
struct KernelSignature {
std::string name;
KernelArgsTuple args;
......@@ -64,23 +49,23 @@ struct KernelSignature {
std::ostream& operator<<(std::ostream& os, KernelSignature signature);
using ArgumentMappingFn = KernelSignature (*)(const ArgumentMappingContext&);
class OpArgumentMappingFnMap {
// TODO(chenweihang): Add more methods if needed in future
class ArgumentMappingContext {
public:
static OpArgumentMappingFnMap& Instance();
virtual ~ArgumentMappingContext() = default;
bool Has(const std::string& op_type) const;
virtual bool HasInput(const std::string& name) const = 0;
virtual bool HasOutput(const std::string& name) const = 0;
const ArgumentMappingFn& Get(const std::string& op_type) const;
// now we can't use Attribute here, it will cause pten relay on
// boost::variant and BlockDesc
virtual paddle::any Attr(const std::string& name) const = 0;
void Emplace(const std::string& op_type,
const std::string api_name,
ArgumentMappingFn fn);
virtual size_t InputSize(const std::string& name) const = 0;
virtual size_t OutputSize(const std::string& name) const = 0;
private:
paddle::flat_hash_map<std::string, std::string> name_map_;
paddle::flat_hash_map<std::string, ArgumentMappingFn> fn_map_;
virtual bool IsDenseTensorInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
};
} // namespace pten
......@@ -14,16 +14,25 @@
#pragma once
#include <functional>
namespace pten {
class Kernel;
class KernelKey;
class KernelArgsDef;
class KernelContext;
class KernelSignature;
class ArgumentMappingContext;
class InferMetaContext;
using KernelFn = void (*)(KernelContext* ctx);
using KernelArgsDefFn = void (*)(Kernel* kernel);
using KernelArgsParseFn = void (*)(const KernelKey& default_key,
KernelArgsDef* args_def);
using ArgumentMappingFn =
std::function<KernelSignature(const ArgumentMappingContext&)>;
using InferMetaFn = void (*)(InferMetaContext* ctx);
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......
......@@ -47,7 +47,10 @@ class MetaTensor {
virtual void share_lod(const MetaTensor& meta_tensor);
private:
// Because the lod in compiletime and runtime is different,
// so `LoD` cannot in public methods
const LoD& lod() const;
TensorBase* tensor_;
};
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/pten/core/arg_map_context.h"
#include "paddle/pten/core/compat/arg_map_context.h"
namespace pten {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册