提交 c176bbe4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!952 Simplify the `ZeroLikeFillZero` optimization pass

Merge pull request !952 from thlinh/dev_May6th_improve_zero_fill_like_zero
......@@ -52,8 +52,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ =
MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM);
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
// ops eliminate
item_tuple_eliminate_ =
......
......@@ -30,6 +30,7 @@
namespace mindspore {
namespace opt {
namespace irpass {
class SpecialOpEliminater {
public:
SpecialOpEliminater()
......@@ -156,12 +157,27 @@ class ZeroLikeFillZero : public AnfVisitor {
if (y_ == nullptr || node->func_graph() == nullptr) {
return nullptr;
}
if ((y_->abstract() == nullptr) || !y_->abstract()->isa<abstract::AbstractTensor>()) {
auto fg = node->func_graph();
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
}
abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();
tensor::TensorPtr new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
std::memset(data, 0, mem_size);
auto fg = node->func_graph();
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
auto new_cnode = NewValueNode(new_tensor_ptr);
new_cnode->set_abstract(new_tensor_ptr->ToAbstract());
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
return new_cnode;
}
void Visit(const AnfNodePtr &node) override { y_ = node; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册