From 5917e09cde86401005261914964eca4ef54de193 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 4 Oct 2017 15:26:19 -0700 Subject: [PATCH] tmp work --- paddle/framework/op_registry.h | 4 +++ paddle/framework/operator.h | 2 +- paddle/framework/shape_inference_map.cc | 9 ++++-- paddle/framework/shape_inference_map.h | 8 ----- paddle/pybind/pybind.cc | 9 ++++++ .../v2/framework/tests/test_infer_shape.py | 29 +++++++++++++++++++ 6 files changed, 49 insertions(+), 12 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_infer_shape.py diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index f04b6c503..8138ba117 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -55,6 +55,10 @@ class OpRegistry { const std::string& grad_op_type) { OperatorRegistrar reg(op_type.c_str()); reg.info.grad_op_type_ = grad_op_type; + auto proto = reg.info.Proto(); + std::cout << "====== " << op_type << " =======" << std::endl; + std::cout << proto.SerializeAsString() << std::endl; + std::cout << "=============" << std::endl; ShapeInferenceMap::Instance().CreateOpWithKernel(reg.info, op_type); // register gradient op if (!grad_op_type.empty()) { diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 99f721cc6..458404af6 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -598,9 +598,9 @@ class OperatorWithKernel : public OperatorBase { }); } - protected: virtual void InferShape(InferShapeContextBase* ctx) const = 0; + protected: // indicate kernel DataType by input data. Defaultly all input data must be // same. virtual DataType IndicateDataType(const ExecutionContext& ctx) const { diff --git a/paddle/framework/shape_inference_map.cc b/paddle/framework/shape_inference_map.cc index 1a2703722..bd2b86798 100644 --- a/paddle/framework/shape_inference_map.cc +++ b/paddle/framework/shape_inference_map.cc @@ -37,10 +37,13 @@ ShapeInferenceMap& ShapeInferenceMap::Instance() { void ShapeInferenceMap::CreateOpWithKernel(const OpInfo& op_info, const std::string& op_type) { - const VariableNameMap inputs = - ConvertOpProtoVarsToVarNameMap(op_info.Proto().inputs()); + auto proto = op_info.Proto(); + std::cout << "========= " << op_type << " in======" << std::endl; + std::cout << proto.SerializeAsString() << std::endl; + std::cout << "========= " << op_type << " out======" << std::endl; + const VariableNameMap inputs = ConvertOpProtoVarsToVarNameMap(proto.inputs()); const VariableNameMap outputs = - ConvertOpProtoVarsToVarNameMap(op_info.Proto().outputs()); + ConvertOpProtoVarsToVarNameMap(proto.outputs()); auto* op = op_info.Creator()(op_type, inputs, outputs, {}); auto* op_with_kernel = dynamic_cast(op); auto it = op_shape_inference_map_.find(op_type); diff --git a/paddle/framework/shape_inference_map.h b/paddle/framework/shape_inference_map.h index fb1266902..6c7304f6c 100644 --- a/paddle/framework/shape_inference_map.h +++ b/paddle/framework/shape_inference_map.h @@ -27,14 +27,6 @@ class ShapeInferenceMap { public: static ShapeInferenceMap& Instance(); - const OperatorBase* GetOperator(const std::string& op_type) { - auto it = op_shape_inference_map_.find(op_type); - if (it == op_shape_inference_map_.end()) { - PADDLE_THROW("op with kernel for Op(%s) is not registered", op_type); - } - return it->second; - } - void CreateOpWithKernel(const OpInfo& op_info, const std::string& op_type); OperatorWithKernel* GetOpWithKernel(const std::string& op_type) { diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index f4121e9d7..e11bcc0e0 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -223,6 +223,15 @@ All parameter, weight, gradient are variables in Paddle. desc.InitializationErrorString()); return OpRegistry::CreateOp(desc); }) + .def("infer_shape", + [](const OpDescBind &op_desc, BlockDescBind &block) { + auto &shape_inference_map = ShapeInferenceMap::Instance(); + auto *op = shape_inference_map.GetOpWithKernel(op_desc.Type()); + if (op != nullptr) { + auto ctx = CompileTimeInferShapeContext(op_desc, block); + op->InferShape(&ctx); + } + }) .def("backward", [](const OperatorBase &forwardOp, const std::unordered_set &no_grad_vars) { diff --git a/python/paddle/v2/framework/tests/test_infer_shape.py b/python/paddle/v2/framework/tests/test_infer_shape.py new file mode 100644 index 000000000..56d3a9012 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_infer_shape.py @@ -0,0 +1,29 @@ +import unittest +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator + + +class TestInferShape(unittest.TestCase): + def test_sum_op(self): + prog = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + + # prepare input/output + x1 = block.new_var("x1") + x1.set_shape([10, 20]) + x2 = block.new_var("x2") + x2.set_shape([10, 20]) + + out = block.new_var("out") + + # prepare the operator + sum_op_desc = block.append_op() + sum_op_desc.set_type("sum") + sum_op_desc.set_input("X", ["x1", "x2"]) + sum_op_desc.set_output("Out", ["out"]) + + sum_op = Operator("sum", X=["x1", "x2"], Out="out") + sum_op.infer_shape(sum_op_desc, block) + print(out.shape()) -- GitLab