未验证 提交 94194275 编写于 作者: P pangyoki 提交者: GitHub

add double_grad and triple_grad inplace info in backward.yaml (#43124)

* add double_grad and triple_grad inplace info in backward.yaml

* only generate inplace api in forward
上级 462ae005
...@@ -48,8 +48,7 @@ class BaseAPI(object): ...@@ -48,8 +48,7 @@ class BaseAPI(object):
'func']) == 1 or not self.kernel['func'][1].endswith( 'func']) == 1 or not self.kernel['func'][1].endswith(
'_sr') else True '_sr') else True
self.data_transform = self.parse_data_transform(api_item_yaml) self.data_transform = self.parse_data_transform(api_item_yaml)
self.inplace_map, self.view_map = self.parse_inplace_and_view( self.inplace_map, self.view_map = {}, {}
api_item_yaml)
def get_api_name(self, api_item_yaml): def get_api_name(self, api_item_yaml):
return api_item_yaml['api'] return api_item_yaml['api']
...@@ -303,31 +302,6 @@ class BaseAPI(object): ...@@ -303,31 +302,6 @@ class BaseAPI(object):
return data_transform return data_transform
def parse_inplace_and_view(self, api_item_yaml):
inplace_map, view_map = {}, {}
for mode in ['inplace', 'view']:
if mode in api_item_yaml:
if mode == 'inplace':
inplace_map = {}
else:
view_map = {}
in_out_mapping_list = api_item_yaml[mode].split(',')
for item in in_out_mapping_list:
result = re.search(r"(?P<in>\w+)\s*->\s*(?P<out>\w+)", item)
in_val = result.group('in')
out_val = result.group('out')
assert in_val in self.inputs['names'], \
f"{self.api} : {mode} input error: the input var name('{in_val}') is not found in the input args of {self.api}."
assert out_val in self.outputs['names'], \
f"{self.api} : {mode} output error: the output var name('{out_val}') is not found in the output args of {self.api}."
if mode == 'inplace':
inplace_map[out_val] = in_val
else:
view_map[out_val] = in_val
return inplace_map, view_map
# Override by child class # Override by child class
def get_return_type(self, inplace_flag=False): def get_return_type(self, inplace_flag=False):
return None return None
......
...@@ -30,6 +30,8 @@ class ForwardAPI(BaseAPI): ...@@ -30,6 +30,8 @@ class ForwardAPI(BaseAPI):
super(ForwardAPI, self).__init__(api_item_yaml) super(ForwardAPI, self).__init__(api_item_yaml)
self.is_dygraph_api, self.intermediate_outs = self.parse_intermediate( self.is_dygraph_api, self.intermediate_outs = self.parse_intermediate(
api_item_yaml) api_item_yaml)
self.inplace_map, self.view_map = self.parse_inplace_and_view(
api_item_yaml)
def get_api_func_name(self): def get_api_func_name(self):
if self.is_dygraph_api: if self.is_dygraph_api:
...@@ -47,6 +49,31 @@ class ForwardAPI(BaseAPI): ...@@ -47,6 +49,31 @@ class ForwardAPI(BaseAPI):
else: else:
return False, [] return False, []
def parse_inplace_and_view(self, api_item_yaml):
inplace_map, view_map = {}, {}
for mode in ['inplace', 'view']:
if mode in api_item_yaml:
if mode == 'inplace':
inplace_map = {}
else:
view_map = {}
in_out_mapping_list = api_item_yaml[mode].split(',')
for item in in_out_mapping_list:
result = re.search(r"(?P<in>\w+)\s*->\s*(?P<out>\w+)", item)
in_val = result.group('in')
out_val = result.group('out')
assert in_val in self.inputs['names'], \
f"{self.api} : {mode} input error: the input var name('{in_val}') is not found in the input args of {self.api}."
assert out_val in self.outputs['names'], \
f"{self.api} : {mode} output error: the output var name('{out_val}') is not found in the output args of {self.api}."
if mode == 'inplace':
inplace_map[out_val] = in_val
else:
view_map[out_val] = in_val
return inplace_map, view_map
def get_return_type_with_intermediate(self, inplace_flag=False): def get_return_type_with_intermediate(self, inplace_flag=False):
out_type_list = [] out_type_list = []
for i, out_type in enumerate(self.outputs['types']): for i, out_type in enumerate(self.outputs['types']):
......
...@@ -56,6 +56,7 @@ ...@@ -56,6 +56,7 @@
func : add_double_grad func : add_double_grad
optional : grad_x_grad, grad_y_grad optional : grad_x_grad, grad_y_grad
backward : add_triple_grad backward : add_triple_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : add_grad - backward_api : add_grad
forward : add (Tensor x, Tensor y) -> Tensor(out) forward : add (Tensor x, Tensor y) -> Tensor(out)
...@@ -86,6 +87,7 @@ ...@@ -86,6 +87,7 @@
param : [grad_grad_x, grad_grad_y] param : [grad_grad_x, grad_grad_y]
kernel : kernel :
func : add_triple_grad func : add_triple_grad
inplace : (grad_grad_out_grad -> grad_grad_x_grad)
- backward_api : addmm_grad - backward_api : addmm_grad
forward : addmm (Tensor input, Tensor x, Tensor y, float alpha, float beta) -> Tensor(out) forward : addmm (Tensor input, Tensor x, Tensor y, float alpha, float beta) -> Tensor(out)
...@@ -193,6 +195,7 @@ ...@@ -193,6 +195,7 @@
func : batch_norm_grad_grad func : batch_norm_grad_grad
data_type : x data_type : x
optional : out_mean, out_variance optional : out_mean, out_variance
inplace : (grad_out -> grad_out_grad)
- backward_api : batch_norm_grad - backward_api : batch_norm_grad
forward : batch_norm (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) forward : batch_norm (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
...@@ -261,6 +264,7 @@ ...@@ -261,6 +264,7 @@
param : [x, x] param : [x, x]
kernel : kernel :
func : celu_double_grad func : celu_double_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : celu_grad - backward_api : celu_grad
forward : celu(Tensor x, float alpha) -> Tensor(out) forward : celu(Tensor x, float alpha) -> Tensor(out)
...@@ -532,6 +536,7 @@ ...@@ -532,6 +536,7 @@
func : divide_double_grad func : divide_double_grad
data_type : out data_type : out
optional : grad_x_grad, grad_y_grad optional : grad_x_grad, grad_y_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : divide_grad - backward_api : divide_grad
forward : divide (Tensor x, Tensor y) -> Tensor(out) forward : divide (Tensor x, Tensor y) -> Tensor(out)
...@@ -596,6 +601,7 @@ ...@@ -596,6 +601,7 @@
param : [x, x] param : [x, x]
kernel : kernel :
func : elu_double_grad func : elu_double_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : elu_grad - backward_api : elu_grad
forward : elu (Tensor x, float alpha) -> Tensor(out) forward : elu (Tensor x, float alpha) -> Tensor(out)
...@@ -947,6 +953,7 @@ ...@@ -947,6 +953,7 @@
param : [grad_x_grad] param : [grad_x_grad]
kernel : kernel :
func : leaky_relu_double_grad func : leaky_relu_double_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : leaky_relu_grad - backward_api : leaky_relu_grad
forward : leaky_relu (Tensor x, float alpha) -> Tensor(out) forward : leaky_relu (Tensor x, float alpha) -> Tensor(out)
...@@ -1022,6 +1029,7 @@ ...@@ -1022,6 +1029,7 @@
param : [x, x] param : [x, x]
kernel : kernel :
func : log_double_grad func : log_double_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : log_grad - backward_api : log_grad
forward : log (Tensor x) -> Tensor(out) forward : log (Tensor x) -> Tensor(out)
...@@ -1310,6 +1318,7 @@ ...@@ -1310,6 +1318,7 @@
func : multiply_double_grad func : multiply_double_grad
optional : grad_x_grad, grad_y_grad optional : grad_x_grad, grad_y_grad
backward : multiply_triple_grad backward : multiply_triple_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : multiply_grad - backward_api : multiply_grad
forward : multiply (Tensor x, Tensor y) -> Tensor(out) forward : multiply (Tensor x, Tensor y) -> Tensor(out)
...@@ -1557,6 +1566,7 @@ ...@@ -1557,6 +1566,7 @@
param : [out] param : [out]
kernel : kernel :
func : relu_double_grad func : relu_double_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : relu_grad - backward_api : relu_grad
forward : relu (Tensor x) -> Tensor(out) forward : relu (Tensor x) -> Tensor(out)
...@@ -1580,6 +1590,7 @@ ...@@ -1580,6 +1590,7 @@
kernel : kernel :
func : reshape_double_grad func : reshape_double_grad
no_need_buffer : grad_out no_need_buffer : grad_out
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : reshape_grad - backward_api : reshape_grad
forward : reshape (Tensor x, IntArray shape) -> Tensor(out), Tensor(xshape) forward : reshape (Tensor x, IntArray shape) -> Tensor(out), Tensor(xshape)
...@@ -1654,6 +1665,7 @@ ...@@ -1654,6 +1665,7 @@
param : [out, out] param : [out, out]
kernel : kernel :
func : rsqrt_double_grad func : rsqrt_double_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : rsqrt_grad - backward_api : rsqrt_grad
forward : rsqrt (Tensor x) -> Tensor(out) forward : rsqrt (Tensor x) -> Tensor(out)
...@@ -1753,6 +1765,7 @@ ...@@ -1753,6 +1765,7 @@
kernel : kernel :
func : sigmoid_double_grad func : sigmoid_double_grad
backward : sigmoid_triple_grad backward : sigmoid_triple_grad
inplace : (grad_x_grad -> fwd_grad_out_grad)
- backward_api : sigmoid_grad - backward_api : sigmoid_grad
forward : sigmoid (Tensor x) -> Tensor(out) forward : sigmoid (Tensor x) -> Tensor(out)
...@@ -1776,6 +1789,7 @@ ...@@ -1776,6 +1789,7 @@
kernel : kernel :
func : sigmoid_triple_grad func : sigmoid_triple_grad
optional : grad_grad_out_grad optional : grad_grad_out_grad
inplace : (grad_grad_x -> fwd_grad_out_grad)
- backward_api : silu_grad - backward_api : silu_grad
forward : silu (Tensor x) -> Tensor(out) forward : silu (Tensor x) -> Tensor(out)
...@@ -1859,6 +1873,7 @@ ...@@ -1859,6 +1873,7 @@
param : [out, out] param : [out, out]
kernel : kernel :
func : sqrt_double_grad func : sqrt_double_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : sqrt_grad - backward_api : sqrt_grad
forward : sqrt (Tensor x) -> Tensor(out) forward : sqrt (Tensor x) -> Tensor(out)
...@@ -1881,6 +1896,7 @@ ...@@ -1881,6 +1896,7 @@
param : [x, x] param : [x, x]
kernel : kernel :
func : square_double_grad func : square_double_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : square_grad - backward_api : square_grad
forward : square (Tensor x) -> Tensor(out) forward : square (Tensor x) -> Tensor(out)
...@@ -1946,6 +1962,7 @@ ...@@ -1946,6 +1962,7 @@
func : subtract_double_grad func : subtract_double_grad
optional : grad_x_grad, grad_y_grad optional : grad_x_grad, grad_y_grad
no_need_buffer : y, grad_out no_need_buffer : y, grad_out
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : subtract_grad - backward_api : subtract_grad
forward : subtract (Tensor x, Tensor y) -> Tensor(out) forward : subtract (Tensor x, Tensor y) -> Tensor(out)
...@@ -2027,6 +2044,7 @@ ...@@ -2027,6 +2044,7 @@
kernel : kernel :
func : tanh_double_grad func : tanh_double_grad
backward : tanh_triple_grad backward : tanh_triple_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_api : tanh_grad - backward_api : tanh_grad
forward : tanh (Tensor x) -> Tensor(out) forward : tanh (Tensor x) -> Tensor(out)
...@@ -2060,6 +2078,7 @@ ...@@ -2060,6 +2078,7 @@
param : [out, out, grad_x_grad_forward] param : [out, out, grad_x_grad_forward]
kernel : kernel :
func : tanh_triple_grad func : tanh_triple_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad)
- backward_api : thresholded_relu_grad - backward_api : thresholded_relu_grad
forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out) forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册