op_compat.yaml 10.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
- api : abs
  backward : abs_grad
  extra :
    attrs : [bool use_cudnn = false, bool use_mkldnn = false]

- api : addmm
  backward : addmm_grad
  extra :
    attrs : [bool use_mkldnn = false]

- api : affine_grid
  backward : affine_grid_grad
  extra :
    attrs : [bool use_cudnn = true]

- api : angle
  backward : angle_grad
  extra :
    attrs : [bool use_cudnn = false, bool use_mkldnn = false]
20

21 22
- api : atan2
  inputs :
23
    {x : X1, y : X2}
24 25 26
  outputs :
    out : Out

27 28 29 30 31
- api : batch_norm
  backward : batch_norm_grad
  extra :
    attrs : [bool use_mkldnn = false, bool fuse_with_relu = false]

32 33 34 35 36 37
- api : bernoulli
  inputs :
    x : X
  outputs :
    out : Out

38 39 40 41 42 43 44 45 46 47
- api : bicubic_interp (bicubic_interp_v2)
  backward : bicubic_interp_grad (bicubic_interp_v2_grad)
  extra :
    attrs : [bool use_mkldnn = false]

- api : bilinear_interp (bilinear_interp_v2)
  backward : bilinear_interp_grad (bilinear_interp_v2_grad)
  extra :
    attrs : [bool use_mkldnn = false]

48 49 50 51 52 53 54 55 56 57 58 59
- api : cholesky
  inputs :
    x : X
  outputs :
    out : Out

- api : cholesky_solve
  inputs :
    {x : X, y : Y}
  outputs :
    out : Out

60 61 62 63 64 65 66 67 68 69
- api : clip
  backward : clip_grad
  extra :
    attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]

- api : concat
  backward : concat_grad
  extra :
    attrs : [bool use_mkldnn = false, bool use_quantizer = false, str mkldnn_data_type = "float32"]

70
- api : conv2d
71
  backward : conv2d_grad
72
  extra :
73
    attrs : [bool is_test = false, bool use_cudnn = true, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
74
             bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
75
             str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false,
76 77
             bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
             float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
78
             int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB(), bool exhaustive_search = false]
79

80
- api : conv2d_fusion
F
Feiyu Chan 已提交
81
  extra :
82
    attrs : [bool is_test = false, bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
F
Feiyu Chan 已提交
83
             bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
84
             str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false,
F
Feiyu Chan 已提交
85 86
             bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
             float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
             int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB(), bool exhaustive_search = false]

- api : conv2d_transpose
  backward : conv2d_transpose_grad
  extra :
    attrs : [bool is_test = false, bool use_cudnn = true, bool use_mkldnn = false, bool force_fp32_output = false,
             str mkldnn_data_type = "float32", bool fuse_relu = false,
             str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f,
             int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()]

- api : conv3d
  backward : conv3d_grad
  extra :
    attrs : [bool is_test = false, bool use_cudnn = true, bool use_mkldnn = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
             str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f,
             bool use_addto = false, bool fuse_residual_connection = false, bool force_fp32_output = false,
             int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB(), bool exhaustive_search = false]

- api : conv3d_transpose
  backward : conv3d_transpose_grad
  extra :
    attrs : [bool use_cudnn = true, bool use_mkldnn = false, int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()]
F
Feiyu Chan 已提交
109

110
- api : cross
111 112
  inputs :
    {x : X, y : Y}
113 114 115 116 117
  attrs :
    axis : dim
  outputs :
    out : Out

118 119 120 121 122
- api : data_norm
  backward : data_norm_grad
  extra :
    attrs : [bool use_mkldnn = false]

123 124 125
- api : depthwise_conv2d
  backward : depthwise_conv2d_grad
  extra :
126
    attrs : [bool is_test = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
127 128 129 130
             bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
             str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false,
             bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
             float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
131 132 133 134 135 136 137 138 139
             int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB(), bool exhaustive_search = false]

- api : depthwise_conv2d_transpose
  backward : depthwise_conv2d_transpose_grad
  extra :
    attrs : [bool is_test = false, bool use_cudnn = false, bool use_mkldnn = false, bool force_fp32_output = false,
             str mkldnn_data_type = "float32", bool fuse_relu = false,
             str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f,
             int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()]
140

141 142
- api : diag (diag_v2)
  backward : diag_grad (diag_v2_grad)
143 144 145 146 147
  inputs :
    x : X
  outputs :
    out : Out

148 149 150 151 152 153
- api : diagonal
  inputs :
    x : Input
  outputs :
    out : Out

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
- api : digamma
  inputs :
    x : X
  outputs :
    out : Out

- api : dist
  inputs :
    {x : X, y : Y}
  outputs :
    out : Out

- api : dot
  inputs :
    {x : X, y : Y}
  outputs :
    out : Out

172 173 174 175 176 177 178 179 180 181
- api : dropout
  backward : dropout_grad
  extra :
    attrs : [bool fix_seed = false, int seed = 0]

- api : dropout_nd
  backward : dropout_nd_grad
  extra :
    attrs : [bool fix_seed = false, int seed = 0]

182 183 184 185 186 187
- api : erf
  inputs :
    x : X
  outputs :
    out : Out

188 189 190 191 192 193
- api : erfinv
  inputs :
    x : X
  outputs :
    out : Out

194 195 196 197 198 199 200 201 202 203 204 205
- api : fft_c2c
  inputs: {x: X}
  outputs: {out: Out}

