diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index f04b6c503a95e64c700ab0d9b258e9118c35260b..8138ba117aac917c357a1511151bf9be23f444e0 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -55,6 +55,10 @@ class OpRegistry { const std::string& grad_op_type) { OperatorRegistrar reg(op_type.c_str()); reg.info.grad_op_type_ = grad_op_type; + auto proto = reg.info.Proto(); + std::cout << "====== " << op_type << " =======" << std::endl; + std::cout << proto.SerializeAsString() << std::endl; + std::cout << "=============" << std::endl; ShapeInferenceMap::Instance().CreateOpWithKernel(reg.info, op_type); // register gradient op if (!grad_op_type.empty()) { diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 99f721cc671e5b62912b9c201db9d6b2da097003..458404af6d3c45e745801cc6a0049751e3941214 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -598,9 +598,9 @@ class OperatorWithKernel : public OperatorBase { }); } - protected: virtual void InferShape(InferShapeContextBase* ctx) const = 0; + protected: // indicate kernel DataType by input data. Defaultly all input data must be // same. virtual DataType IndicateDataType(const ExecutionContext& ctx) const { diff --git a/paddle/framework/shape_inference_map.cc b/paddle/framework/shape_inference_map.cc index 1a27037221a9e53f513c32a83aa9f63a3866420d..bd2b8679841a9a681f87b872ea2aef3d46d041e3 100644 --- a/paddle/framework/shape_inference_map.cc +++ b/paddle/framework/shape_inference_map.cc @@ -37,10 +37,13 @@ ShapeInferenceMap& ShapeInferenceMap::Instance() { void ShapeInferenceMap::CreateOpWithKernel(const OpInfo& op_info, const std::string& op_type) { - const VariableNameMap inputs = - ConvertOpProtoVarsToVarNameMap(op_info.Proto().inputs()); + auto proto = op_info.Proto(); + std::cout << "========= " << op_type << " in======" << std::endl; + std::cout << proto.SerializeAsString() << std::endl; + std::cout << "========= " << op_type << " out======" << std::endl; + const VariableNameMap inputs = ConvertOpProtoVarsToVarNameMap(proto.inputs()); const VariableNameMap outputs = - ConvertOpProtoVarsToVarNameMap(op_info.Proto().outputs()); + ConvertOpProtoVarsToVarNameMap(proto.outputs()); auto* op = op_info.Creator()(op_type, inputs, outputs, {}); auto* op_with_kernel = dynamic_cast(op); auto it = op_shape_inference_map_.find(op_type); diff --git a/paddle/framework/shape_inference_map.h b/paddle/framework/shape_inference_map.h index fb126690268b1a4ad9635df5d3eeb4b00479e6a7..6c7304f6c0ccf3b5e05cdc9315244874ab82a58c 100644 --- a/paddle/framework/shape_inference_map.h +++ b/paddle/framework/shape_inference_map.h @@ -27,14 +27,6 @@ class ShapeInferenceMap { public: static ShapeInferenceMap& Instance(); - const OperatorBase* GetOperator(const std::string& op_type) { - auto it = op_shape_inference_map_.find(op_type); - if (it == op_shape_inference_map_.end()) { - PADDLE_THROW("op with kernel for Op(%s) is not registered", op_type); - } - return it->second; - } - void CreateOpWithKernel(const OpInfo& op_info, const std::string& op_type); OperatorWithKernel* GetOpWithKernel(const std::string& op_type) { diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index f4121e9d71824296770f86c1e94c096f767dec0a..e11bcc0e0f055544a8b6e4bfadfc0204abe43aea 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -223,6 +223,15 @@ All parameter, weight, gradient are variables in Paddle. desc.InitializationErrorString()); return OpRegistry::CreateOp(desc); }) + .def("infer_shape", + [](const OpDescBind &op_desc, BlockDescBind &block) { + auto &shape_inference_map = ShapeInferenceMap::Instance(); + auto *op = shape_inference_map.GetOpWithKernel(op_desc.Type()); + if (op != nullptr) { + auto ctx = CompileTimeInferShapeContext(op_desc, block); + op->InferShape(&ctx); + } + }) .def("backward", [](const OperatorBase &forwardOp, const std::unordered_set &no_grad_vars) { diff --git a/python/paddle/v2/framework/tests/test_infer_shape.py b/python/paddle/v2/framework/tests/test_infer_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..56d3a90123fa111cac04bb6cc11d173e736f86cd --- /dev/null +++ b/python/paddle/v2/framework/tests/test_infer_shape.py @@ -0,0 +1,29 @@ +import unittest +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator + + +class TestInferShape(unittest.TestCase): + def test_sum_op(self): + prog = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + + # prepare input/output + x1 = block.new_var("x1") + x1.set_shape([10, 20]) + x2 = block.new_var("x2") + x2.set_shape([10, 20]) + + out = block.new_var("out") + + # prepare the operator + sum_op_desc = block.append_op() + sum_op_desc.set_type("sum") + sum_op_desc.set_input("X", ["x1", "x2"]) + sum_op_desc.set_output("Out", ["out"]) + + sum_op = Operator("sum", X=["x1", "x2"], Out="out") + sum_op.infer_shape(sum_op_desc, block) + print(out.shape())