diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index 3cb78b040430a16cd58f1d3ea156ff51a8a8033e..92a50d8bb1b0887bbbb3dfb469558b25676855ca 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -18,6 +18,7 @@ import paddle from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layers.utils import flatten from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose +import paddle.fluid.core as core paddle.enable_static() @@ -363,6 +364,46 @@ class TestFillZerosLikeOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestFillAnyLikeOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'fill_any_like' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + + self.input = { + 'X': X, + } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['fill_any_like', 'fill_constant_p'] + self.out_map = {0: self.output['Out']} + + +class TestFillAnyLikeOrig2Prim2(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'fill_any_like' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + + self.input = { + 'X': X, + } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'dtype': paddle.float32, 'value': 5} + + self.orig2prim_args = (X, ) + self.all_ops = ['fill_any_like', 'fill_constant_p'] + self.out_map = {0: self.output['Out']} + + class TestSumOrig2Prim(TestElementWiseAddOrig2Prim): def init_data(self): diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 326f61088171a9dbe17e41b11febf22da41dcfc6..954bdf0cb1c868f02abbaa0d23995b2bfe2b1f72 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -29,6 +29,8 @@ from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, lookup_orig2prim, lookup_prim2orig, lookup_transpose, op_position_inputs, op_position_output) from .utils import INT_DTYPE_2_STRING, get_input_var_list, get_output_var_list +from paddle.fluid.data_feeder import convert_dtype +from paddle.fluid.framework import convert_np_dtype_to_dtype_ def _orig2prim(op, *args): @@ -66,6 +68,7 @@ elementwise_sub elementwise_mul tanh fill_zeros_like +fill_any_like sum index_select scale @@ -195,6 +198,16 @@ def fill_zeros_like_orig2prim(op, x): return fill_const(value=0.0, shape=x.shape, dtype=x.dtype) +@REGISTER_ORIG2PRIM('fill_any_like') +def fill_any_like_orig2prim(op, x): + if op.attr('dtype') == -1: + return fill_const(value=op.attr('value'), shape=x.shape, dtype=x.dtype) + return fill_const(value=op.attr('value'), + shape=x.shape, + dtype=convert_np_dtype_to_dtype_( + convert_dtype(INT_DTYPE_2_STRING[op.attr('dtype')]))) + + @REGISTER_ORIG2PRIM('sum') def sum_orig2prim(op, xs): x0 = xs[0]