提交 475f62f6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!499 pynative support topk and print op

Merge pull request !499 from JoyLvliang/pynative-support-topk-and-print
......@@ -135,10 +135,11 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
}
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
const session::KernelGraph *graph) {
session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
// assign memory for input nodes
RunOpAssignInputMemory(input_tensors, graph);
AssignStaticMemoryValueNode(graph);
for (const auto &cnode : graph->execution_order()) {
// assign memory for output nodes
RunOpAssignOutputMemory(cnode);
......
......@@ -46,7 +46,7 @@ class KernelRuntime {
virtual ~KernelRuntime();
virtual bool Init() = 0;
virtual void AssignMemory(session::KernelGraph *graph);
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph);
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph);
virtual bool Run(session::KernelGraph *graph);
virtual bool DumpData(session::KernelGraph *graph);
virtual bool RunTask(const session::KernelGraph *graph);
......
......@@ -222,6 +222,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
optimizer->AddPassManager(ir_fusion_pm);
(void)optimizer->Optimize(kernel_graph);
......
......@@ -25,7 +25,7 @@ assign_add_op_info = TBERegOp("AssignAdd") \
.partial_flag(True) \
.input(0, "ref", False, "required", "all") \
.input(1, "value", False, "required", "all") \
.output(0, "output_ref", False, "required", "all") \
.output(0, "ref", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
......
......@@ -210,6 +210,10 @@ class Print(PrimitiveWithInfer):
def __init__(self):
pass
def __call__(self, *args):
for arg in args:
print(arg)
def infer_shape(self, *inputs):
return [1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册