diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 7cbd77ec1f597437f0824c7b42ef71f6bd638366..454e8d54d4cb7e0195addb6e7ffbc132b8e29c50 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -12,13 +12,15 @@ cc_test(variable_test SRCS variable_test.cc) cc_library(scope SRCS scope.cc) cc_test(scope_test SRCS scope_test.cc DEPS scope) +cc_library(attribute SRCS attribute.cc) + proto_library(attribute_proto SRCS attribute.proto) proto_library(op_proto SRCS op_proto.proto DEPS attribute_proto) proto_library(op_desc SRCS op_desc.proto DEPS attribute_proto) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) -cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope) +cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope attribute) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator) diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc new file mode 100644 index 0000000000000000000000000000000000000000..4c5790693b7e48396e945d09f4fdc72b86aa5978 --- /dev/null +++ b/paddle/framework/attribute.cc @@ -0,0 +1,85 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/attribute.h" + +#include + +namespace paddle { +namespace framework { + +template <> +AttrType AttrTypeID() { + return INT; +} +template <> +AttrType AttrTypeID() { + return FLOAT; +} +template <> +AttrType AttrTypeID() { + return STRING; +} +template <> +AttrType AttrTypeID>() { + return INTS; +} +template <> +AttrType AttrTypeID>() { + return FLOATS; +} +template <> +AttrType AttrTypeID>() { + return STRINGS; +} + +Attribute GetAttrValue(const AttrDesc& attr_desc) { + switch (attr_desc.type()) { + case paddle::framework::AttrType::INT: { + return attr_desc.i(); + } + case paddle::framework::AttrType::FLOAT: { + return attr_desc.f(); + } + case paddle::framework::AttrType::STRING: { + return attr_desc.s(); + } + case paddle::framework::AttrType::INTS: { + std::vector val(attr_desc.ints_size()); + for (int i = 0; i < attr_desc.ints_size(); ++i) { + val[i] = attr_desc.ints(i); + } + return val; + } + case paddle::framework::AttrType::FLOATS: { + std::vector val(attr_desc.floats_size()); + for (int i = 0; i < attr_desc.floats_size(); ++i) { + val[i] = attr_desc.floats(i); + } + return val; + } + case paddle::framework::AttrType::STRINGS: { + std::vector val(attr_desc.strings_size()); + for (int i = 0; i < attr_desc.strings_size(); ++i) { + val[i] = attr_desc.strings(i); + } + return val; + } + } + PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); + return boost::blank(); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 8f3e898ec552df7e102b1dad8c654be08f92e889..fcd8eceb3528f1510da323dc99de80b5672d2573 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -27,12 +27,6 @@ limitations under the License. */ namespace paddle { namespace framework { -// helper class to set attribute type -struct AttrTypeHelper { - template - static void SetAttrType(AttrProto* attr); -}; - // this class not only make proto but also init attribute checkers. class OpProtoAndCheckerMaker { public: