提交 a67fc4be 编写于 作者: P phlrain

add yaml config; test=develop

上级 f7765991
......@@ -202,7 +202,7 @@ def ParseYamlArgs(string):
default_value = m.group(3).split("=")[1].strip() if len(
m.group(3).split("=")) > 1 else None
assert arg_type in yaml_types_mapping.keys()
assert arg_type in yaml_types_mapping.keys(), arg_type
arg_type = yaml_types_mapping[arg_type]
if "Tensor" in arg_type:
assert default_value is None
......@@ -1126,7 +1126,7 @@ if __name__ == "__main__":
fwd_returns_str = fwd_api['output']
bwd_api_name = fwd_api['backward']
assert bwd_api_name in grad_api_dict.keys()
assert bwd_api_name in grad_api_dict.keys(), bwd_api_name
bwd_api = grad_api_dict[bwd_api_name]
assert 'args' in bwd_api.keys()
......
......@@ -17,6 +17,9 @@ limitations under the License. */
#include <tuple>
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
......
......@@ -147,6 +147,7 @@ def source_include(header_file_path):
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h"
"""
......
......@@ -98,6 +98,7 @@ def source_include(header_file_path):
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h"
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册