diff --git a/python/paddle/utils/code_gen/backward_api_gen.py b/python/paddle/utils/code_gen/backward_api_gen.py index 2d33cd5b1812ada8fca118c0e0f616cfbe511dd1..125ebed82de8b25b0a2c20ca7b76560966313566 100644 --- a/python/paddle/utils/code_gen/backward_api_gen.py +++ b/python/paddle/utils/code_gen/backward_api_gen.py @@ -56,8 +56,9 @@ class BackwardAPI(BaseAPI): # check the attributes of backward for attr in self.attrs['names']: - assert attr in fw_attrs['names'] and self.attrs['attr_info'][attr][0] == fw_attrs['attr_info'][attr][0], \ - f"{self.api} : Attribute error: The attribute({attr}) of backward isn't consistent with forward api. \ + assert (attr in fw_attrs['names'] and self.attrs['attr_info'][attr][0] == fw_attrs['attr_info'][attr][0]) or \ + self.attrs['attr_info'][attr][1] is not None, \ + f"{self.api} : Attribute error: The attribute({attr}) of backward isn't consistent with forward api or doesn't have default value. \ Please check the args of {self.api} in yaml." # check the output of backward