diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index a2efcdb55cfc75a4f961533d16d454ca6d431990..986b45451fe71c81d0ba9cb4d250cea972bfee68 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -26,8 +26,10 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) +cc_library(shape_inference_map SRCS shape_inference_map.cc DEPS op_info operator) + cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc) -cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info) +cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info shape_inference_map) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry sum_op) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 4ee2c7d27561c3855059c42e1604f353c65bfa41..f04b6c503a95e64c700ab0d9b258e9118c35260b 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/framework/grad_op_builder.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" +#include "paddle/framework/shape_inference_map.h" namespace paddle { namespace framework { @@ -54,9 +55,12 @@ class OpRegistry { const std::string& grad_op_type) { OperatorRegistrar reg(op_type.c_str()); reg.info.grad_op_type_ = grad_op_type; + ShapeInferenceMap::Instance().CreateOpWithKernel(reg.info, op_type); // register gradient op if (!grad_op_type.empty()) { OperatorRegistrar grad_reg(grad_op_type.c_str()); + ShapeInferenceMap::Instance().CreateOpWithKernel(grad_reg.info, + grad_op_type); } } diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index bc8af0eb3ec7e8685eb7d2734e9b8f75372d1309..ac6f238638cfd0ff025dbe4048a0ffc865e2b0e6 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/framework/attribute.h" #include "paddle/framework/ddim.h" namespace paddle { diff --git a/paddle/framework/shape_inference_map.cc b/paddle/framework/shape_inference_map.cc new file mode 100644 index 0000000000000000000000000000000000000000..1a27037221a9e53f513c32a83aa9f63a3866420d --- /dev/null +++ b/paddle/framework/shape_inference_map.cc @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/shape_inference_map.h" + +namespace paddle { +namespace framework { + +static VariableNameMap ConvertOpProtoVarsToVarNameMap( + const google::protobuf::RepeatedPtrField& op_proto_vars) { + VariableNameMap ret_val; + for (auto& var : op_proto_vars) { + ret_val[var.name()] = {}; + } + return ret_val; +} + +static ShapeInferenceMap* g_shape_inference_map = nullptr; + +ShapeInferenceMap& ShapeInferenceMap::Instance() { + if (g_shape_inference_map == nullptr) { + g_shape_inference_map = new ShapeInferenceMap(); + } + return *g_shape_inference_map; +} + +void ShapeInferenceMap::CreateOpWithKernel(const OpInfo& op_info, + const std::string& op_type) { + const VariableNameMap inputs = + ConvertOpProtoVarsToVarNameMap(op_info.Proto().inputs()); + const VariableNameMap outputs = + ConvertOpProtoVarsToVarNameMap(op_info.Proto().outputs()); + auto* op = op_info.Creator()(op_type, inputs, outputs, {}); + auto* op_with_kernel = dynamic_cast(op); + auto it = op_shape_inference_map_.find(op_type); + if (it != op_shape_inference_map_.end()) { + PADDLE_THROW("OpWithKernel(%s) is already registered for infer_shape", + op_type); + } + if (op_with_kernel != nullptr) { + op_shape_inference_map_[op_type] = op_with_kernel; + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/shape_inference_map.h b/paddle/framework/shape_inference_map.h new file mode 100644 index 0000000000000000000000000000000000000000..fb126690268b1a4ad9635df5d3eeb4b00479e6a7 --- /dev/null +++ b/paddle/framework/shape_inference_map.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/framework/op_info.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/shape_inference.h" + +namespace paddle { +namespace framework { + +class ShapeInferenceMap { + public: + static ShapeInferenceMap& Instance(); + + const OperatorBase* GetOperator(const std::string& op_type) { + auto it = op_shape_inference_map_.find(op_type); + if (it == op_shape_inference_map_.end()) { + PADDLE_THROW("op with kernel for Op(%s) is not registered", op_type); + } + return it->second; + } + + void CreateOpWithKernel(const OpInfo& op_info, const std::string& op_type); + + OperatorWithKernel* GetOpWithKernel(const std::string& op_type) { + auto it = op_shape_inference_map_.find(op_type); + if (it == op_shape_inference_map_.end()) { + return nullptr; + } + return it->second; + } + + private: + ShapeInferenceMap() = default; + DISABLE_COPY_AND_ASSIGN(ShapeInferenceMap); + + std::unordered_map op_shape_inference_map_; +}; + +} // namespace framework +} // namespace paddle