diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 5d394132b7f3ddd36cebeb45f5602f13d9acdf35..a2efcdb55cfc75a4f961533d16d454ca6d431990 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -23,7 +23,7 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) -cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) +cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc) diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index 9570aedfdda332b797a8f348e0f6cf81bb2aee2f..01f50e1393606044fb20d5f782fadede46b744e3 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -34,6 +34,10 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const { return it->second.get(); } +bool BlockDescBind::HasVar(const std::string &name) const { + return vars_.find(name) != vars_.end(); +} + std::vector BlockDescBind::AllVars() const { std::vector res; for (const auto &p : vars_) { diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 59513ede33ebb41acbeb0e1acab66be5947a9a3c..b646ad5f3b67e22b933a56d88e4a1bf6e74a124e 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -43,6 +43,8 @@ class BlockDescBind { VarDescBind *Var(const std::string &name_bytes) const; + bool HasVar(const std::string &var_name) const; + std::vector AllVars() const; BlockDescBind *ParentBlock() const; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 73e53a4176db32b7cbfd79c088dadfc23037213f..d7bc9c9ffb9d5e0a7d8ea309a50623da440820da 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "op_info.h" #include "paddle/framework/attribute.h" +#include "paddle/framework/block_desc.h" #include "paddle/framework/data_type.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/lod_tensor.h" @@ -317,26 +318,122 @@ class ExecutionContext : public InferShapeContext { const platform::DeviceContext& device_context_; }; +class CompileTimeInferShapeContext : public InferShapeContextBase { + public: + CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block) + : op_(op), block_(block) {} + + bool HasInput(const std::string& name) const override { + const std::vector& input_names = op_.Input(name); + auto length = input_names.size(); + PADDLE_ENFORCE_EQ(length, 1UL, + "Input(%s) should have only one value, " + "but it have %d now", + name, length); + return block_.HasVar(input_names[0]); + } + + bool HasOutput(const std::string& name) const override { + const std::vector& output_names = op_.Output(name); + auto length = output_names.size(); + PADDLE_ENFORCE_EQ(length, 1UL, + "Output(%s) should have only one value, " + "but it have %d now", + name, length); + return block_.HasVar(output_names[0]); + } + + bool HasInputs(const std::string& name) const override { + const std::vector& input_names = op_.Input(name); + PADDLE_ENFORCE(!input_names.empty(), "Inputs(%s) length is 0", name); + for (auto& input : input_names) { + if (!block_.HasVar(input)) return false; + } + return true; + } + + bool HasOutputs(const std::string& name) const override { + const std::vector& output_names = op_.Output(name); + PADDLE_ENFORCE(!output_names.empty(), "Inputs(%s) length is 0", name); + for (auto& output : output_names) { + if (!block_.HasVar(output)) return false; + } + return true; + } + + DDim GetInputDim(const std::string& name) const override { + std::vector ddims = GetInputsDim(name); + auto length = ddims.size(); + PADDLE_ENFORCE_EQ(length, 1UL, + "Input(%s) should have 1 value, " + "but it has %d now", + name, length); + return ddims[0]; + } + + void SetInputDim(const std::string& name, const DDim& dim) override { + SetInputsDim(name, {dim}); + } + + DDim GetOutputDim(const std::string& name) const override { + std::vector ddims = GetOutputsDim(name); + auto length = ddims.size(); + PADDLE_ENFORCE_EQ(length, 1UL, + "Output(%s) should have 1 value, " + "but it has %d now", + name, length); + return ddims[0]; + } + + void SetOutputDim(const std::string& name, const DDim& dim) override { + SetOutputsDim(name, {dim}); + } + + AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); } + + const std::vector& Inputs( + const std::string& name) const override { + return op_.Input(name); + } + + const std::vector& Outputs( + const std::string& name) const override { + return op_.Output(name); + } + + private: + DDim GetDim(const std::string& name) const override { + return framework::make_ddim(block_.Var(name)->Shape()); + } + + void SetDim(const std::string& name, const DDim& dim) override { + block_.Var(name)->SetShape(framework::vectorize(dim)); + } + + const OpDescBind& op_; + const BlockDescBind& block_; +}; + class RuntimeInferShapeContext : public InferShapeContextBase { public: RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) : op_(op), scope_(scope) {} - bool HasInput(const std::string& name) const { + bool HasInput(const std::string& name) const override { auto ipt = op_.Input(name); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); return var != nullptr; } - bool HasOutput(const std::string& name) const { + bool HasOutput(const std::string& name) const override { auto ipt = op_.Output(name); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); return var != nullptr; } - bool HasInputs(const std::string& name) const { + bool HasInputs(const std::string& name) const override { auto inputs = op_.Inputs(name); - if (inputs.size() == 0UL) { + if (inputs.empty()) { return false; } for (auto& input : inputs) { @@ -347,9 +444,9 @@ class RuntimeInferShapeContext : public InferShapeContextBase { return true; } - bool HasOutputs(const std::string& name) const { + bool HasOutputs(const std::string& name) const override { auto outputs = op_.Outputs(name); - if (outputs.size() == 0UL) { + if (outputs.empty()) { return false; } for (auto& output : outputs) { @@ -360,29 +457,31 @@ class RuntimeInferShapeContext : public InferShapeContextBase { return true; } - DDim GetInputDim(const std::string& name) const { + DDim GetInputDim(const std::string& name) const override { return GetDim(op_.Input(name)); } - void SetInputDim(const std::string& name, const DDim& dim) { + void SetInputDim(const std::string& name, const DDim& dim) override { SetDim(op_.Input(name), dim); } - DDim GetOutputDim(const std::string& name) const { + DDim GetOutputDim(const std::string& name) const override { return GetDim(op_.Output(name)); } - void SetOutputDim(const std::string& name, const DDim& dim) { + void SetOutputDim(const std::string& name, const DDim& dim) override { SetDim(op_.Output(name), dim); } - AttrReader Attrs() const { return AttrReader(op_.Attrs()); } + AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } - const std::vector& Inputs(const std::string& name) const { + const std::vector& Inputs( + const std::string& name) const override { return op_.Inputs(name); } - const std::vector& Outputs(const std::string& name) const { + const std::vector& Outputs( + const std::string& name) const override { return op_.Outputs(name); } @@ -403,11 +502,11 @@ class RuntimeInferShapeContext : public InferShapeContextBase { return t; } - DDim GetDim(const std::string& name) const { + DDim GetDim(const std::string& name) const override { return GetTensor(name)->dims(); } - void SetDim(const std::string& name, const DDim& dim) { + void SetDim(const std::string& name, const DDim& dim) override { GetTensor(name)->Resize(dim); } @@ -513,9 +612,9 @@ class OperatorWithKernel : public OperatorBase { }); } - protected: virtual void InferShape(InferShapeContextBase* ctx) const = 0; + protected: // indicate kernel DataType by input data. Defaultly all input data must be // same. virtual DataType IndicateDataType(const ExecutionContext& ctx) const { diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index bc8af0eb3ec7e8685eb7d2734e9b8f75372d1309..74e0371e328114294d7f85932b1e551c21ff5b97 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -19,6 +19,9 @@ limitations under the License. */ namespace paddle { namespace framework { +// TODO(longfei): Once after both CompileTimeInferShapeContext and +// RuntimeInferShapeContext get merged, we can rename InferShapeContextBase into +// InferShapeContext so to replace the current InferShapeContext. class InferShapeContextBase { public: virtual ~InferShapeContextBase() {} diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index cff54b174134879a3779c7738cfc3b43a074f8d7..38ba450447386b44ee8abe71c3c8b6427bbc398c 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -230,6 +230,21 @@ All parameter, weight, gradient are variables in Paddle. desc.InitializationErrorString()); return OpRegistry::CreateOp(desc); }) + .def_static("infer_shape", + [](OpDescBind &op_desc, BlockDescBind &block) { + auto op = OpRegistry::CreateOp(*op_desc.Proto()); + auto *op_with_kernel = + dynamic_cast(op.get()); + if (op_with_kernel != nullptr) { + auto ctx = CompileTimeInferShapeContext(op_desc, block); + op_with_kernel->InferShape(&ctx); + } else { + PADDLE_THROW( + "OP(%s) is not type of OperatorWithKernel, " + "should not call this function", + op_desc.Type()); + } + }) .def("backward", [](const OperatorBase &forwardOp, const std::unordered_set &no_grad_vars) { diff --git a/python/paddle/v2/framework/tests/test_infer_shape.py b/python/paddle/v2/framework/tests/test_infer_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..b38ec9c03740a2e69f1247c094ce56ab43fa8e32 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_infer_shape.py @@ -0,0 +1,63 @@ +import unittest +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator + + +class TestInferShape(unittest.TestCase): + def test_sum_op(self): + prog = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + + shape = [10, 20] + + # prepare input/output + x1 = block.new_var("x1") + x1.set_shape(shape) + x2 = block.new_var("x2") + x2.set_shape(shape) + + out = block.new_var("out") + + # prepare the operator + sum_op_desc = block.append_op() + sum_op_desc.set_type("sum") + sum_op_desc.set_input("X", ["x1", "x2"]) + sum_op_desc.set_output("Out", ["out"]) + + core.Operator.infer_shape(sum_op_desc, block) + self.assertEqual(out.shape(), shape) + + def test_mul_op(self): + prog = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + + x_shape = [10, 20] + y_shape = [20, 30] + + # prepare input/output + x1 = block.new_var("x") + x1.set_shape(x_shape) + x2 = block.new_var("y") + x2.set_shape(y_shape) + + out = block.new_var("out") + + # prepare the operator + mul_op_desc = block.append_op() + mul_op_desc.set_type("mul") + mul_op_desc.set_input("X", ["x"]) + mul_op_desc.set_input("Y", ["y"]) + mul_op_desc.set_output("Out", ["out"]) + mul_op_desc.set_attr("x_num_col_dims", 1) + mul_op_desc.set_attr("y_num_col_dims", 1) + + core.Operator.infer_shape(mul_op_desc, block) + self.assertEqual(out.shape(), [x_shape[0], y_shape[1]]) + + +if __name__ == '__main__': + unittest.main()