提交 d5ee05e6 编写于 作者: B baojun-nervana

Replaced VarIsTensor

test=develop
上级 e6bd53be
...@@ -283,7 +283,7 @@ void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) { ...@@ -283,7 +283,7 @@ void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
for (auto& var_name_item : op->Inputs()) { for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name); auto* var = scope_.FindVar(var_name);
if (var && VarIsTensor(*var)) { if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto sp = Ddim2Shape(tensor_pd->dims()); auto sp = Ddim2Shape(tensor_pd->dims());
if (std::find(var_in_.begin(), var_in_.end(), var_name) != if (std::find(var_in_.begin(), var_in_.end(), var_name) !=
...@@ -305,7 +305,7 @@ void NgraphOperator::BuildNgNode() { ...@@ -305,7 +305,7 @@ void NgraphOperator::BuildNgNode() {
for (auto& var_name : var_out_) { for (auto& var_name : var_out_) {
if (var_node_map_->find(var_name) == var_node_map_->end()) { if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto* var = scope_.FindVar(var_name); auto* var = scope_.FindVar(var_name);
if (var && VarIsTensor(*var)) { if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims(); auto& ddim = tensor_pd->dims();
auto ng_shape = Ddim2Shape(ddim); auto ng_shape = Ddim2Shape(ddim);
...@@ -433,7 +433,7 @@ std::shared_ptr<std::string> NgraphOperator::GetCacheKey() { ...@@ -433,7 +433,7 @@ std::shared_ptr<std::string> NgraphOperator::GetCacheKey() {
for (auto& var_name : var_out_) { for (auto& var_name : var_out_) {
auto* var = scope_.FindVar(var_name); auto* var = scope_.FindVar(var_name);
if (var && VarIsTensor(*var)) { if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims(); auto& ddim = tensor_pd->dims();
for (int i = 0; i < ddim.size(); ++i) { for (int i = 0; i < ddim.size(); ++i) {
...@@ -469,7 +469,7 @@ void NgraphOperator::Run(const Scope& scope, ...@@ -469,7 +469,7 @@ void NgraphOperator::Run(const Scope& scope,
auto sp = var_node_map_->at(vi)->get_shape(); auto sp = var_node_map_->at(vi)->get_shape();
std::shared_ptr<ngraph::runtime::Tensor> ti; std::shared_ptr<ngraph::runtime::Tensor> ti;
auto* var = scope.FindVar(vi); auto* var = scope.FindVar(vi);
if (var && VarIsTensor(*var)) { if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()), PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()),
"Ensure ngraph tensor layout align with paddle tensor"); "Ensure ngraph tensor layout align with paddle tensor");
...@@ -518,7 +518,7 @@ void NgraphOperator::Run(const Scope& scope, ...@@ -518,7 +518,7 @@ void NgraphOperator::Run(const Scope& scope,
auto var_name = var_out_[i]; auto var_name = var_out_[i];
auto* var = scope.FindVar(var_name); auto* var = scope.FindVar(var_name);
std::shared_ptr<ngraph::runtime::Tensor> to; std::shared_ptr<ngraph::runtime::Tensor> to;
if (var && VarIsTensor(*var)) { if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var); auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
auto dd = tensor_pd->dims(); auto dd = tensor_pd->dims();
ngraph::Shape sp = Ddim2Shape(dd); ngraph::Shape sp = Ddim2Shape(dd);
......
...@@ -355,7 +355,7 @@ void OperatorBase::GenerateTemporaryNames() { ...@@ -355,7 +355,7 @@ void OperatorBase::GenerateTemporaryNames() {
} }
} }
bool VarIsTensor(const Variable& var) { static bool VarIsTensor(const Variable& var) {
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>(); return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
} }
......
...@@ -64,7 +64,6 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -64,7 +64,6 @@ inline std::string GradVarName(const std::string& var_name) {
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
bool VarIsTensor(const Variable& var);
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册