fix(jit): fix jit grad
a) fix shape mismatch when take grad of JITExecutor including Dimshuffle b) avoid redundant computation in the grad of JITExecutor c) not pass unused vars as inputs to the grad of JITExecutor to save device memory d) traverse internal graph only once in JITExecutor ctor instead of traverse whole graph in each call of setup_args() e) expand the gradient graph into the origin graph if all inputs are const GitOrigin-RevId: ba6a2b29e975c7f63a21785efad87dbda76143d4
Showing
想要评论请 注册 或 登录