未验证 提交 266c6f0d 编写于 作者: C Chen Zhiyang 提交者: GitHub

【New IR】New ir op test v1.0 (#56668)

* add reference of lbfgs

* add reference of lbfgs

* new ir op test v1.0

* fix new ir optest bug1.0

* modify two testcase bug

* add new ir white list & pass test_mean_op.py

* rename white list

* add new_ir_guard

* rename backward.grad as ir_backward.grad

* check place for new ir

* fix test_build_model env bug

* fix test_prim_program backward bug

* change backward to ir_backward in check_appr

---------
Co-authored-by: Nwangruting <wangruting@baidu.com>
上级 2d50a64d
......@@ -478,9 +478,10 @@ if is_compiled_with_cinn():
disable_static()
from .new_ir_utils import _switch_to_new_ir # noqa: F401
from .new_ir_utils import IrChange # noqa: F401
_switch_to_new_ir()
ir_change = IrChange()
ir_change._switch_to_new_ir()
__all__ = [ # noqa
'iinfo',
......
......@@ -707,26 +707,26 @@ def grad(
outputs,
'outputs',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple),
'paddle.autograd.backward.grad',
'paddle.autograd.ir_backward.grad',
)
check_type(
inputs,
'inputs',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple),
'paddle.autograd.backward.grad',
'paddle.autograd.ir_backward.grad',
)
check_type(
grad_outputs,
'grad_outputs',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple, type(None)),
'paddle.autograd.backward.grad',
'paddle.autograd.ir_backward.grad',
)
check_type(
no_grad_vars,
'no_grad_vars',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple, set, type(None)),
'paddle.autograd.backward.grad',
'paddle.autograd.ir_backward.grad',
)
outputs = _as_list(outputs)
inputs = _as_list(inputs)
......
......@@ -12,17 +12,67 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .fluid.wrapped_decorator import signature_safe_contextmanager
class IrChange:
def __init__(self):
old_flag = paddle.fluid.framework.get_flags("FLAGS_enable_new_ir_api")
paddle.fluid.framework.set_flags({"FLAGS_enable_new_ir_api": False})
if not paddle.ir.core._use_new_ir_api():
self.old_Program = paddle.static.Program
self.old_program_guard = paddle.fluid.program_guard
self.old_default_main_program = paddle.static.default_main_program
else:
raise RuntimeError(
"IrChange only init when paddle.ir.core._use_new_ir_api() is false, \
please set FLAGS_enable_new_ir_api = false"
)
paddle.fluid.framework.set_flags(old_flag)
def _switch_to_new_ir(self):
if paddle.ir.core._use_new_ir_api():
paddle.framework.set_flags(
{"FLAGS_enable_new_ir_in_executor": True}
)
paddle.ir.register_paddle_dialect()
paddle.static.Program = paddle.ir.Program
paddle.fluid.Program = paddle.ir.Program
paddle.fluid.program_guard = paddle.ir.core.program_guard
paddle.static.program_guard = paddle.ir.core.program_guard
paddle.framework.default_main_program = (
paddle.ir.core.default_main_program
)
def _switch_to_old_ir(self):
if not paddle.ir.core._use_new_ir_api():
paddle.framework.set_flags(
{"FLAGS_enable_new_ir_in_executor": False}
)
paddle.static.Program = self.old_Program
paddle.fluid.Program = self.old_Program
paddle.fluid.program_guard = self.old_program_guard
paddle.static.program_guard = self.old_program_guard
paddle.framework.default_main_program = (
self.old_default_main_program
)
else:
raise RuntimeError(
"IrChange._switch_to_old_ir only work when paddle.ir.core._use_new_ir_api() is false, \
please set FLAGS_enable_new_ir_api = false"
)
def _switch_to_new_ir():
if paddle.ir.core._use_new_ir_api():
paddle.framework.set_flags({"FLAGS_enable_new_ir_in_executor": True})
paddle.ir.register_paddle_dialect()
paddle.static.Program = paddle.ir.Program
paddle.fluid.Program = paddle.ir.Program
paddle.fluid.program_guard = paddle.ir.core.program_guard
paddle.static.program_guard = paddle.ir.core.program_guard
paddle.framework.default_main_program = (
paddle.ir.core.default_main_program
)
@signature_safe_contextmanager
def _newir_guard():
ir_change = IrChange()
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
ir_change._switch_to_new_ir()
try:
yield
finally:
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
ir_change._switch_to_old_ir()
......@@ -98,44 +98,47 @@ def data(name, shape, dtype=None, lod_level=0):
[2.]]], dtype=float32)]
"""
helper = LayerHelper('data', **locals())
check_type(name, 'name', (bytes, str), 'data')
check_type(shape, 'shape', (list, tuple), 'data')
shape = list(shape)
for i in range(len(shape)):
if shape[i] is None:
shape[i] = -1
if dtype:
out = helper.create_global_variable(
name=name,
shape=shape,
dtype=dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
stop_gradient=True,
lod_level=lod_level,
is_data=True,
need_check_feed=True,
)
else:
out = helper.create_global_variable(
name=name,
shape=shape,
dtype=paddle.get_default_dtype(),
type=core.VarDesc.VarType.LOD_TENSOR,
stop_gradient=True,
lod_level=lod_level,
is_data=True,
need_check_feed=True,
)
dtype = paddle.get_default_dtype()
if paddle.ir.core._use_new_ir_api():
if not dtype:
dtype = paddle.get_default_dtype()
ir_dtype = paddle.ir.core.convert_np_dtype_to_dtype_(dtype)
return paddle._ir_ops.data(name, shape, ir_dtype, core.Place())
else:
helper = LayerHelper('data', **locals())
check_type(name, 'name', (bytes, str), 'data')
check_type(shape, 'shape', (list, tuple), 'data')
shape = list(shape)
for i in range(len(shape)):
if shape[i] is None:
shape[i] = -1
if dtype:
out = helper.create_global_variable(
name=name,
shape=shape,
dtype=dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
stop_gradient=True,
lod_level=lod_level,
is_data=True,
need_check_feed=True,
)
else:
out = helper.create_global_variable(
name=name,
shape=shape,
dtype=paddle.get_default_dtype(),
type=core.VarDesc.VarType.LOD_TENSOR,
stop_gradient=True,
lod_level=lod_level,
is_data=True,
need_check_feed=True,
)
is_new_ir_mode = os.environ.get("FLAGS_enable_new_ir_in_executor", None)
if evaluate_flag(is_new_ir_mode):
helper = LayerHelper('data', **locals())
......
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
......@@ -29,15 +28,25 @@ class TestBuildModule(unittest.TestCase):
y = paddle.static.data('y', [4, 4], dtype='float32')
divide_out = paddle.divide(x, y)
sum_out = paddle.sum(divide_out)
exe = paddle.static.Executor()
x_feed = np.ones([4, 4], dtype=np.float32) * 10
y_feed = np.ones([4, 4], dtype=np.float32) * 2
(sum_value,) = exe.run(
feed={'x': x_feed, 'y': y_feed}, fetch_list=[sum_out]
main_program,
feed={'x': x_feed, 'y': y_feed},
fetch_list=[sum_out],
)
self.assertEqual(sum_value, 5 * 4 * 4)
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = paddle.static.data('x', [4, 4], dtype='float32')
out = paddle.mean(x)
exe = paddle.static.Executor()
x_feed = np.ones([4, 4], dtype=np.float32) * 10
(sum_value,) = exe.run(feed={'x': x_feed}, fetch_list=[out])
self.assertEqual(sum_value, 10)
if __name__ == "__main__":
unittest.main()
......@@ -16,7 +16,7 @@ import unittest
import paddle
from paddle import ir
from paddle.autograd.backward import grad
from paddle.autograd.ir_backward import grad
paddle.enable_static()
......
......@@ -31,6 +31,7 @@ sys.path.append("..")
from white_list import (
check_shape_white_list,
compile_vs_runtime_white_list,
new_ir_python_api_grad_white_list,
no_check_set_white_list,
no_grad_set_white_list,
op_accuracy_white_list,
......@@ -39,6 +40,7 @@ from white_list import (
import paddle
from paddle import fluid
from paddle.autograd.ir_backward import grad as ir_grad
from paddle.fluid import core, unique_name
from paddle.fluid.backward import append_backward
from paddle.fluid.executor import Executor
......@@ -1201,6 +1203,164 @@ class OpTest(unittest.TestCase):
)
return outputs
def get_kernel_signature(self, place, egr_inps=None, egr_oups=None):
with fluid.dygraph.base.guard(place=place):
block = fluid.default_main_program().global_block()
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
# prepare input variable
dygraph_tensor_inputs = (
egr_inps
if egr_inps
else self.append_input_output_for_dygraph(
op_proto, self.inputs, True, False, block
)
)
# prepare output variable
dygraph_tensor_outputs = (
egr_oups
if egr_oups
else self.append_input_output_for_dygraph(
op_proto, self.outputs, False, False, block
)
)
# prepare attributes
attrs_outputs = {}
if hasattr(self, "attrs"):
for attrs_name in self.attrs:
if self.attrs[attrs_name] is not None:
attrs_outputs[attrs_name] = self.attrs[attrs_name]
kernel_sig = OpTestUtils._get_kernel_signature(
self.op_type,
dygraph_tensor_inputs,
dygraph_tensor_outputs,
canonicalize_attrs(attrs_outputs, op_proto),
)
if not kernel_sig or (
len(kernel_sig[0]) == 0
and len(kernel_sig[1]) == 0
and len(kernel_sig[2]) == 0
):
return None
if not hasattr(self, "python_api"):
print(kernel_sig)
assert hasattr(self, "python_api"), (
"Detect there is KernelSignature for `%s` op, please set the `self.python_api` if you set check_dygraph = True"
% self.op_type
)
return kernel_sig
def get_ir_input_attr_dict_and_feed(self, stop_gradient):
attrs_outputs = {}
if hasattr(self, "attrs"):
for attrs_name in self.attrs:
if self.attrs[attrs_name] is not None:
attrs_outputs[attrs_name] = self.attrs[attrs_name]
input_dict = {}
static_inputs = defaultdict(list)
feed = {}
for name, item in self.inputs.items():
if isinstance(item, list):
for tup in item:
dtype = (
"bfloat16"
if OpTestUtils.is_bfloat16_type(tup[1].dtype)
else tup[1].dtype
)
x = paddle.static.data(
name=str(tup[0]), shape=tup[1].shape, dtype=dtype
)
x.stop_gradient = stop_gradient
static_inputs[name].append(x)
feed.update({str(tup[0]): tup[1]})
input_dict.update({str(tup[0]): x})
else:
dtype = (
"bfloat16"
if OpTestUtils.is_bfloat16_type(item.dtype)
else item.dtype
)
x = paddle.static.data(name=name, shape=item.shape, dtype=dtype)
x.stop_gradient = stop_gradient
static_inputs[name].append(x)
feed.update({name: item})
input_dict.update({name: x})
return static_inputs, attrs_outputs, input_dict, feed
def _calc_new_ir_output(
self, place, no_check_set=None, inps=None, oups=None
):
"""set egr_inps and egr_oups = None if you want to create it by yourself."""
def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
if hasattr(self, "python_out_sig"):
output_sig = self.python_out_sig
if not isinstance(ret_tuple, (tuple, list)):
ret_tuple = [ret_tuple]
if len(output_sig) == len(ret_tuple):
# [assumption]: we assume {"Out": [Tensor]}
return {a: [b] for a, b in zip(output_sig, ret_tuple)}
else:
# [assumption]: return multi-Tensor in a single output. such as paddle.split()
assert (
len(output_sig) == 1
), "Don't support multi-output with multi-tensor output. (May be you can use set `python_out_sig`, see `test_squeeze2_op` as a example.)"
return {output_sig[0]: ret_tuple}
# get kernel signature
kernel_sig = self.get_kernel_signature(place)
ir_program = paddle.static.Program()
with paddle.static.program_guard(ir_program):
# prepare inps attributes feed
(
static_inputs,
attrs,
input_dict,
feed,
) = self.get_ir_input_attr_dict_and_feed(stop_gradient=True)
# prepare args
args = OpTestUtils.prepare_python_api_arguments(
self.python_api,
static_inputs,
attrs,
kernel_sig,
)
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret_tuple = self.python_api(*args)
result = construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
result[key][0][i].name = self.python_out_sig_sub_name[
key
][i]
fetch_list = getattr(self, "fetch_list", [])
# if the fetch_list is customized by user, we use it directly.
# if not, fill the fetch_list by the user configured outputs in test.
if len(fetch_list) == 0:
for var in result.items():
if no_check_set is not None and var in no_check_set:
continue
if isinstance(var[1], list):
for v in var[1]:
fetch_list.append(v)
else:
fetch_list.append(var[1])
# executor run
executor = Executor(place)
(outs,) = executor.run(
ir_program,
feed=feed,
fetch_list=fetch_list,
)
return outs
def _check_ir_output(self, place, program, feed_map, fetch_list, outs):
if os.getenv("FLAGS_NEW_IR_OPTEST") is None:
return
......@@ -2123,6 +2283,114 @@ class OpTest(unittest.TestCase):
return True
return super()._is_skip_name(name)
class NewIRChecker(Checker):
def init(self):
self.checker_name = "new ir checker"
def calculate_output(self):
self.is_python_api_test = True
new_ir_outs = self.op_test._calc_new_ir_output(place)
if new_ir_outs is None:
self.is_python_api_test = False
# missing KernelSignature, fall back to eager middle output.
new_ir_outs = self.op_test._calc_dygraph_output(
place, no_check_set=no_check_set
)
self.outputs = new_ir_outs
if self.op_test.is_compared_with_fp32():
self.op_test.enable_cal_ref_output()
self.is_python_api_test = True
self.ref_outputs = self.op_test._calc_new_ir_output(place)
if self.ref_outputs is None:
self.is_python_api_test = False
# missing KernelSignature, fall back to eager middle output.
self.ref_outputs = self.op_test._calc_dygraph_output(
place, no_check_set=no_check_set
)
self.op_test.disable_cal_ref_output()
def _compare_numpy(self, name, actual_np, expect_np):
expect_np = np.array(expect_np)
assert (
actual_np.shape == expect_np.shape
), "Operator ({}) : Output ({}) shape mismatch, expect shape is {}, but actual shape is {}".format(
self.op_type, name, expect_np.shape, actual_np.shape
)
np.testing.assert_allclose(
actual_np,
expect_np,
atol=atol,
rtol=self.rtol if hasattr(self, 'rtol') else rtol,
equal_nan=equal_nan,
err_msg=(
"Operator ("
+ self.op_type
+ ") Output ("
+ name
+ ") has diff at "
+ str(place)
+ " in "
+ self.checker_name
),
)
def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
if actual_np.dtype == np.uint16:
self.rtol = 1.0e-2
elif actual_np.dtype == np.float16:
self.rtol = 1.0e-3
else:
self.rtol = 1.0e-5
if self.op_test.is_bfloat16_op():
if actual_np.dtype == np.uint16:
actual_np = convert_uint16_to_float(actual_np)
if expect_np.dtype == np.uint16:
expect_np = convert_uint16_to_float(expect_np)
return actual_np, expect_np
def find_actual_value(self, target_name):
with paddle.ir.core.program_guard(
paddle.ir.core.default_main_program()
):
actual = self.outputs
actual_t = np.array(actual)
return actual, actual_t
def find_expect_value(self, name):
with paddle.ir.core.program_guard(
paddle.ir.core.default_main_program()
):
expect = self.ref_outputs
expect_t = np.array(expect)
return expect, expect_t
def _compare_list(self, name, actual, expect):
"""if expect is a tuple, we need to compare list."""
with paddle.ir.core.program_guard(place=place):
self.op_test.assertListEqual(
actual.value()
.get_tensor()
.recursive_sequence_lengths(),
expect[1],
"Operator ("
+ self.op_type
+ ") Output ("
+ name
+ ") has different lod at "
+ str(place)
+ " in dygraph mode",
)
def _is_skip_name(self, name):
# if in final state and kernel signature don't have name, then skip it.
if (
self.is_python_api_test
and hasattr(self.op_test, "python_out_sig")
and name not in self.op_test.python_out_sig
):
return True
return super()._is_skip_name(name)
# set some flags by the combination of arguments.
if self.is_float16_op():
self.dtype = np.float16
......@@ -2184,6 +2452,21 @@ class OpTest(unittest.TestCase):
dygraph_checker.check()
dygraph_dygraph_outs = dygraph_checker.outputs
if (
self.op_type
in new_ir_python_api_grad_white_list.new_ir_python_api_grad_white_list
):
if (
type(place) is paddle.fluid.libpaddle.CPUPlace
or type(place) is paddle.fluid.libpaddle.CUDAPlace
):
print("New IR checker begins...........")
with paddle.new_ir_utils._newir_guard():
new_ir_checker = NewIRChecker(self, self.outputs)
new_ir_checker.check()
print("New IR checker ends...........")
# Note(zhiqiu): inplace_atol should be only set when op doesn't ensure
# computational consistency.
# For example, group_norm uses AtomicAdd on CUDAPlace, which do not ensure
......@@ -2720,6 +3003,33 @@ class OpTest(unittest.TestCase):
"Gradient Check On %s" % str(place),
atol=atol,
)
# get new ir gradient
if (
self.op_type
in new_ir_python_api_grad_white_list.new_ir_python_api_grad_white_list
):
if (
type(place) is paddle.fluid.libpaddle.CPUPlace
or type(place) is paddle.fluid.libpaddle.CUDAPlace
):
print("New IR gradient begins...........")
with paddle.new_ir_utils._newir_guard():
new_ir_grad = self._get_ir_gradient(
inputs_to_check,
place,
output_names,
user_defined_grad_outputs,
no_grad_set,
)
print("New IR gradient ends...........")
self._assert_is_close(
numeric_grads,
[new_ir_grad],
inputs_to_check,
max_relative_error,
"Gradient Check On %s" % str(place),
atol=atol,
)
def _find_var_in_dygraph(self, output_vars, name):
if name in output_vars:
......@@ -3065,6 +3375,106 @@ class OpTest(unittest.TestCase):
return res
def _get_ir_gradient(
self,
inputs_to_check,
place,
output_names,
user_defined_grad_outputs=None,
no_grad_set=None,
):
def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
if hasattr(self, "python_out_sig"):
output_sig = self.python_out_sig
if not isinstance(ret_tuple, (tuple, list)):
ret_tuple = [ret_tuple]
if len(output_sig) == len(ret_tuple):
# [assumption]: we assume {"Out": [Tensor]}
return {a: [b] for a, b in zip(output_sig, ret_tuple)}
else:
# [assumption]: return multi-Tensor in a single output. such as paddle.split()
assert (
len(output_sig) == 1
), "Don't support multi-output with multi-tensor output. (May be you can use set `python_out_sig`, see `test_squeeze2_op` as a example.)"
return {output_sig[0]: ret_tuple}
# get kernel signature
kernel_sig = self.get_kernel_signature(place)
ir_program = paddle.static.Program()
with paddle.static.program_guard(ir_program):
# prepare inps attributes feed
(
static_inputs,
attrs,
input_dict,
feed,
) = self.get_ir_input_attr_dict_and_feed(stop_gradient=False)
# prepare args
args = OpTestUtils.prepare_python_api_arguments(
self.python_api,
static_inputs,
attrs,
kernel_sig,
)
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret_tuple = self.python_api(*args)
result = construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
result[key][0][i].name = self.python_out_sig_sub_name[
key
][i]
fetch_list = getattr(self, "fetch_list", [])
if len(fetch_list) == 0:
for var in result.items():
if isinstance(var[1], list):
for v in var[1]:
fetch_list.append(v)
else:
fetch_list.append(var[1])
outputs = result
outputs_valid = outputs
grad_inputs = inputs_to_check
if user_defined_grad_outputs is None:
if len(outputs_valid) == 1:
for outputs_valid_key in outputs_valid:
loss = paddle.mean(outputs_valid[outputs_valid_key][0])
grad_inputs = ir_grad(
outputs=paddle.utils.flatten(loss),
inputs=paddle.utils.flatten(static_inputs),
grad_outputs=None,
)
else:
# user_defined_grad_outputs here are numpy arrays
if not isinstance(user_defined_grad_outputs, list):
user_defined_grad_outputs = [user_defined_grad_outputs]
grad_outputs = []
for grad_out_value in user_defined_grad_outputs:
grad_outputs.append(paddle.to_tensor(grad_out_value))
# delete the inputs which no need to calculate grad
for no_grad_val in no_grad_set:
del static_inputs[no_grad_val]
grad_inputs = ir_grad(
outputs=paddle.utils.flatten(outputs),
inputs=paddle.utils.flatten(static_inputs),
grad_outputs=grad_outputs,
)
fetch_list = list(grad_inputs)
# executor run
executor = paddle.static.Executor()
(outs,) = executor.run(
ir_program,
feed=feed,
fetch_list=fetch_list,
)
return outs
class OpTestTool:
@classmethod
......
......@@ -17,7 +17,7 @@ import unittest
import numpy as np
import paddle
from paddle.autograd.backward import grad
from paddle.autograd.ir_backward import grad
from paddle.decomposition import decompose
from paddle.framework import core
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
new_ir_python_api_grad_white_list = [
"mean",
]
......@@ -87,7 +87,7 @@ API_FILES=("CMakeLists.txt"
"paddle/fluid/prim/api/api.yaml"
"python/paddle/incubate/autograd/composite_rules.py"
"python/paddle/incubate/autograd/primitives.py"
"python/paddle/autograd/backward.py"
"python/paddle/autograd/ir_backward.py"
"python/paddle/autograd/backward_utils.py"
)
......@@ -220,8 +220,8 @@ for API_FILE in ${API_FILES[*]}; do
elif [ "${API_FILE}" == "python/paddle/incubate/autograd/primitives.py" ] || [ "${API_FILE}" == "python/paddle/incubate/autograd/composite_rules.py" ]; then
echo_line="You must have one RD (cyber-pioneer(chenzhuo), xiaoguoguo626807(wangruting), Charles-hit(wanghao), JiabinYang) approval for changing ${API_FILE} , which manages the composite rules.\n"
check_approval 1 cyber-pioneer xiaoguoguo626807 Charles-hit JiabinYang
elif [ "${API_FILE}" == "python/paddle/autograd/backward.py" ] || [ "${API_FILE}" == "python/paddle/autograd/backward_utils.py" ]; then
echo_line="You must be approved by Aurelius84(zhangliujie) or cxxly(chenxiaoxu) or xiaoguoguo626807(wangruting) or changeyoung98(chenzhiyang) for python/paddle/autograd/backward.py or python/paddle/autograd/backward_utils.py changes.\n"
elif [ "${API_FILE}" == "python/paddle/autograd/ir_backward.py" ] || [ "${API_FILE}" == "python/paddle/autograd/backward_utils.py" ]; then
echo_line="You must be approved by Aurelius84(zhangliujie) or cxxly(chenxiaoxu) or xiaoguoguo626807(wangruting) or changeyoung98(chenzhiyang) for python/paddle/autograd/ir_backward.py or python/paddle/autograd/backward_utils.py changes.\n"
check_approval 1 Aurelius84 cxxly xiaoguoguo626807 changeyoung98
else
echo_line="You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1,qili93,Aurelius84) approval for ${API_FILE}, which manages the underlying code for fluid.\n"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册