未验证 提交 7ac2f80f 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add attr method for ArgumentMappingContext (#39130)

* add attr for arg map context

* add argument fn declare

* add attr test for get attr value method

* polish details
上级 ff7f9d06
...@@ -167,6 +167,7 @@ cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor ...@@ -167,6 +167,7 @@ cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
framework_proto selected_rows_utils data_device_transform data_type_transform data_layout_transform) framework_proto selected_rows_utils data_device_transform data_type_transform data_layout_transform)
cc_library(attribute SRCS attribute.cc DEPS framework_proto boost enforce) cc_library(attribute SRCS attribute.cc DEPS framework_proto boost enforce)
cc_test(attribute_test SRCS attribute_test.cc DEPS attribute framework_proto proto_desc)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context) device_context)
......
...@@ -17,6 +17,39 @@ limitations under the License. */ ...@@ -17,6 +17,39 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
paddle::any GetAttrValue(const Attribute& attr) {
if (attr.type() == typeid(int)) {
return paddle::any(BOOST_GET_CONST(int, attr));
} else if (attr.type() == typeid(float)) {
return paddle::any(BOOST_GET_CONST(float, attr));
} else if (attr.type() == typeid(std::string)) {
return paddle::any(BOOST_GET_CONST(std::string, attr));
} else if (attr.type() == typeid(std::vector<int>)) {
return paddle::any(BOOST_GET_CONST(std::vector<int>, attr));
} else if (attr.type() == typeid(std::vector<float>)) {
return paddle::any(BOOST_GET_CONST(std::vector<float>, attr));
} else if (attr.type() == typeid(std::vector<std::string>)) {
return paddle::any(BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr.type() == typeid(bool)) {
return paddle::any(BOOST_GET_CONST(bool, attr));
} else if (attr.type() == typeid(std::vector<bool>)) {
return paddle::any(BOOST_GET_CONST(std::vector<bool>, attr));
} else if (attr.type() == typeid(BlockDesc*)) {
return paddle::any(BOOST_GET_CONST(BlockDesc*, attr));
} else if (attr.type() == typeid(int64_t)) {
return paddle::any(BOOST_GET_CONST(int64_t, attr));
} else if (attr.type() == typeid(std::vector<BlockDesc*>)) {
return paddle::any(BOOST_GET_CONST(std::vector<BlockDesc*>, attr));
} else if (attr.type() == typeid(std::vector<int64_t>)) {
return paddle::any(BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (attr.type() == typeid(std::vector<double>)) {
return paddle::any(BOOST_GET_CONST(std::vector<double>, attr));
} else {
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported Attribute value type."));
}
}
Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) { Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
case proto::AttrType::BOOLEAN: { case proto::AttrType::BOOLEAN: {
......
...@@ -27,10 +27,15 @@ limitations under the License. */ ...@@ -27,10 +27,15 @@ limitations under the License. */
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "paddle/utils/any.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
paddle::any GetAttrValue(const Attribute& attr);
Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
template <typename T> template <typename T>
struct ExtractAttribute { struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name) explicit ExtractAttribute(const std::string& attr_name)
...@@ -204,8 +209,6 @@ inline proto::AttrType AttrTypeID() { ...@@ -204,8 +209,6 @@ inline proto::AttrType AttrTypeID() {
return static_cast<proto::AttrType>(tmp.which() - 1); return static_cast<proto::AttrType>(tmp.which() - 1);
} }
Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
class AttrReader { class AttrReader {
public: public:
explicit AttrReader(const AttributeMap& attrs) explicit AttrReader(const AttributeMap& attrs)
...@@ -234,6 +237,22 @@ class AttrReader { ...@@ -234,6 +237,22 @@ class AttrReader {
return *attr_value; return *attr_value;
} }
inline const Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name);
bool found = it != attrs_.end();
if (!found) {
if (default_attrs_ != nullptr) {
it = default_attrs_->find(name);
found = it != default_attrs_->end();
}
}
PADDLE_ENFORCE_EQ(found, true,
platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));
return it->second;
}
private: private:
const AttributeMap& attrs_; const AttributeMap& attrs_;
const AttributeMap* default_attrs_; const AttributeMap* default_attrs_;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/program_desc.h"
#include "gtest/gtest.h"
#include "paddle/utils/any.h"
TEST(Attribute, GetAttrValueToAny) {
paddle::framework::Attribute x_int(100);
auto rlt_int = paddle::framework::GetAttrValue(x_int);
EXPECT_EQ(paddle::any_cast<int>(rlt_int), 100);
float float_value = 3.14;
paddle::framework::Attribute x_float(float_value);
auto rlt_float = paddle::framework::GetAttrValue(x_float);
EXPECT_NEAR(paddle::any_cast<float>(rlt_float), 3.14, 1e-6);
std::string str_value("test");
paddle::framework::Attribute x_str(str_value);
auto rlt_str = paddle::framework::GetAttrValue(x_str);
EXPECT_EQ(paddle::any_cast<std::string>(rlt_str), "test");
std::vector<int> vec_int_var(2, 100);
paddle::framework::Attribute x_vec_int = vec_int_var;
auto rlt_vec_int = paddle::framework::GetAttrValue(x_vec_int);
auto vec_int = paddle::any_cast<std::vector<int>>(rlt_vec_int);
EXPECT_EQ(vec_int.size(), 2UL);
EXPECT_EQ(vec_int[0], 100);
EXPECT_EQ(vec_int[1], 100);
std::vector<float> vec_float_var(2, 3.14);
paddle::framework::Attribute x_vec_float = vec_float_var;
auto rlt_vec_float = paddle::framework::GetAttrValue(x_vec_float);
auto vec_float = paddle::any_cast<std::vector<float>>(rlt_vec_float);
EXPECT_EQ(vec_float.size(), 2UL);
EXPECT_NEAR(vec_float[0], 3.14, 1e-6);
EXPECT_NEAR(vec_float[1], 3.14, 1e-6);
std::vector<std::string> vec_str_var(2, "test");
paddle::framework::Attribute x_vec_str = vec_str_var;
auto rlt_vec_str = paddle::framework::GetAttrValue(x_vec_str);
auto vec_str = paddle::any_cast<std::vector<std::string>>(rlt_vec_str);
EXPECT_EQ(vec_str.size(), 2UL);
EXPECT_EQ(vec_str[0], "test");
EXPECT_EQ(vec_str[1], "test");
paddle::framework::Attribute x_bool(true);
auto rlt_bool = paddle::framework::GetAttrValue(x_bool);
EXPECT_EQ(paddle::any_cast<bool>(rlt_bool), true);
std::vector<bool> vec_bool_var(2, true);
paddle::framework::Attribute x_vec_bool = vec_bool_var;
auto rlt_vec_bool = paddle::framework::GetAttrValue(x_vec_bool);
auto vec_bool = paddle::any_cast<std::vector<bool>>(rlt_vec_bool);
EXPECT_EQ(vec_bool.size(), 2UL);
EXPECT_EQ(vec_bool[0], true);
EXPECT_EQ(vec_bool[1], true);
paddle::framework::ProgramDesc prog;
paddle::framework::proto::BlockDesc proto_block;
paddle::framework::BlockDesc block_desc(&prog, &proto_block);
paddle::framework::Attribute x_block_desc(&block_desc);
auto rlt_block_desc = paddle::framework::GetAttrValue(x_block_desc);
auto block_desc_ptr =
paddle::any_cast<paddle::framework::BlockDesc*>(rlt_block_desc);
EXPECT_NE(block_desc_ptr, nullptr);
std::vector<paddle::framework::BlockDesc*> vec_block_desc_var;
vec_block_desc_var.emplace_back(&block_desc);
paddle::framework::Attribute x_vec_block_desc(vec_block_desc_var);
auto rlt_vec_block_desc = paddle::framework::GetAttrValue(x_vec_block_desc);
auto vec_block_desc =
paddle::any_cast<std::vector<paddle::framework::BlockDesc*>>(
rlt_vec_block_desc);
EXPECT_EQ(vec_block_desc.size(), 1UL);
EXPECT_NE(vec_block_desc[0], nullptr);
int64_t int64_value = 100;
paddle::framework::Attribute x_int64(int64_value);
auto rlt_int64 = paddle::framework::GetAttrValue(x_int64);
EXPECT_EQ(paddle::any_cast<int64_t>(rlt_int64), 100);
std::vector<int64_t> vec_int64_var(2, 100);
paddle::framework::Attribute x_vec_int64 = vec_int64_var;
auto rlt_vec_int64 = paddle::framework::GetAttrValue(x_vec_int64);
auto vec_int64 = paddle::any_cast<std::vector<int64_t>>(rlt_vec_int64);
EXPECT_EQ(vec_int64.size(), 2UL);
EXPECT_EQ(vec_int64[0], 100);
EXPECT_EQ(vec_int64[1], 100);
std::vector<double> vec_double_var(2, 3.14);
paddle::framework::Attribute x_vec_double = vec_double_var;
auto rlt_vec_double = paddle::framework::GetAttrValue(x_vec_double);
auto vec_double = paddle::any_cast<std::vector<double>>(rlt_vec_double);
EXPECT_EQ(vec_double.size(), 2UL);
EXPECT_NEAR(vec_double[0], 3.14, 1e-6);
EXPECT_NEAR(vec_double[1], 3.14, 1e-6);
}
...@@ -40,7 +40,7 @@ limitations under the License. */ ...@@ -40,7 +40,7 @@ limitations under the License. */
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
#include "paddle/utils/flat_hash_map.h" #include "paddle/utils/flat_hash_map.h"
#include "paddle/pten/core/arg_map_context.h" #include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/kernel_context.h" #include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/core/kernel_factory.h"
...@@ -454,8 +454,9 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { ...@@ -454,8 +454,9 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
return ctx_.HasOutput(name); return ctx_.HasOutput(name);
} }
bool HasAttr(const std::string& name) const override { paddle::any Attr(const std::string& name) const override {
return ctx_.HasAttr(name); auto& attr = ctx_.GetAttr(name);
return GetAttrValue(attr);
} }
size_t InputSize(const std::string& name) const override { size_t InputSize(const std::string& name) const override {
......
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/core/arg_map_context.h" #include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/core/kernel_factory.h"
#include "paddle/utils/flat_hash_map.h" #include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
...@@ -85,6 +85,6 @@ template <> ...@@ -85,6 +85,6 @@ template <>
struct ConvertToPtenContext<platform::CPUDeviceContext> { struct ConvertToPtenContext<platform::CPUDeviceContext> {
using TYPE = pten::CPUContext; using TYPE = pten::CPUContext;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
# utils used for compatible for fluid op system
add_subdirectory(compat)
if(WITH_GPU) if(WITH_GPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info) cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
elseif(WITH_ROCM) elseif(WITH_ROCM)
...@@ -8,7 +11,6 @@ endif() ...@@ -8,7 +11,6 @@ endif()
cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce convert_utils) cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce convert_utils)
cc_library(kernel_context SRCS kernel_context.cc DEPS enforce pten_context) cc_library(kernel_context SRCS kernel_context.cc DEPS enforce pten_context)
cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce)
cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce) cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce)
cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector) cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector)
......
cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce)
...@@ -12,41 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,41 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/pten/core/arg_map_context.h" #include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
namespace pten { namespace pten {
OpArgumentMappingFnMap& OpArgumentMappingFnMap::Instance() {
static OpArgumentMappingFnMap g_op_arg_mapping_fn_map;
return g_op_arg_mapping_fn_map;
}
bool OpArgumentMappingFnMap::Has(const std::string& op_type) const {
return fn_map_.find(op_type) != fn_map_.end();
}
const ArgumentMappingFn& OpArgumentMappingFnMap::Get(
const std::string& op_type) const {
auto it = fn_map_.find(op_type);
PADDLE_ENFORCE_NE(
it,
fn_map_.end(),
paddle::platform::errors::NotFound(
"Operator `%s`'s argument mapping funciton is not registered.",
op_type));
return it->second;
}
void OpArgumentMappingFnMap::Emplace(const std::string& op_type,
const std::string api_name,
ArgumentMappingFn fn) {
name_map_.emplace(op_type, api_name);
fn_map_.emplace(op_type, fn);
}
std::ostream& operator<<(std::ostream& os, KernelSignature signature) { std::ostream& operator<<(std::ostream& os, KernelSignature signature) {
os << "Kernel Signature - name: " << signature.name << "; inputs: " os << "Kernel Signature - name: " << signature.name << "; inputs: "
<< paddle::string::join_strings(std::get<0>(signature.args), ", ") << paddle::string::join_strings(std::get<0>(signature.args), ", ")
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <string> #include <string>
#include <tuple> #include <tuple>
#include "paddle/utils/any.h"
#include "paddle/utils/flat_hash_map.h" #include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
...@@ -28,22 +29,6 @@ using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>, ...@@ -28,22 +29,6 @@ using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>, paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>>; paddle::SmallVector<std::string>>;
// TODO(chenweihang): Add more methods if needed in future
class ArgumentMappingContext {
public:
virtual ~ArgumentMappingContext() = default;
virtual bool HasInput(const std::string& name) const = 0;
virtual bool HasOutput(const std::string& name) const = 0;
virtual bool HasAttr(const std::string& name) const = 0;
virtual size_t InputSize(const std::string& name) const = 0;
virtual size_t OutputSize(const std::string& name) const = 0;
virtual bool IsDenseTensorInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
};
struct KernelSignature { struct KernelSignature {
std::string name; std::string name;
KernelArgsTuple args; KernelArgsTuple args;
...@@ -64,23 +49,23 @@ struct KernelSignature { ...@@ -64,23 +49,23 @@ struct KernelSignature {
std::ostream& operator<<(std::ostream& os, KernelSignature signature); std::ostream& operator<<(std::ostream& os, KernelSignature signature);
using ArgumentMappingFn = KernelSignature (*)(const ArgumentMappingContext&); // TODO(chenweihang): Add more methods if needed in future
class ArgumentMappingContext {
class OpArgumentMappingFnMap {
public: public:
static OpArgumentMappingFnMap& Instance(); virtual ~ArgumentMappingContext() = default;
bool Has(const std::string& op_type) const; virtual bool HasInput(const std::string& name) const = 0;
virtual bool HasOutput(const std::string& name) const = 0;
const ArgumentMappingFn& Get(const std::string& op_type) const; // now we can't use Attribute here, it will cause pten relay on
// boost::variant and BlockDesc
virtual paddle::any Attr(const std::string& name) const = 0;
void Emplace(const std::string& op_type, virtual size_t InputSize(const std::string& name) const = 0;
const std::string api_name, virtual size_t OutputSize(const std::string& name) const = 0;
ArgumentMappingFn fn);
private: virtual bool IsDenseTensorInput(const std::string& name) const = 0;
paddle::flat_hash_map<std::string, std::string> name_map_; virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
paddle::flat_hash_map<std::string, ArgumentMappingFn> fn_map_;
}; };
} // namespace pten } // namespace pten
...@@ -14,16 +14,25 @@ ...@@ -14,16 +14,25 @@
#pragma once #pragma once
#include <functional>
namespace pten { namespace pten {
class Kernel; class Kernel;
class KernelKey; class KernelKey;
class KernelArgsDef; class KernelArgsDef;
class KernelContext; class KernelContext;
class KernelSignature;
class ArgumentMappingContext;
class InferMetaContext;
using KernelFn = void (*)(KernelContext* ctx); using KernelFn = void (*)(KernelContext* ctx);
using KernelArgsDefFn = void (*)(Kernel* kernel); using KernelArgsDefFn = void (*)(Kernel* kernel);
using KernelArgsParseFn = void (*)(const KernelKey& default_key, using KernelArgsParseFn = void (*)(const KernelKey& default_key,
KernelArgsDef* args_def); KernelArgsDef* args_def);
using ArgumentMappingFn =
std::function<KernelSignature(const ArgumentMappingContext&)>;
using InferMetaFn = void (*)(InferMetaContext* ctx);
} // namespace pten } // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
......
...@@ -47,7 +47,10 @@ class MetaTensor { ...@@ -47,7 +47,10 @@ class MetaTensor {
virtual void share_lod(const MetaTensor& meta_tensor); virtual void share_lod(const MetaTensor& meta_tensor);
private: private:
// Because the lod in compiletime and runtime is different,
// so `LoD` cannot in public methods
const LoD& lod() const; const LoD& lod() const;
TensorBase* tensor_; TensorBase* tensor_;
}; };
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/pten/core/arg_map_context.h" #include "paddle/pten/core/compat/arg_map_context.h"
namespace pten { namespace pten {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册