提交 d5ee05e6 编写于 作者: B baojun-nervana

Replaced VarIsTensor

test=develop
上级 e6bd53be
......@@ -283,7 +283,7 @@ void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name);
if (var && VarIsTensor(*var)) {
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto sp = Ddim2Shape(tensor_pd->dims());
if (std::find(var_in_.begin(), var_in_.end(), var_name) !=
......@@ -305,7 +305,7 @@ void NgraphOperator::BuildNgNode() {
for (auto& var_name : var_out_) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto* var = scope_.FindVar(var_name);
if (var && VarIsTensor(*var)) {
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims();
auto ng_shape = Ddim2Shape(ddim);
......@@ -433,7 +433,7 @@ std::shared_ptr<std::string> NgraphOperator::GetCacheKey() {
for (auto& var_name : var_out_) {
auto* var = scope_.FindVar(var_name);
if (var && VarIsTensor(*var)) {
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims();
for (int i = 0; i < ddim.size(); ++i) {
......@@ -469,7 +469,7 @@ void NgraphOperator::Run(const Scope& scope,
auto sp = var_node_map_->at(vi)->get_shape();
std::shared_ptr<ngraph::runtime::Tensor> ti;
auto* var = scope.FindVar(vi);
if (var && VarIsTensor(*var)) {
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()),
"Ensure ngraph tensor layout align with paddle tensor");
......@@ -518,7 +518,7 @@ void NgraphOperator::Run(const Scope& scope,
auto var_name = var_out_[i];
auto* var = scope.FindVar(var_name);
std::shared_ptr<ngraph::runtime::Tensor> to;
if (var && VarIsTensor(*var)) {
if (var && var->IsType<LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
auto dd = tensor_pd->dims();
ngraph::Shape sp = Ddim2Shape(dd);
......
......@@ -355,7 +355,7 @@ void OperatorBase::GenerateTemporaryNames() {
}
}
bool VarIsTensor(const Variable& var) {
static bool VarIsTensor(const Variable& var) {
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
}
......
......@@ -64,7 +64,6 @@ inline std::string GradVarName(const std::string& var_name) {
}
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
bool VarIsTensor(const Variable& var);
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册