- api : fft_c2r
  inputs: {x: X}
  outputs: {out: Out}

- api : fft_r2c
  inputs: {x: X}
  outputs: {out: Out}

206 207 208 209 210
- api : frobenius_norm
  backward : frobenius_norm_grad
  extra :
    attrs : [bool use_mkldnn = false]

211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
- api : gelu
  backward : gelu_grad
  extra :
    attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool use_cudnn = false]

- api : grid_sampler
  backward : grid_sampler_grad
  extra :
    attrs : [bool use_cudnn = true]

- api : gru
  backward : gru_grad
  extra :
    attrs : [bool is_test = false]

226 227 228 229 230
- api : inplace_abn
  backward : inplace_abn_grad
  extra :
    attrs : [bool use_mkldnn = false, bool fuse_with_relu = false]

231 232 233 234 235
- api : layer_norm
  backward : layer_norm_grad
  extra :
    attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]

236 237 238 239 240 241
- api : lgamma
  inputs :
    x : X
  outputs :
    out : Out

242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
- api : linear_interp (linear_interp_v2)
  backward : linear_interp_grad (linear_interp_v2_grad)
  extra :
    attrs : [bool use_mkldnn = false]

- api : log_softmax
  backward : log_softmax_grad
  extra :
    attrs : [bool use_mkldnn = false]

- api : lrn
  backward : lrn_grad
  extra :
    attrs : [bool use_mkldnn = false, bool is_test = false]

257 258 259 260 261 262 263
- api : matmul (matmul_v2)
  backward : matmul_grad (matmul_v2_grad)
  extra :
    attrs : [bool use_mkldnn = false, 'int[] fused_reshape_Out = {}', 'int[] fused_transpose_Out = {}',
             str mkldnn_data_type = "float32", 'int[] fused_reshape_X = {}', 'int[] fused_reshape_Y = {}',
             'int[] fused_transpose_X = {}', 'int[] fused_transpose_Y = {}',]

264 265 266 267 268 269
- api : mv
  inputs :
    {x : X, vec : Vec}
  outputs :
    out : Out

270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
- api : nearest_interp (nearest_interp_v2)
  backward : nearest_interp_grad (nearest_interp_v2_grad)
  extra :
    attrs : [bool use_mkldnn = false]

- api : pad2d
  backward : pad2d_grad
  extra :
    attrs : [bool use_mkldnn = false]

- api : pad3d
  backward : pad3d_grad
  extra :
    attrs : [bool use_mkldnn = false]

- api : partial_sum
  backward : partial_sum_grad
  extra :
    attrs : [bool use_mkldnn = false]

290 291 292 293 294 295
- api : poisson
  inputs :
    x : X
  outputs :
    out : Out

296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
- api : reduce_all
  extra :
    attrs : [bool use_mkldnn = false]

- api : reduce_amax
  backward : reduce_amax_grad
  extra :
    attrs : [bool use_mkldnn = false]

- api : reduce_amin
  backward : reduce_amin_grad
  extra :
    attrs : [bool use_mkldnn = false]

- api : reduce_any
  extra :
    attrs : [bool use_mkldnn = false]

- api : reduce_max
  backward : reduce_max_grad
  extra :
    attrs : [bool use_mkldnn = false]

- api : reduce_mean
  backward : reduce_mean_grad
  extra :
    attrs : [bool use_mkldnn = false]

- api : reduce_min
  backward : reduce_min_grad
  extra :
    attrs : [bool use_mkldnn = false]

- api : reduce_prod
  backward : reduce_prod_grad
  extra :
    attrs : [bool use_mkldnn = false]

- api : reduce_sum
  backward : reduce_sum_grad
  extra :
    attrs : [bool use_mkldnn = false]

339 340 341 342 343
- api : renorm
  backward : renorm_grad
  extra :
    attrs : [bool use_mkldnn = false, bool use_cudnn = false]

344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
- api : rnn
  backward : rnn_grad
  extra :
    attrs : [bool is_test = false]

- api : seed
  extra :
    attrs : [bool deterministic = false, str rng_name = "", bool force_cpu = false]

- api : shape
  extra :
    attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]

- api : shuffle_channel
  backward : shuffle_channel_grad
  extra :
    attrs : [bool use_mkldnn = false]

- api : slice
  backward : slice_grad
  extra :
    attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]

- api : softmax
  backward : softmax_grad
  extra :
    attrs : [bool use_cudnn = false, bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
- api : prelu
  backward : prelu_grad
  extra :
    attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]

376 377 378 379 380 381
- api : solve
  inputs :
    {x : X, y : Y}
  outputs :
    out : Out

382 383 384 385 386
- api : squeeze (squeeze2)
  backward : squeeze_grad (squeeze2_grad)
  extra :
    attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]

387 388 389 390 391
- api : stack
  backward : stack_grad
  extra :
    attrs : [bool use_mkldnn = false]

392 393 394 395 396
- api : sync_batch_norm
  backward : sync_batch_norm_grad
  extra :
    attrs : [bool use_mkldnn = false, bool fuse_with_relu = false]

397 398 399 400 401
- api : trace
  inputs :
    x : Input
  outputs :
    out : Out
402

403 404 405 406 407
- api : trilinear_interp (trilinear_interp_v2)
  backward : trilinear_interp_grad (trilinear_interp_v2_grad)
  extra :
    attrs : [bool use_mkldnn = false]

408 409 410 411 412
- api : trunc
  inputs :
    x : X
  outputs :
    out : Out