提交 f8b5d54c 编写于 作者: Y Yi Wang 提交者: GitHub

Merge pull request #4569 from jacquesqiao/add_compile_time_infershape

Add compile time infershape
...@@ -23,7 +23,7 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc ...@@ -23,7 +23,7 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc)
......
...@@ -34,6 +34,10 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const { ...@@ -34,6 +34,10 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const {
return it->second.get(); return it->second.get();
} }
bool BlockDescBind::HasVar(const std::string &name) const {
return vars_.find(name) != vars_.end();
}
std::vector<VarDescBind *> BlockDescBind::AllVars() const { std::vector<VarDescBind *> BlockDescBind::AllVars() const {
std::vector<VarDescBind *> res; std::vector<VarDescBind *> res;
for (const auto &p : vars_) { for (const auto &p : vars_) {
......
...@@ -43,6 +43,8 @@ class BlockDescBind { ...@@ -43,6 +43,8 @@ class BlockDescBind {
VarDescBind *Var(const std::string &name_bytes) const; VarDescBind *Var(const std::string &name_bytes) const;
bool HasVar(const std::string &var_name) const;
std::vector<VarDescBind *> AllVars() const; std::vector<VarDescBind *> AllVars() const;
BlockDescBind *ParentBlock() const; BlockDescBind *ParentBlock() const;
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "op_info.h" #include "op_info.h"
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/data_type.h" #include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
...@@ -317,26 +318,122 @@ class ExecutionContext : public InferShapeContext { ...@@ -317,26 +318,122 @@ class ExecutionContext : public InferShapeContext {
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
}; };
class CompileTimeInferShapeContext : public InferShapeContextBase {
public:
CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
: op_(op), block_(block) {}
bool HasInput(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name);
auto length = input_names.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have only one value, "
"but it have %d now",
name, length);
return block_.HasVar(input_names[0]);
}
bool HasOutput(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name);
auto length = output_names.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Output(%s) should have only one value, "
"but it have %d now",
name, length);
return block_.HasVar(output_names[0]);
}
bool HasInputs(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name);
PADDLE_ENFORCE(!input_names.empty(), "Inputs(%s) length is 0", name);
for (auto& input : input_names) {
if (!block_.HasVar(input)) return false;
}
return true;
}
bool HasOutputs(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name);
PADDLE_ENFORCE(!output_names.empty(), "Inputs(%s) length is 0", name);
for (auto& output : output_names) {
if (!block_.HasVar(output)) return false;
}
return true;
}
DDim GetInputDim(const std::string& name) const override {
std::vector<DDim> ddims = GetInputsDim(name);
auto length = ddims.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have 1 value, "
"but it has %d now",
name, length);
return ddims[0];
}
void SetInputDim(const std::string& name, const DDim& dim) override {
SetInputsDim(name, {dim});
}
DDim GetOutputDim(const std::string& name) const override {
std::vector<DDim> ddims = GetOutputsDim(name);
auto length = ddims.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Output(%s) should have 1 value, "
"but it has %d now",
name, length);
return ddims[0];
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
SetOutputsDim(name, {dim});
}
AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); }
const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Input(name);
}
const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Output(name);
}
private:
DDim GetDim(const std::string& name) const override {
return framework::make_ddim(block_.Var(name)->Shape());
}
void SetDim(const std::string& name, const DDim& dim) override {
block_.Var(name)->SetShape(framework::vectorize(dim));
}
const OpDescBind& op_;
const BlockDescBind& block_;
};
class RuntimeInferShapeContext : public InferShapeContextBase { class RuntimeInferShapeContext : public InferShapeContextBase {
public: public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {} : op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const { bool HasInput(const std::string& name) const override {
auto ipt = op_.Input(name); auto ipt = op_.Input(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr; return var != nullptr;
} }
bool HasOutput(const std::string& name) const { bool HasOutput(const std::string& name) const override {
auto ipt = op_.Output(name); auto ipt = op_.Output(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr; return var != nullptr;
} }
bool HasInputs(const std::string& name) const { bool HasInputs(const std::string& name) const override {
auto inputs = op_.Inputs(name); auto inputs = op_.Inputs(name);
if (inputs.size() == 0UL) { if (inputs.empty()) {
return false; return false;
} }
for (auto& input : inputs) { for (auto& input : inputs) {
...@@ -347,9 +444,9 @@ class RuntimeInferShapeContext : public InferShapeContextBase { ...@@ -347,9 +444,9 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
return true; return true;
} }
bool HasOutputs(const std::string& name) const { bool HasOutputs(const std::string& name) const override {
auto outputs = op_.Outputs(name); auto outputs = op_.Outputs(name);
if (outputs.size() == 0UL) { if (outputs.empty()) {
return false; return false;
} }
for (auto& output : outputs) { for (auto& output : outputs) {
...@@ -360,29 +457,31 @@ class RuntimeInferShapeContext : public InferShapeContextBase { ...@@ -360,29 +457,31 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
return true; return true;
} }
DDim GetInputDim(const std::string& name) const { DDim GetInputDim(const std::string& name) const override {
return GetDim(op_.Input(name)); return GetDim(op_.Input(name));
} }
void SetInputDim(const std::string& name, const DDim& dim) { void SetInputDim(const std::string& name, const DDim& dim) override {
SetDim(op_.Input(name), dim); SetDim(op_.Input(name), dim);
} }
DDim GetOutputDim(const std::string& name) const { DDim GetOutputDim(const std::string& name) const override {
return GetDim(op_.Output(name)); return GetDim(op_.Output(name));
} }
void SetOutputDim(const std::string& name, const DDim& dim) { void SetOutputDim(const std::string& name, const DDim& dim) override {
SetDim(op_.Output(name), dim); SetDim(op_.Output(name), dim);
} }
AttrReader Attrs() const { return AttrReader(op_.Attrs()); } AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
const std::vector<std::string>& Inputs(const std::string& name) const { const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Inputs(name); return op_.Inputs(name);
} }
const std::vector<std::string>& Outputs(const std::string& name) const { const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Outputs(name); return op_.Outputs(name);
} }
...@@ -403,11 +502,11 @@ class RuntimeInferShapeContext : public InferShapeContextBase { ...@@ -403,11 +502,11 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
return t; return t;
} }
DDim GetDim(const std::string& name) const { DDim GetDim(const std::string& name) const override {
return GetTensor<false>(name)->dims(); return GetTensor<false>(name)->dims();
} }
void SetDim(const std::string& name, const DDim& dim) { void SetDim(const std::string& name, const DDim& dim) override {
GetTensor<true>(name)->Resize(dim); GetTensor<true>(name)->Resize(dim);
} }
...@@ -513,9 +612,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -513,9 +612,9 @@ class OperatorWithKernel : public OperatorBase {
}); });
} }
protected:
virtual void InferShape(InferShapeContextBase* ctx) const = 0; virtual void InferShape(InferShapeContextBase* ctx) const = 0;
protected:
// indicate kernel DataType by input data. Defaultly all input data must be // indicate kernel DataType by input data. Defaultly all input data must be
// same. // same.
virtual DataType IndicateDataType(const ExecutionContext& ctx) const { virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
......
...@@ -19,6 +19,9 @@ limitations under the License. */ ...@@ -19,6 +19,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// TODO(longfei): Once after both CompileTimeInferShapeContext and
// RuntimeInferShapeContext get merged, we can rename InferShapeContextBase into
// InferShapeContext so to replace the current InferShapeContext.
class InferShapeContextBase { class InferShapeContextBase {
public: public:
virtual ~InferShapeContextBase() {} virtual ~InferShapeContextBase() {}
......
...@@ -230,6 +230,21 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -230,6 +230,21 @@ All parameter, weight, gradient are variables in Paddle.
desc.InitializationErrorString()); desc.InitializationErrorString());
return OpRegistry::CreateOp(desc); return OpRegistry::CreateOp(desc);
}) })
.def_static("infer_shape",
[](OpDescBind &op_desc, BlockDescBind &block) {
auto op = OpRegistry::CreateOp(*op_desc.Proto());
auto *op_with_kernel =
dynamic_cast<OperatorWithKernel *>(op.get());
if (op_with_kernel != nullptr) {
auto ctx = CompileTimeInferShapeContext(op_desc, block);
op_with_kernel->InferShape(&ctx);
} else {
PADDLE_THROW(
"OP(%s) is not type of OperatorWithKernel, "
"should not call this function",
op_desc.Type());
}
})
.def("backward", .def("backward",
[](const OperatorBase &forwardOp, [](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) { const std::unordered_set<std::string> &no_grad_vars) {
......
import unittest
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
class TestInferShape(unittest.TestCase):
def test_sum_op(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
shape = [10, 20]
# prepare input/output
x1 = block.new_var("x1")
x1.set_shape(shape)
x2 = block.new_var("x2")
x2.set_shape(shape)
out = block.new_var("out")
# prepare the operator
sum_op_desc = block.append_op()
sum_op_desc.set_type("sum")
sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"])
core.Operator.infer_shape(sum_op_desc, block)
self.assertEqual(out.shape(), shape)
def test_mul_op(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
x_shape = [10, 20]
y_shape = [20, 30]
# prepare input/output
x1 = block.new_var("x")
x1.set_shape(x_shape)
x2 = block.new_var("y")
x2.set_shape(y_shape)
out = block.new_var("out")
# prepare the operator
mul_op_desc = block.append_op()
mul_op_desc.set_type("mul")
mul_op_desc.set_input("X", ["x"])
mul_op_desc.set_input("Y", ["y"])
mul_op_desc.set_output("Out", ["out"])
mul_op_desc.set_attr("x_num_col_dims", 1)
mul_op_desc.set_attr("y_num_col_dims", 1)
core.Operator.infer_shape(mul_op_desc, block)
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册