提交 d550380e 编写于 作者: Q qiaolongfei

add CompileTimeInferShapeContext

上级 31bdb3f3
...@@ -23,7 +23,7 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc ...@@ -23,7 +23,7 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc)
......
...@@ -34,6 +34,11 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const { ...@@ -34,6 +34,11 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const {
return it->second.get(); return it->second.get();
} }
bool BlockDescBind::HasVar(const std::string &name) const {
auto it = vars_.find(name);
return it != vars_.end();
}
std::vector<VarDescBind *> BlockDescBind::AllVars() const { std::vector<VarDescBind *> BlockDescBind::AllVars() const {
std::vector<VarDescBind *> res; std::vector<VarDescBind *> res;
for (const auto &p : vars_) { for (const auto &p : vars_) {
......
...@@ -45,6 +45,8 @@ class BlockDescBind { ...@@ -45,6 +45,8 @@ class BlockDescBind {
VarDescBind *Var(const std::string &name_bytes) const; VarDescBind *Var(const std::string &name_bytes) const;
bool HasVar(const std::string &var_name) const;
std::vector<VarDescBind *> AllVars() const; std::vector<VarDescBind *> AllVars() const;
BlockDescBind *ParentBlock() const; BlockDescBind *ParentBlock() const;
......
...@@ -319,100 +319,82 @@ class ExecutionContext : public InferShapeContext { ...@@ -319,100 +319,82 @@ class ExecutionContext : public InferShapeContext {
class CompileTimeInferShapeContext : public InferShapeContextBase { class CompileTimeInferShapeContext : public InferShapeContextBase {
public: public:
CompileTimeInferShapeContext(const OperatorBase& op, const Scope& scope) CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
: op_(op), scope_(scope) {} : op_(op), block_(block) {}
bool HasInput(const std::string& name) const { bool HasInput(const std::string& name) const {
auto ipt = op_.Input(name); const std::vector<std::string>& input_names = op_.Input(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); PADDLE_ENFORCE_EQ(input_names.size(), 1UL, "Inputs(%s) length is not 1",
return var != nullptr; name);
return block_.HasVar(input_names[0]);
} }
bool HasOutput(const std::string& name) const { bool HasOutput(const std::string& name) const {
auto ipt = op_.Output(name); const std::vector<std::string>& output_names = op_.Output(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); PADDLE_ENFORCE_EQ(output_names.size(), 1UL, "Outputs(%s) length is not 1",
return var != nullptr; name);
return block_.HasVar(output_names[0]);
} }
bool HasInputs(const std::string& name) const { bool HasInputs(const std::string& name) const {
auto inputs = op_.Inputs(name); const std::vector<std::string>& input_names = op_.Input(name);
if (inputs.size() == 0UL) { PADDLE_ENFORCE_GT(input_names.size(), 0UL, "Inputs(%s) length is 0", name);
return false; for (auto& input : input_names) {
} if (!block_.HasVar(input)) return false;
for (auto& input : inputs) {
if (scope_.FindVar(input) == nullptr) {
return false;
}
} }
return true; return true;
} }
bool HasOutputs(const std::string& name) const { bool HasOutputs(const std::string& name) const {
auto outputs = op_.Outputs(name); const std::vector<std::string>& output_names = op_.Output(name);
if (outputs.size() == 0UL) { PADDLE_ENFORCE_GT(output_names.size(), 0UL, "Inputs(%s) length is 0", name);
return false; for (auto& output : output_names) {
} if (!block_.HasVar(name)) return false;
for (auto& output : outputs) {
if (scope_.FindVar(output) == nullptr) {
return false;
}
} }
return true; return true;
} }
DDim GetInputDim(const std::string& name) const { DDim GetInputDim(const std::string& name) const {
return GetDim(op_.Input(name)); std::vector<DDim> ddims = GetInputsDim(name);
PADDLE_ENFORCE_EQ(ddims.size(), 1UL, "Inputs(%s) length is not 1", name);
return ddims[0];
} }
void SetInputDim(const std::string& name, const DDim& dim) { void SetInputDim(const std::string& name, const DDim& dim) {
SetDim(op_.Input(name), dim); SetInputsDim(name, {dim});
} }
DDim GetOutputDim(const std::string& name) const { DDim GetOutputDim(const std::string& name) const {
return GetDim(op_.Output(name)); std::vector<DDim> ddims = GetOutputsDim(name);
PADDLE_ENFORCE_EQ(ddims.size(), 1UL, "Outputs(%s) length is not 1", name);
return ddims[0];
} }
void SetOutputDim(const std::string& name, const DDim& dim) { void SetOutputDim(const std::string& name, const DDim& dim) {
SetDim(op_.Output(name), dim); SetOutputsDim(name, {dim});
} }
AttrReader Attrs() const { return AttrReader(op_.Attrs()); } AttrReader Attrs() const { return AttrReader(op_.GetAttrMap()); }
const std::vector<std::string>& Inputs(const std::string& name) const { const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name); return op_.Input(name);
} }
const std::vector<std::string>& Outputs(const std::string& name) const { const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name); return op_.Output(name);
} }
private: private:
template <bool Allocate>
Tensor* GetTensor(const std::string& name) const {
Tensor* t = nullptr;
auto* var = scope_.FindVar(name);
if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
if (Allocate) {
t = var->GetMutable<LoDTensor>();
} else {
PADDLE_THROW("Variable(%s) should be tensor", name);
}
} else {
t = GetTensorFromVar(scope_.FindVar(name));
}
return t;
}
DDim GetDim(const std::string& name) const { DDim GetDim(const std::string& name) const {
return GetTensor<false>(name)->dims(); return framework::make_ddim(block_.Var(name)->Shape());
} }
void SetDim(const std::string& name, const DDim& dim) { void SetDim(const std::string& name, const DDim& dim) {
GetTensor<true>(name)->Resize(dim); block_.Var(name)->SetShape(framework::vectorize(dim));
} }
const OperatorBase& op_; const OpDescBind& op_;
const Scope& scope_; const BlockDescBind& block_;
}; };
class RuntimeInferShapeContext : public InferShapeContextBase { class RuntimeInferShapeContext : public InferShapeContextBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册