提交 caf4b937 编写于 作者: B baojun-nervana

Added RunInferShape

test=develop
上级 1d19eb2b
......@@ -278,43 +278,26 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphOperator::backend_ =
ngraph::runtime::Backend::create("CPU");
void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
RuntimeInferShapeContext infer_shape_ctx(*op, scope_);
std::shared_ptr<OperatorWithKernel> op_k =
std::dynamic_pointer_cast<OperatorWithKernel>(op);
op_k->InferShape(&infer_shape_ctx);
op->RunInferShape(scope_, place_);
for (auto& var_name_item : op->Inputs()) {
std::vector<ngraph::Shape> vshape;
auto& var_prm_name = var_name_item.first;
auto var_name_size = var_name_item.second.size();
if (var_name_size == 1) {
auto dim = infer_shape_ctx.GetInputDim(var_prm_name);
vshape.push_back(Ddim2Shape(dim));
} else if (var_name_item.second.size() > 1) {
auto vdim = infer_shape_ctx.GetInputsDim(var_prm_name);
PADDLE_ENFORCE_EQ(vdim.size(), var_name_item.second.size(),
"Need dim info for each var");
for (auto& dim : vdim) {
vshape.push_back(Ddim2Shape(dim));
}
} else {
// 0 size : conv2d Bias
}
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var_name = var_name_item.second.at(i);
for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name);
if (var && VarIsTensor(*var)) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto sp = Ddim2Shape(tensor_pd->dims());
if (std::find(var_in_.begin(), var_in_.end(), var_name) !=
var_in_.end()) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto ng_type = var_type_map_.at(var_name);
auto prm = std::make_shared<ngraph::op::Parameter>(
ng_type, vshape.at(i), true);
auto prm =
std::make_shared<ngraph::op::Parameter>(ng_type, sp, true);
(*var_node_map_)[var_name] = prm;
(*var_in_node_map_)[var_name] = prm;
}
}
}
}
}
}
void NgraphOperator::BuildNgNode() {
......
......@@ -355,7 +355,7 @@ void OperatorBase::GenerateTemporaryNames() {
}
}
static bool VarIsTensor(const Variable& var) {
bool VarIsTensor(const Variable& var) {
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
}
......@@ -695,6 +695,12 @@ static void CheckTensorNANOrInf(const std::string& name,
"Tensor %s contains NAN", name);
}
void OperatorWithKernel::RunInferShape(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
}
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
......
......@@ -64,6 +64,7 @@ inline std::string GradVarName(const std::string& var_name) {
}
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
bool VarIsTensor(const Variable& var);
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
......@@ -128,6 +129,8 @@ class OperatorBase {
virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
virtual void RunInferShape(const Scope& scope,
const platform::Place& place) const {}
protected:
std::string type_;
......@@ -348,6 +351,9 @@ class OperatorWithKernel : public OperatorBase {
OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
}
void RunInferShape(const Scope& scope,
const platform::Place& place) const override;
protected:
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetKernelTypeForVar(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册