From ab9545aa95fb482e7b51b58e0abe2191c9ef3bea Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 4 Oct 2017 00:44:07 -0700 Subject: [PATCH] add shape_inference_map --- paddle/framework/CMakeLists.txt | 4 +- paddle/framework/op_registry.h | 4 ++ paddle/framework/shape_inference.h | 1 + paddle/framework/shape_inference_map.cc | 57 +++++++++++++++++++++++++ paddle/framework/shape_inference_map.h | 56 ++++++++++++++++++++++++ 5 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 paddle/framework/shape_inference_map.cc create mode 100644 paddle/framework/shape_inference_map.h diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index a2efcdb55cf..986b45451fe 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -26,8 +26,10 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) +cc_library(shape_inference_map SRCS shape_inference_map.cc DEPS op_info operator) + cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc) -cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info) +cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info shape_inference_map) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry sum_op) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 4ee2c7d2756..f04b6c503a9 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/framework/grad_op_builder.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" +#include "paddle/framework/shape_inference_map.h" namespace paddle { namespace framework { @@ -54,9 +55,12 @@ class OpRegistry { const std::string& grad_op_type) { OperatorRegistrar reg(op_type.c_str()); reg.info.grad_op_type_ = grad_op_type; + ShapeInferenceMap::Instance().CreateOpWithKernel(reg.info, op_type); // register gradient op if (!grad_op_type.empty()) { OperatorRegistrar grad_reg(grad_op_type.c_str()); + ShapeInferenceMap::Instance().CreateOpWithKernel(grad_reg.info, + grad_op_type); } } diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index bc8af0eb3ec..ac6f238638c 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/framework/attribute.h" #include "paddle/framework/ddim.h" namespace paddle { diff --git a/paddle/framework/shape_inference_map.cc b/paddle/framework/shape_inference_map.cc new file mode 100644 index 00000000000..1a27037221a --- /dev/null +++ b/paddle/framework/shape_inference_map.cc @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/shape_inference_map.h" + +namespace paddle { +namespace framework { + +static VariableNameMap ConvertOpProtoVarsToVarNameMap( + const google::protobuf::RepeatedPtrField& op_proto_vars) { + VariableNameMap ret_val; + for (auto& var : op_proto_vars) { + ret_val[var.name()] = {}; + } + return ret_val; +} + +static ShapeInferenceMap* g_shape_inference_map = nullptr; + +ShapeInferenceMap& ShapeInferenceMap::Instance() { + if (g_shape_inference_map == nullptr) { + g_shape_inference_map = new ShapeInferenceMap(); + } + return *g_shape_inference_map; +} + +void ShapeInferenceMap::CreateOpWithKernel(const OpInfo& op_info, + const std::string& op_type) { + const VariableNameMap inputs = + ConvertOpProtoVarsToVarNameMap(op_info.Proto().inputs()); + const VariableNameMap outputs = + ConvertOpProtoVarsToVarNameMap(op_info.Proto().outputs()); + auto* op = op_info.Creator()(op_type, inputs, outputs, {}); + auto* op_with_kernel = dynamic_cast(op); + auto it = op_shape_inference_map_.find(op_type); + if (it != op_shape_inference_map_.end()) { + PADDLE_THROW("OpWithKernel(%s) is already registered for infer_shape", + op_type); + } + if (op_with_kernel != nullptr) { + op_shape_inference_map_[op_type] = op_with_kernel; + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/shape_inference_map.h b/paddle/framework/shape_inference_map.h new file mode 100644 index 00000000000..fb126690268 --- /dev/null +++ b/paddle/framework/shape_inference_map.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/framework/op_info.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/shape_inference.h" + +namespace paddle { +namespace framework { + +class ShapeInferenceMap { + public: + static ShapeInferenceMap& Instance(); + + const OperatorBase* GetOperator(const std::string& op_type) { + auto it = op_shape_inference_map_.find(op_type); + if (it == op_shape_inference_map_.end()) { + PADDLE_THROW("op with kernel for Op(%s) is not registered", op_type); + } + return it->second; + } + + void CreateOpWithKernel(const OpInfo& op_info, const std::string& op_type); + + OperatorWithKernel* GetOpWithKernel(const std::string& op_type) { + auto it = op_shape_inference_map_.find(op_type); + if (it == op_shape_inference_map_.end()) { + return nullptr; + } + return it->second; + } + + private: + ShapeInferenceMap() = default; + DISABLE_COPY_AND_ASSIGN(ShapeInferenceMap); + + std::unordered_map op_shape_inference_map_; +}; + +} // namespace framework +} // namespace paddle -- GitLab