diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 286a8684127a9fcbc42e98b89828d6acb87b859c..5c3b24463ef4b0f3974b748d388e0fdc26beaae8 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -167,6 +167,7 @@ cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows_utils data_device_transform data_type_transform data_layout_transform) cc_library(attribute SRCS attribute.cc DEPS framework_proto boost enforce) +cc_test(attribute_test SRCS attribute_test.cc DEPS attribute framework_proto proto_desc) cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc device_context) diff --git a/paddle/fluid/framework/attribute.cc b/paddle/fluid/framework/attribute.cc index 63934d17f996420b24f37ca982e7c439ea3db662..cf7a7c3c9f43dde58cc356fe5dc8e7f92bc1053f 100644 --- a/paddle/fluid/framework/attribute.cc +++ b/paddle/fluid/framework/attribute.cc @@ -17,6 +17,39 @@ limitations under the License. */ namespace paddle { namespace framework { +paddle::any GetAttrValue(const Attribute& attr) { + if (attr.type() == typeid(int)) { + return paddle::any(BOOST_GET_CONST(int, attr)); + } else if (attr.type() == typeid(float)) { + return paddle::any(BOOST_GET_CONST(float, attr)); + } else if (attr.type() == typeid(std::string)) { + return paddle::any(BOOST_GET_CONST(std::string, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else if (attr.type() == typeid(bool)) { + return paddle::any(BOOST_GET_CONST(bool, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else if (attr.type() == typeid(BlockDesc*)) { + return paddle::any(BOOST_GET_CONST(BlockDesc*, attr)); + } else if (attr.type() == typeid(int64_t)) { + return paddle::any(BOOST_GET_CONST(int64_t, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else { + PADDLE_THROW( + platform::errors::Unimplemented("Unsupported Attribute value type.")); + } +} + Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { case proto::AttrType::BOOLEAN: { diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 37d399b7779a7a3dbe743e061f74598a2bcdb377..7026cc7cf1aa3acdc27728350b7572a0aa8970f7 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -27,10 +27,15 @@ limitations under the License. */ #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" +#include "paddle/utils/any.h" namespace paddle { namespace framework { +paddle::any GetAttrValue(const Attribute& attr); + +Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); + template struct ExtractAttribute { explicit ExtractAttribute(const std::string& attr_name) @@ -204,8 +209,6 @@ inline proto::AttrType AttrTypeID() { return static_cast(tmp.which() - 1); } -Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); - class AttrReader { public: explicit AttrReader(const AttributeMap& attrs) @@ -234,6 +237,22 @@ class AttrReader { return *attr_value; } + inline const Attribute& GetAttr(const std::string& name) const { + auto it = attrs_.find(name); + bool found = it != attrs_.end(); + if (!found) { + if (default_attrs_ != nullptr) { + it = default_attrs_->find(name); + found = it != default_attrs_->end(); + } + } + PADDLE_ENFORCE_EQ(found, true, + platform::errors::NotFound( + "Attribute (%s) should be in AttributeMap.", name)); + + return it->second; + } + private: const AttributeMap& attrs_; const AttributeMap* default_attrs_; diff --git a/paddle/fluid/framework/attribute_test.cc b/paddle/fluid/framework/attribute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..27a6afb49f5e817c6c09ab5adda260059f75b4a4 --- /dev/null +++ b/paddle/fluid/framework/attribute_test.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/attribute.h" +#include "paddle/fluid/framework/program_desc.h" + +#include "gtest/gtest.h" +#include "paddle/utils/any.h" + +TEST(Attribute, GetAttrValueToAny) { + paddle::framework::Attribute x_int(100); + auto rlt_int = paddle::framework::GetAttrValue(x_int); + EXPECT_EQ(paddle::any_cast(rlt_int), 100); + + float float_value = 3.14; + paddle::framework::Attribute x_float(float_value); + auto rlt_float = paddle::framework::GetAttrValue(x_float); + EXPECT_NEAR(paddle::any_cast(rlt_float), 3.14, 1e-6); + + std::string str_value("test"); + paddle::framework::Attribute x_str(str_value); + auto rlt_str = paddle::framework::GetAttrValue(x_str); + EXPECT_EQ(paddle::any_cast(rlt_str), "test"); + + std::vector vec_int_var(2, 100); + paddle::framework::Attribute x_vec_int = vec_int_var; + auto rlt_vec_int = paddle::framework::GetAttrValue(x_vec_int); + auto vec_int = paddle::any_cast>(rlt_vec_int); + EXPECT_EQ(vec_int.size(), 2UL); + EXPECT_EQ(vec_int[0], 100); + EXPECT_EQ(vec_int[1], 100); + + std::vector vec_float_var(2, 3.14); + paddle::framework::Attribute x_vec_float = vec_float_var; + auto rlt_vec_float = paddle::framework::GetAttrValue(x_vec_float); + auto vec_float = paddle::any_cast>(rlt_vec_float); + EXPECT_EQ(vec_float.size(), 2UL); + EXPECT_NEAR(vec_float[0], 3.14, 1e-6); + EXPECT_NEAR(vec_float[1], 3.14, 1e-6); + + std::vector vec_str_var(2, "test"); + paddle::framework::Attribute x_vec_str = vec_str_var; + auto rlt_vec_str = paddle::framework::GetAttrValue(x_vec_str); + auto vec_str = paddle::any_cast>(rlt_vec_str); + EXPECT_EQ(vec_str.size(), 2UL); + EXPECT_EQ(vec_str[0], "test"); + EXPECT_EQ(vec_str[1], "test"); + + paddle::framework::Attribute x_bool(true); + auto rlt_bool = paddle::framework::GetAttrValue(x_bool); + EXPECT_EQ(paddle::any_cast(rlt_bool), true); + + std::vector vec_bool_var(2, true); + paddle::framework::Attribute x_vec_bool = vec_bool_var; + auto rlt_vec_bool = paddle::framework::GetAttrValue(x_vec_bool); + auto vec_bool = paddle::any_cast>(rlt_vec_bool); + EXPECT_EQ(vec_bool.size(), 2UL); + EXPECT_EQ(vec_bool[0], true); + EXPECT_EQ(vec_bool[1], true); + + paddle::framework::ProgramDesc prog; + paddle::framework::proto::BlockDesc proto_block; + paddle::framework::BlockDesc block_desc(&prog, &proto_block); + paddle::framework::Attribute x_block_desc(&block_desc); + auto rlt_block_desc = paddle::framework::GetAttrValue(x_block_desc); + auto block_desc_ptr = + paddle::any_cast(rlt_block_desc); + EXPECT_NE(block_desc_ptr, nullptr); + + std::vector vec_block_desc_var; + vec_block_desc_var.emplace_back(&block_desc); + paddle::framework::Attribute x_vec_block_desc(vec_block_desc_var); + auto rlt_vec_block_desc = paddle::framework::GetAttrValue(x_vec_block_desc); + auto vec_block_desc = + paddle::any_cast>( + rlt_vec_block_desc); + EXPECT_EQ(vec_block_desc.size(), 1UL); + EXPECT_NE(vec_block_desc[0], nullptr); + + int64_t int64_value = 100; + paddle::framework::Attribute x_int64(int64_value); + auto rlt_int64 = paddle::framework::GetAttrValue(x_int64); + EXPECT_EQ(paddle::any_cast(rlt_int64), 100); + + std::vector vec_int64_var(2, 100); + paddle::framework::Attribute x_vec_int64 = vec_int64_var; + auto rlt_vec_int64 = paddle::framework::GetAttrValue(x_vec_int64); + auto vec_int64 = paddle::any_cast>(rlt_vec_int64); + EXPECT_EQ(vec_int64.size(), 2UL); + EXPECT_EQ(vec_int64[0], 100); + EXPECT_EQ(vec_int64[1], 100); + + std::vector vec_double_var(2, 3.14); + paddle::framework::Attribute x_vec_double = vec_double_var; + auto rlt_vec_double = paddle::framework::GetAttrValue(x_vec_double); + auto vec_double = paddle::any_cast>(rlt_vec_double); + EXPECT_EQ(vec_double.size(), 2UL); + EXPECT_NEAR(vec_double[0], 3.14, 1e-6); + EXPECT_NEAR(vec_double[1], 3.14, 1e-6); +} diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 8e000ef9985bd90854fdc452d6ff56bcd428b387..40c80ec5f2d654b57a72290398e323e1ce91e156 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -40,7 +40,7 @@ limitations under the License. */ #include "paddle/fluid/platform/variant.h" #include "paddle/utils/flat_hash_map.h" -#include "paddle/pten/core/arg_map_context.h" +#include "paddle/pten/core/compat/arg_map_context.h" #include "paddle/pten/core/kernel_context.h" #include "paddle/pten/core/kernel_factory.h" @@ -454,8 +454,9 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { return ctx_.HasOutput(name); } - bool HasAttr(const std::string& name) const override { - return ctx_.HasAttr(name); + paddle::any Attr(const std::string& name) const override { + auto& attr = ctx_.GetAttr(name); + return GetAttrValue(attr); } size_t InputSize(const std::string& name) const override { diff --git a/paddle/fluid/framework/pten_utils.h b/paddle/fluid/framework/pten_utils.h index a4493f3d3e5c08ad926e56771f09635640db59e3..ab129c6313dabfecf3d7cd1968b66485e48ec211 100644 --- a/paddle/fluid/framework/pten_utils.h +++ b/paddle/fluid/framework/pten_utils.h @@ -25,7 +25,7 @@ limitations under the License. */ #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" #include "paddle/pten/api/lib/utils/tensor_utils.h" -#include "paddle/pten/core/arg_map_context.h" +#include "paddle/pten/core/compat/arg_map_context.h" #include "paddle/pten/core/kernel_factory.h" #include "paddle/utils/flat_hash_map.h" #include "paddle/utils/small_vector.h" @@ -85,6 +85,6 @@ template <> struct ConvertToPtenContext { using TYPE = pten::CPUContext; }; - + } // namespace framework } // namespace paddle diff --git a/paddle/pten/core/CMakeLists.txt b/paddle/pten/core/CMakeLists.txt index f6f0e1f3e26ecc56e0d48bad0a22b8054f0664e2..cd3a1755a9df4350cc7dc2638e248c9ba1f101e0 100644 --- a/paddle/pten/core/CMakeLists.txt +++ b/paddle/pten/core/CMakeLists.txt @@ -1,3 +1,6 @@ +# utils used for compatible for fluid op system +add_subdirectory(compat) + if(WITH_GPU) cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info) elseif(WITH_ROCM) @@ -8,7 +11,6 @@ endif() cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce convert_utils) cc_library(kernel_context SRCS kernel_context.cc DEPS enforce pten_context) -cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce) cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce) cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector) diff --git a/paddle/pten/core/compat/CMakeLists.txt b/paddle/pten/core/compat/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..253f60daf1f890caccdeb02908c1b4fb3d6c62da --- /dev/null +++ b/paddle/pten/core/compat/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce) diff --git a/paddle/pten/core/arg_map_context.cc b/paddle/pten/core/compat/arg_map_context.cc similarity index 53% rename from paddle/pten/core/arg_map_context.cc rename to paddle/pten/core/compat/arg_map_context.cc index d7aea11ddf043a30d3434427d8e25dd3fc97e3ac..3914a8a684eda937cf54283f72a04bec67cf64af 100644 --- a/paddle/pten/core/arg_map_context.cc +++ b/paddle/pten/core/compat/arg_map_context.cc @@ -12,41 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/pten/core/arg_map_context.h" +#include "paddle/pten/core/compat/arg_map_context.h" -#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/string/string_helper.h" namespace pten { -OpArgumentMappingFnMap& OpArgumentMappingFnMap::Instance() { - static OpArgumentMappingFnMap g_op_arg_mapping_fn_map; - return g_op_arg_mapping_fn_map; -} - -bool OpArgumentMappingFnMap::Has(const std::string& op_type) const { - return fn_map_.find(op_type) != fn_map_.end(); -} - -const ArgumentMappingFn& OpArgumentMappingFnMap::Get( - const std::string& op_type) const { - auto it = fn_map_.find(op_type); - PADDLE_ENFORCE_NE( - it, - fn_map_.end(), - paddle::platform::errors::NotFound( - "Operator `%s`'s argument mapping funciton is not registered.", - op_type)); - return it->second; -} - -void OpArgumentMappingFnMap::Emplace(const std::string& op_type, - const std::string api_name, - ArgumentMappingFn fn) { - name_map_.emplace(op_type, api_name); - fn_map_.emplace(op_type, fn); -} - std::ostream& operator<<(std::ostream& os, KernelSignature signature) { os << "Kernel Signature - name: " << signature.name << "; inputs: " << paddle::string::join_strings(std::get<0>(signature.args), ", ") diff --git a/paddle/pten/core/arg_map_context.h b/paddle/pten/core/compat/arg_map_context.h similarity index 79% rename from paddle/pten/core/arg_map_context.h rename to paddle/pten/core/compat/arg_map_context.h index be9eb3af76a36704d4d05fbf2ce39bc8d4a0d37c..e7dfc0706544c9ce3f33d9e56bf406089da7f5a2 100644 --- a/paddle/pten/core/arg_map_context.h +++ b/paddle/pten/core/compat/arg_map_context.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include +#include "paddle/utils/any.h" #include "paddle/utils/flat_hash_map.h" #include "paddle/utils/small_vector.h" @@ -28,22 +29,6 @@ using KernelArgsTuple = std::tuple, paddle::SmallVector, paddle::SmallVector>; -// TODO(chenweihang): Add more methods if needed in future -class ArgumentMappingContext { - public: - virtual ~ArgumentMappingContext() = default; - - virtual bool HasInput(const std::string& name) const = 0; - virtual bool HasOutput(const std::string& name) const = 0; - virtual bool HasAttr(const std::string& name) const = 0; - - virtual size_t InputSize(const std::string& name) const = 0; - virtual size_t OutputSize(const std::string& name) const = 0; - - virtual bool IsDenseTensorInput(const std::string& name) const = 0; - virtual bool IsSelectedRowsInput(const std::string& name) const = 0; -}; - struct KernelSignature { std::string name; KernelArgsTuple args; @@ -64,23 +49,23 @@ struct KernelSignature { std::ostream& operator<<(std::ostream& os, KernelSignature signature); -using ArgumentMappingFn = KernelSignature (*)(const ArgumentMappingContext&); - -class OpArgumentMappingFnMap { +// TODO(chenweihang): Add more methods if needed in future +class ArgumentMappingContext { public: - static OpArgumentMappingFnMap& Instance(); + virtual ~ArgumentMappingContext() = default; - bool Has(const std::string& op_type) const; + virtual bool HasInput(const std::string& name) const = 0; + virtual bool HasOutput(const std::string& name) const = 0; - const ArgumentMappingFn& Get(const std::string& op_type) const; + // now we can't use Attribute here, it will cause pten relay on + // boost::variant and BlockDesc + virtual paddle::any Attr(const std::string& name) const = 0; - void Emplace(const std::string& op_type, - const std::string api_name, - ArgumentMappingFn fn); + virtual size_t InputSize(const std::string& name) const = 0; + virtual size_t OutputSize(const std::string& name) const = 0; - private: - paddle::flat_hash_map name_map_; - paddle::flat_hash_map fn_map_; + virtual bool IsDenseTensorInput(const std::string& name) const = 0; + virtual bool IsSelectedRowsInput(const std::string& name) const = 0; }; } // namespace pten diff --git a/paddle/pten/core/kernel_def.h b/paddle/pten/core/kernel_def.h index 875083cfb59e39bfd5e073e69497394dc00e9d7b..3884bb55e47b877dcae744dbffb4663d073291ff 100644 --- a/paddle/pten/core/kernel_def.h +++ b/paddle/pten/core/kernel_def.h @@ -14,16 +14,25 @@ #pragma once +#include + namespace pten { class Kernel; class KernelKey; class KernelArgsDef; class KernelContext; +class KernelSignature; +class ArgumentMappingContext; +class InferMetaContext; using KernelFn = void (*)(KernelContext* ctx); using KernelArgsDefFn = void (*)(Kernel* kernel); using KernelArgsParseFn = void (*)(const KernelKey& default_key, KernelArgsDef* args_def); +using ArgumentMappingFn = + std::function; +using InferMetaFn = void (*)(InferMetaContext* ctx); + } // namespace pten diff --git a/paddle/pten/core/macros.h b/paddle/pten/core/macros.h index fec67b1a3dc25d0a66b99bde51c2b33ff5cbc681..20a39fdda2ced8143d7a20a1406d4ad6fe5da80a 100644 --- a/paddle/pten/core/macros.h +++ b/paddle/pten/core/macros.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/pten/core/meta_tensor.h b/paddle/pten/core/meta_tensor.h index 4273aa6f85b4e584cf5241235592ab6c510db8af..442ff4137de4267e863c169df3dceb4deca2757a 100644 --- a/paddle/pten/core/meta_tensor.h +++ b/paddle/pten/core/meta_tensor.h @@ -47,7 +47,10 @@ class MetaTensor { virtual void share_lod(const MetaTensor& meta_tensor); private: + // Because the lod in compiletime and runtime is different, + // so `LoD` cannot in public methods const LoD& lod() const; + TensorBase* tensor_; }; diff --git a/paddle/pten/ops/compat/scale_args_fn.h b/paddle/pten/ops/compat/scale_args_fn.h index b9a20400f971a0477709fc052cfb4df520bcb4f9..91f0db389d9d5094e6f6d3cf978c4c35590d1d2e 100644 --- a/paddle/pten/ops/compat/scale_args_fn.h +++ b/paddle/pten/ops/compat/scale_args_fn.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/pten/core/arg_map_context.h" +#include "paddle/pten/core/compat/arg_map_context.h" namespace pten {