• M
    fix(jit): fix jit grad · bc95e873
    Megvii Engine Team 提交于
    a) fix shape mismatch when take grad of JITExecutor including Dimshuffle
    b) avoid redundant computation in the grad of JITExecutor
    c) not pass unused vars as inputs to the grad of JITExecutor to save device memory
    d) traverse internal graph only once in JITExecutor ctor instead of traverse
       whole graph in each call of setup_args()
    e) expand the gradient graph into the origin graph if all inputs are const
    
    GitOrigin-RevId: ba6a2b29e975c7f63a21785efad87dbda76143d4
    bc95e873
fusion.cpp 51.9 KB