diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 9c39430835d37d5dfbe4031f29e5a6216ed8b67f..1db042c6fc8b6c4ea7c3854ea4b1cd016deeb0b6 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -12,13 +12,15 @@ cc_test(variable_test SRCS variable_test.cc) cc_library(scope SRCS scope.cc) cc_test(scope_test SRCS scope_test.cc DEPS scope) -proto_library(attr_type SRCS attr_type.proto) -proto_library(op_proto SRCS op_proto.proto DEPS attr_type) -proto_library(op_desc SRCS op_desc.proto DEPS attr_type) +proto_library(attribute_proto SRCS attribute.proto) +proto_library(op_proto SRCS op_proto.proto DEPS attribute_proto) +proto_library(op_desc SRCS op_desc.proto DEPS attribute_proto) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) -cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope) +cc_library(attribute SRCS attribute.cc DEPS op_desc op_proto) + +cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope attribute) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator) @@ -26,7 +28,7 @@ cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op) -py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) +py_proto_compile(framework_py_proto SRCS attribute.proto op_proto.proto op_desc.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_dependencies(framework_py_proto framework_py_proto_init) diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc new file mode 100644 index 0000000000000000000000000000000000000000..4c5790693b7e48396e945d09f4fdc72b86aa5978 --- /dev/null +++ b/paddle/framework/attribute.cc @@ -0,0 +1,85 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/attribute.h" + +#include + +namespace paddle { +namespace framework { + +template <> +AttrType AttrTypeID() { + return INT; +} +template <> +AttrType AttrTypeID() { + return FLOAT; +} +template <> +AttrType AttrTypeID() { + return STRING; +} +template <> +AttrType AttrTypeID>() { + return INTS; +} +template <> +AttrType AttrTypeID>() { + return FLOATS; +} +template <> +AttrType AttrTypeID>() { + return STRINGS; +} + +Attribute GetAttrValue(const AttrDesc& attr_desc) { + switch (attr_desc.type()) { + case paddle::framework::AttrType::INT: { + return attr_desc.i(); + } + case paddle::framework::AttrType::FLOAT: { + return attr_desc.f(); + } + case paddle::framework::AttrType::STRING: { + return attr_desc.s(); + } + case paddle::framework::AttrType::INTS: { + std::vector val(attr_desc.ints_size()); + for (int i = 0; i < attr_desc.ints_size(); ++i) { + val[i] = attr_desc.ints(i); + } + return val; + } + case paddle::framework::AttrType::FLOATS: { + std::vector val(attr_desc.floats_size()); + for (int i = 0; i < attr_desc.floats_size(); ++i) { + val[i] = attr_desc.floats(i); + } + return val; + } + case paddle::framework::AttrType::STRINGS: { + std::vector val(attr_desc.strings_size()); + for (int i = 0; i < attr_desc.strings_size(); ++i) { + val[i] = attr_desc.strings(i); + } + return val; + } + } + PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); + return boost::blank(); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/attr_checker.h b/paddle/framework/attribute.h similarity index 95% rename from paddle/framework/attr_checker.h rename to paddle/framework/attribute.h index ea5614a45f3a77a851358aff80abbc276c9972ba..72a654bda550ee75811f365403f6eeb5284a102e 100644 --- a/paddle/framework/attr_checker.h +++ b/paddle/framework/attribute.h @@ -6,6 +6,9 @@ #include #include #include + +#include "paddle/framework/attribute.pb.h" +#include "paddle/framework/op_desc.pb.h" #include "paddle/platform/enforce.h" namespace paddle { @@ -14,8 +17,14 @@ namespace framework { typedef boost::variant, std::vector, std::vector> Attribute; + typedef std::unordered_map AttributeMap; +template +AttrType AttrTypeID(); + +Attribute GetAttrValue(const AttrDesc& attr_desc); + // check whether a value(attribute) fit a certain limit template class LargerThanChecker { diff --git a/paddle/framework/attr_type.proto b/paddle/framework/attribute.proto similarity index 100% rename from paddle/framework/attr_type.proto rename to paddle/framework/attribute.proto diff --git a/paddle/framework/op_desc.proto b/paddle/framework/op_desc.proto index 89497f3c16bc28aa93b25a83c1f2eccafdf1c5b4..5954dd89155ec5d5e99a33f4688b705780a6582d 100644 --- a/paddle/framework/op_desc.proto +++ b/paddle/framework/op_desc.proto @@ -15,7 +15,7 @@ limitations under the License. */ syntax="proto2"; package paddle.framework; -import "attr_type.proto"; +import "attribute.proto"; // AttrDesc is used to describe Attributes of an Operator. It contain's // name, type, and value of Attribute. diff --git a/paddle/framework/op_proto.proto b/paddle/framework/op_proto.proto index 366c84e53dc29e41eefbaef0a6452e01c4fe37bd..60661cf7a8c7721b894fe648a7b8ca0f8279cf23 100644 --- a/paddle/framework/op_proto.proto +++ b/paddle/framework/op_proto.proto @@ -21,7 +21,7 @@ limitations under the License. */ syntax="proto2"; package paddle.framework; -import "attr_type.proto"; +import "attribute.proto"; // Attribute protocol message for 3rd-party language binding. // It will store the Op support what attribute and what type. diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 1d14535c50b542733663a6900a8b5f2033290ea6..1caa02a2a1d046778f875d04eeaef957be741302 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -14,37 +14,8 @@ limitations under the License. */ #include -namespace paddle { -namespace framework { - -template <> -void AttrTypeHelper::SetAttrType(AttrProto* attr) { - attr->set_type(paddle::framework::AttrType::INT); -} - -template <> -void AttrTypeHelper::SetAttrType(AttrProto* attr) { - attr->set_type(paddle::framework::AttrType::FLOAT); -} - -template <> -void AttrTypeHelper::SetAttrType(AttrProto* attr) { - attr->set_type(paddle::framework::AttrType::STRING); -} +#include -template <> -void AttrTypeHelper::SetAttrType>(AttrProto* attr) { - attr->set_type(paddle::framework::AttrType::INTS); -} - -template <> -void AttrTypeHelper::SetAttrType>(AttrProto* attr) { - attr->set_type(paddle::framework::AttrType::FLOATS); -} - -template <> -void AttrTypeHelper::SetAttrType>(AttrProto* attr) { - attr->set_type(paddle::framework::AttrType::STRINGS); -} -} // namespace framework +namespace paddle { +namespace framework {} // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 24ce7930f110f9b7b398f879713158d96c7712da..6c26183818a9d6996e3d3ce2af74ba36f4711eca 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -19,7 +19,7 @@ limitations under the License. */ #include #include #include -#include "paddle/framework/attr_checker.h" +#include "paddle/framework/attribute.h" #include "paddle/framework/grad_op_builder.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/scope.h" @@ -27,49 +27,6 @@ limitations under the License. */ namespace paddle { namespace framework { -// helper class to set attribute type -struct AttrTypeHelper { - template - static void SetAttrType(AttrProto* attr); - - static Attribute GetAttrValue(const AttrDesc& attr_desc) { - switch (attr_desc.type()) { - case paddle::framework::AttrType::INT: { - return attr_desc.i(); - } - case paddle::framework::AttrType::FLOAT: { - return attr_desc.f(); - } - case paddle::framework::AttrType::STRING: { - return attr_desc.s(); - } - case paddle::framework::AttrType::INTS: { - std::vector val(attr_desc.ints_size()); - for (int i = 0; i < attr_desc.ints_size(); ++i) { - val[i] = attr_desc.ints(i); - } - return val; - } - case paddle::framework::AttrType::FLOATS: { - std::vector val(attr_desc.floats_size()); - for (int i = 0; i < attr_desc.floats_size(); ++i) { - val[i] = attr_desc.floats(i); - } - return val; - } - case paddle::framework::AttrType::STRINGS: { - std::vector val(attr_desc.strings_size()); - for (int i = 0; i < attr_desc.strings_size(); ++i) { - val[i] = attr_desc.strings(i); - } - return val; - } - } - PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); - return boost::blank(); - } -}; - // this class not only make proto but also init attribute checkers. class OpProtoAndCheckerMaker { public: @@ -136,7 +93,7 @@ class OpProtoAndCheckerMaker { *attr->mutable_name() = name; *attr->mutable_comment() = comment; attr->set_generated(generated); - AttrTypeHelper::SetAttrType(attr); + attr->set_type(AttrTypeID()); return op_checker_->AddAttrChecker(name); } @@ -297,7 +254,7 @@ class OpRegistry { AttributeMap attrs; for (auto& attr : op_desc.attrs()) { - attrs[attr.name()] = AttrTypeHelper::GetAttrValue(attr); + attrs[attr.name()] = GetAttrValue(attr); } return CreateOp(op_desc.type(), inputs, outputs, attrs); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 6786ad080fd0fd26a572735290f6ac6d9fdab857..d42e21c0a235791db42076555d0568ff8f4acbe2 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -20,7 +20,7 @@ limitations under the License. */ #include #include -#include "paddle/framework/attr_checker.h" +#include "paddle/framework/attribute.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/scope.h"