diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 986b45451fe71c81d0ba9cb4d250cea972bfee68..a2efcdb55cfc75a4f961533d16d454ca6d431990 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -26,10 +26,8 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) -cc_library(shape_inference_map SRCS shape_inference_map.cc DEPS op_info operator) - cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc) -cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info shape_inference_map) +cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry sum_op) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 8138ba117aac917c357a1511151bf9be23f444e0..ee02da7b4dcb809dcbf1ae6fae5d7e28a5a5a412 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/framework/grad_op_builder.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" -#include "paddle/framework/shape_inference_map.h" namespace paddle { namespace framework { @@ -55,16 +54,10 @@ class OpRegistry { const std::string& grad_op_type) { OperatorRegistrar reg(op_type.c_str()); reg.info.grad_op_type_ = grad_op_type; - auto proto = reg.info.Proto(); - std::cout << "====== " << op_type << " =======" << std::endl; - std::cout << proto.SerializeAsString() << std::endl; - std::cout << "=============" << std::endl; - ShapeInferenceMap::Instance().CreateOpWithKernel(reg.info, op_type); + // register gradient op if (!grad_op_type.empty()) { OperatorRegistrar grad_reg(grad_op_type.c_str()); - ShapeInferenceMap::Instance().CreateOpWithKernel(grad_reg.info, - grad_op_type); } } diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index ac6f238638cfd0ff025dbe4048a0ffc865e2b0e6..8189823c1943e73308f18ee3ca58676b3c4ab636 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -20,6 +20,10 @@ limitations under the License. */ namespace paddle { namespace framework { +class InferShapeContextBase; + +typedef std::function InferShapeFn; + class InferShapeContextBase { public: virtual ~InferShapeContextBase() {} diff --git a/paddle/framework/shape_inference_map.cc b/paddle/framework/shape_inference_map.cc deleted file mode 100644 index bd2b8679841a9a681f87b872ea2aef3d46d041e3..0000000000000000000000000000000000000000 --- a/paddle/framework/shape_inference_map.cc +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/framework/shape_inference_map.h" - -namespace paddle { -namespace framework { - -static VariableNameMap ConvertOpProtoVarsToVarNameMap( - const google::protobuf::RepeatedPtrField& op_proto_vars) { - VariableNameMap ret_val; - for (auto& var : op_proto_vars) { - ret_val[var.name()] = {}; - } - return ret_val; -} - -static ShapeInferenceMap* g_shape_inference_map = nullptr; - -ShapeInferenceMap& ShapeInferenceMap::Instance() { - if (g_shape_inference_map == nullptr) { - g_shape_inference_map = new ShapeInferenceMap(); - } - return *g_shape_inference_map; -} - -void ShapeInferenceMap::CreateOpWithKernel(const OpInfo& op_info, - const std::string& op_type) { - auto proto = op_info.Proto(); - std::cout << "========= " << op_type << " in======" << std::endl; - std::cout << proto.SerializeAsString() << std::endl; - std::cout << "========= " << op_type << " out======" << std::endl; - const VariableNameMap inputs = ConvertOpProtoVarsToVarNameMap(proto.inputs()); - const VariableNameMap outputs = - ConvertOpProtoVarsToVarNameMap(proto.outputs()); - auto* op = op_info.Creator()(op_type, inputs, outputs, {}); - auto* op_with_kernel = dynamic_cast(op); - auto it = op_shape_inference_map_.find(op_type); - if (it != op_shape_inference_map_.end()) { - PADDLE_THROW("OpWithKernel(%s) is already registered for infer_shape", - op_type); - } - if (op_with_kernel != nullptr) { - op_shape_inference_map_[op_type] = op_with_kernel; - } -} - -} // namespace framework -} // namespace paddle diff --git a/paddle/framework/shape_inference_map.h b/paddle/framework/shape_inference_map.h deleted file mode 100644 index 6c7304f6c0ccf3b5e05cdc9315244874ab82a58c..0000000000000000000000000000000000000000 --- a/paddle/framework/shape_inference_map.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include "paddle/framework/op_info.h" -#include "paddle/framework/operator.h" -#include "paddle/framework/shape_inference.h" - -namespace paddle { -namespace framework { - -class ShapeInferenceMap { - public: - static ShapeInferenceMap& Instance(); - - void CreateOpWithKernel(const OpInfo& op_info, const std::string& op_type); - - OperatorWithKernel* GetOpWithKernel(const std::string& op_type) { - auto it = op_shape_inference_map_.find(op_type); - if (it == op_shape_inference_map_.end()) { - return nullptr; - } - return it->second; - } - - private: - ShapeInferenceMap() = default; - DISABLE_COPY_AND_ASSIGN(ShapeInferenceMap); - - std::unordered_map op_shape_inference_map_; -}; - -} // namespace framework -} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index e11bcc0e0f055544a8b6e4bfadfc0204abe43aea..2ad0344c094a65bbf40ced10f26685f40fab0fd3 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -223,15 +223,21 @@ All parameter, weight, gradient are variables in Paddle. desc.InitializationErrorString()); return OpRegistry::CreateOp(desc); }) - .def("infer_shape", - [](const OpDescBind &op_desc, BlockDescBind &block) { - auto &shape_inference_map = ShapeInferenceMap::Instance(); - auto *op = shape_inference_map.GetOpWithKernel(op_desc.Type()); - if (op != nullptr) { - auto ctx = CompileTimeInferShapeContext(op_desc, block); - op->InferShape(&ctx); - } - }) + .def_static("infer_shape", + [](OpDescBind &op_desc, BlockDescBind &block) { + auto op = OpRegistry::CreateOp(*op_desc.Proto()); + auto *op_with_kernel = + dynamic_cast(op.get()); + if (op_with_kernel != nullptr) { + auto ctx = CompileTimeInferShapeContext(op_desc, block); + op_with_kernel->InferShape(&ctx); + } else { + PADDLE_THROW( + "OP(%s) is not type of OperatorWithKernel, " + "should not call this function", + op_desc.Type()); + } + }) .def("backward", [](const OperatorBase &forwardOp, const std::unordered_set &no_grad_vars) { diff --git a/python/paddle/v2/framework/tests/test_infer_shape.py b/python/paddle/v2/framework/tests/test_infer_shape.py index 56d3a90123fa111cac04bb6cc11d173e736f86cd..ec93aaf84370d461254f68c1b352d991680de698 100644 --- a/python/paddle/v2/framework/tests/test_infer_shape.py +++ b/python/paddle/v2/framework/tests/test_infer_shape.py @@ -10,11 +10,13 @@ class TestInferShape(unittest.TestCase): block = prog.block(0) self.assertIsNotNone(block) + shape = [10, 20] + # prepare input/output x1 = block.new_var("x1") - x1.set_shape([10, 20]) + x1.set_shape(shape) x2 = block.new_var("x2") - x2.set_shape([10, 20]) + x2.set_shape(shape) out = block.new_var("out") @@ -24,6 +26,40 @@ class TestInferShape(unittest.TestCase): sum_op_desc.set_input("X", ["x1", "x2"]) sum_op_desc.set_output("Out", ["out"]) - sum_op = Operator("sum", X=["x1", "x2"], Out="out") - sum_op.infer_shape(sum_op_desc, block) - print(out.shape()) + print(type(sum_op_desc)) + print(type(block)) + core.Operator.infer_shape(sum_op_desc, block) + self.assertEqual(out.shape(), shape) + + def test_mul_op(self): + prog = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + + x_shape = [10, 20] + y_shape = [20, 30] + + # prepare input/output + x1 = block.new_var("x") + x1.set_shape(x_shape) + x2 = block.new_var("y") + x2.set_shape(y_shape) + + out = block.new_var("out") + + # prepare the operator + mul_op_desc = block.append_op() + mul_op_desc.set_type("mul") + mul_op_desc.set_input("X", ["x"]) + mul_op_desc.set_input("Y", ["y"]) + mul_op_desc.set_output("Out", ["out"]) + mul_op_desc.set_attr("x_num_col_dims", 1) + mul_op_desc.set_attr("y_num_col_dims", 1) + + core.Operator.infer_shape(mul_op_desc, block) + self.assertEqual(out.shape(), [x_shape[0], y_shape[1]]) + + +if __name__ == '__main__': + unittest.main()