Added improvement for ZeroLikeFillZero optimization pass: The old algorithm...

Added improvement for ZeroLikeFillZero optimization pass: The old algorithm convert the PrimitivePy op (with 3 nodes) into a
new subtree with 9 nodes and after that a Renormalize pass is needed to simplified it back to a tensor. The new algorithm will
create the tensor while visiting in the pass, therefore only a single node is created and no Renormalize will be needed for this pass
(if other passes requires then Renormalize will still be called but no further infer is needed for the created tensor)
Signed-off-by: NHoai Linh Tran h00472437 <hoai.linh.tran@huawei.com>

Code review
上级 e53b43cb
......@@ -52,8 +52,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ =
MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM);
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
// ops eliminate
item_tuple_eliminate_ =
......
......@@ -30,6 +30,7 @@
namespace mindspore {
namespace opt {
namespace irpass {
class SpecialOpEliminater {
public:
SpecialOpEliminater()
......@@ -156,12 +157,27 @@ class ZeroLikeFillZero : public AnfVisitor {
if (y_ == nullptr || node->func_graph() == nullptr) {
return nullptr;
}
if ((y_->abstract() == nullptr) || !y_->abstract()->isa<abstract::AbstractTensor>()) {
auto fg = node->func_graph();
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
}
abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();
tensor::TensorPtr new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
std::memset(data, 0, mem_size);
auto fg = node->func_graph();
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
auto new_cnode = NewValueNode(new_tensor_ptr);
new_cnode->set_abstract(new_tensor_ptr->ToAbstract());
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
return new_cnode;
}
void Visit(const AnfNodePtr &node) override { y_ = node; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册