未验证 提交 c077de3c 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Rearranged Eager AutoCodeGen directory structure (#37812)

* Rearranged Eager AutoCodeGen directory structure

* Removed USE_OP in Eager AutoCodeGen
上级 809ba964
...@@ -30,26 +30,63 @@ ...@@ -30,26 +30,63 @@
DEFINE_bool(generate_all, false, DEFINE_bool(generate_all, false,
"Generate all operators currently registered in Paddle"); "Generate all operators currently registered in Paddle");
static std::unordered_map<std::string, paddle::framework::AttributeMap>
operators_with_attrs = {};
static std::unordered_set<std::string> operators_to_skip = { static std::unordered_set<std::string> operators_to_skip = {
"fused_elemwise_add_activation", // No Default Attr "pull_sparse", "pull_box_extended_sparse", "pull_sparse_v2",
"fused_elemwise_activation", // No Default Attr "pull_box_sparse", "fused_attention", "diag_v2",
"reverse", // Attr Error
"flip", // Attr Error
"cast", // Attr Error
"sum",
"minus", // Multiple ops_
"pull_sparse",
"pull_box_extended_sparse",
"pull_sparse_v2",
"pull_box_sparse",
"fused_attention",
"diag_v2",
"transfer_dtype",
"c_split"}; "c_split"};
static std::unordered_set<std::string> operators_to_codegen = {}; static std::unordered_set<std::string> operators_to_codegen = {};
static std::unordered_set<std::string> skipped_operators = {}; static std::unordered_set<std::string> skipped_operators = {};
static void PrepareAttrMapForOps() {
// Handle "fused_elemwise_add_activation"
std::vector<std::string> functor_list = {"a", "b"};
operators_with_attrs["fused_elemwise_add_activation"] = {};
operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] =
functor_list;
// Handle "fused_elemwise_activation"
operators_with_attrs["fused_elemwise_activation"] = {};
operators_with_attrs["fused_elemwise_activation"]["functor_list"] =
functor_list;
// Handle "reverse"
std::vector<int> axis = {0};
operators_with_attrs["reverse"] = {};
operators_with_attrs["reverse"]["axis"] = axis;
// Handle "flip"
operators_with_attrs["flip"] = {};
operators_with_attrs["flip"]["axis"] = axis;
// Handle "cast"
operators_with_attrs["cast"] = {};
operators_with_attrs["cast"]["out_dtype"] = 5;
operators_with_attrs["cast"]["in_dtype"] = 5;
// Handle "transfer_dtype"
operators_with_attrs["transfer_dtype"] = {};
operators_with_attrs["transfer_dtype"]["out_dtype"] = 5;
operators_with_attrs["transfer_dtype"]["in_dtype"] = 5;
}
static void CollectOperatorsToCodeGen(const std::string& op_list_path) {
std::string line;
std::ifstream op_list_file(op_list_path);
if (op_list_file.is_open()) {
while (getline(op_list_file, line)) {
operators_to_codegen.insert(line);
}
op_list_file.close();
} else {
PADDLE_THROW(
paddle::platform::errors::Fatal("Unable to open op_list.txt file"));
}
}
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -573,12 +610,21 @@ static bool CollectInformationFromOpInfo( ...@@ -573,12 +610,21 @@ static bool CollectInformationFromOpInfo(
paddle::framework::AttributeMap default_attrs; paddle::framework::AttributeMap default_attrs;
auto* attr_checker = op_info.Checker(); auto* attr_checker = op_info.Checker();
if (attr_checker) { if (attr_checker) {
VLOG(6) << "Checking AttributeMap Settings";
attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true); attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
default_attrs = attr_checker->GetDefaultAttrMap(); default_attrs = attr_checker->GetDefaultAttrMap();
VLOG(6) << "AttributeMap Checking Passed";
} else { } else {
VLOG(6) << "Detected Null Attribute Checker, use empty default_attrs"; VLOG(6) << "Detected Null Attribute Checker, use empty default_attrs";
} }
if (operators_with_attrs.count(op_type)) {
VLOG(6) << "Found operator " << op_type << " using special AttributeMap";
attrs = operators_with_attrs[op_type];
// default_attrs.insert(operators_with_attrs[op_type].begin(),
// operators_with_attrs[op_type].end());
}
VLOG(6) << "Prepared Default Attributes Map, size = " << default_attrs.size(); VLOG(6) << "Prepared Default Attributes Map, size = " << default_attrs.size();
/* ---------------------------- */ /* ---------------------------- */
...@@ -851,18 +897,6 @@ static std::string GenerateGradNodeCreationContent( ...@@ -851,18 +897,6 @@ static std::string GenerateGradNodeCreationContent(
return grad_node_creation_body_str; return grad_node_creation_body_str;
} }
static std::string AppendUseOp(const std::string& op_type) {
// [Generation] Append USE_OP
const char* USE_OP_TEMPLATE = "USE_OP(%s);\n";
std::string return_str = paddle::string::Sprintf(USE_OP_TEMPLATE, op_type);
// Special Ops
if (op_type == "reduce_sum")
return_str += paddle::string::Sprintf(USE_OP_TEMPLATE, "reduce_sum_grad");
return return_str;
}
/* -------------------------------- */ /* -------------------------------- */
/* --------- CodeGen: Forward ----- */ /* --------- CodeGen: Forward ----- */
/* -------------------------------- */ /* -------------------------------- */
...@@ -1110,9 +1144,6 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1110,9 +1144,6 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
FWD_FUNCTION_TEMPLATE, function_proto_return_type_str, function_name, FWD_FUNCTION_TEMPLATE, function_proto_return_type_str, function_name,
dygraph_function_args_str, generated_function_body); dygraph_function_args_str, generated_function_body);
// [Generation] Append USE_OP
fwd_function_str += AppendUseOp(op_type);
// [Generation] Generate forward functions header // [Generation] Generate forward functions header
const char* FWD_HEADER_TEMPLATE = "%s %s(%s);\n"; const char* FWD_HEADER_TEMPLATE = "%s %s(%s);\n";
std::string dygraph_function_declaration_str = paddle::string::Sprintf( std::string dygraph_function_declaration_str = paddle::string::Sprintf(
...@@ -1480,34 +1511,31 @@ static void GenerateForwardHFile(const std::string& output_dir, ...@@ -1480,34 +1511,31 @@ static void GenerateForwardHFile(const std::string& output_dir,
forward_header_stream.close(); forward_header_stream.close();
} }
static void GenerateForwardDygraphFile(const std::string& op_type, static void GenerateForwardDygraphFile(const std::string& output_dir,
const std::string& output_dir,
const std::string& fwd_function_str) { const std::string& fwd_function_str) {
std::string forwards_dir = output_dir + "/forwards/"; std::string forwards_dir = output_dir + "/forwards/";
std::string node_h_filename = op_type + "_node.h"; std::string forward_cc_filename = "dygraph_forward_functions.cc";
std::string forward_cc_filename = op_type + "_dygraph.cc";
std::string forward_cc_path = forwards_dir + forward_cc_filename; std::string forward_cc_path = forwards_dir + forward_cc_filename;
const char* FORWARD_INCLUDE_TEMPLATE = const char* FORWARD_INCLUDE_TEMPLATE =
"#include " "#include "
"\"paddle/fluid/eager/api/generated/fluid_generated/" "\"paddle/fluid/eager/api/generated/fluid_generated/"
"dygraph_forward_api.h\"\n" "dygraph_forward_api.h\"\n"
"#include " "#include "
"\"paddle/fluid/eager/api/generated/fluid_generated/nodes/%s\"\n\n" "\"paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.h\"\n\n"
"#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n" "#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n"
"#include \"paddle/fluid/eager/legacy/op_runner.h\"\n"; "#include \"paddle/fluid/eager/legacy/op_runner.h\"\n";
std::string forward_cc_include_str = std::string forward_cc_include_str =
paddle::string::Sprintf(FORWARD_INCLUDE_TEMPLATE, node_h_filename); paddle::string::Sprintf(FORWARD_INCLUDE_TEMPLATE);
std::ofstream forward_cc_stream(forward_cc_path, std::ios::out); std::ofstream forward_cc_stream(forward_cc_path, std::ios::out);
forward_cc_stream << forward_cc_include_str; forward_cc_stream << forward_cc_include_str;
forward_cc_stream << fwd_function_str; forward_cc_stream << fwd_function_str;
forward_cc_stream.close(); forward_cc_stream.close();
} }
static void GenerateNodeHFile(const std::string& op_type, static void GenerateNodeHFile(const std::string& output_dir,
const std::string& output_dir,
const std::string& grad_node_str) { const std::string& grad_node_str) {
std::string nodes_dir = output_dir + "/nodes/"; std::string nodes_dir = output_dir + "/nodes/";
std::string node_h_filename = op_type + "_node.h"; std::string node_h_filename = "nodes.h";
std::string node_h_path = nodes_dir + node_h_filename; std::string node_h_path = nodes_dir + node_h_filename;
std::string node_h_include_str = std::string node_h_include_str =
"#pragma once\n" "#pragma once\n"
...@@ -1520,12 +1548,10 @@ static void GenerateNodeHFile(const std::string& op_type, ...@@ -1520,12 +1548,10 @@ static void GenerateNodeHFile(const std::string& op_type,
node_h_stream.close(); node_h_stream.close();
} }
static void GenerateNodeCCFile(const std::string& op_type, static void GenerateNodeCCFile(const std::string& output_dir,
const std::string& output_dir,
const std::string& grad_function_str) { const std::string& grad_function_str) {
std::string nodes_dir = output_dir + "/nodes/"; std::string nodes_dir = output_dir + "/nodes/";
std::string node_h_filename = op_type + "_node.h"; std::string node_cc_filename = "nodes.cc";
std::string node_cc_filename = op_type + "_node.cc";
std::string node_cc_path = nodes_dir + node_cc_filename; std::string node_cc_path = nodes_dir + node_cc_filename;
const char* NODE_CC_INCLUDE_TEMPLATE = const char* NODE_CC_INCLUDE_TEMPLATE =
"#include \"glog/logging.h\"\n" "#include \"glog/logging.h\"\n"
...@@ -1535,9 +1561,9 @@ static void GenerateNodeCCFile(const std::string& op_type, ...@@ -1535,9 +1561,9 @@ static void GenerateNodeCCFile(const std::string& op_type,
"#include \"paddle/fluid/eager/utils.h\"\n" "#include \"paddle/fluid/eager/utils.h\"\n"
"#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n" "#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n"
"#include " "#include "
"\"paddle/fluid/eager/api/generated/fluid_generated/nodes/%s\"\n\n"; "\"paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.h\"\n\n";
std::string node_cc_include_str = std::string node_cc_include_str =
paddle::string::Sprintf(NODE_CC_INCLUDE_TEMPLATE, node_h_filename); paddle::string::Sprintf(NODE_CC_INCLUDE_TEMPLATE);
std::ofstream node_cc_stream(node_cc_path, std::ios::out); std::ofstream node_cc_stream(node_cc_path, std::ios::out);
node_cc_stream << node_cc_include_str; node_cc_stream << node_cc_include_str;
node_cc_stream << grad_function_str; node_cc_stream << grad_function_str;
...@@ -1558,6 +1584,9 @@ static std::string GenerateDygraphHFileIncludes() { ...@@ -1558,6 +1584,9 @@ static std::string GenerateDygraphHFileIncludes() {
static void DygraphCodeGeneration(const std::string& output_dir) { static void DygraphCodeGeneration(const std::string& output_dir) {
std::string dygraph_forward_api_str = GenerateDygraphHFileIncludes(); std::string dygraph_forward_api_str = GenerateDygraphHFileIncludes();
std::string fwd_function_str = "";
std::string grad_node_h_str = "";
std::string grad_node_cc_str = "";
auto& op_info_map = paddle::framework::OpInfoMap::Instance().map(); auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();
...@@ -1603,7 +1632,7 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -1603,7 +1632,7 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
/* --------------------------- */ /* --------------------------- */
/* --------- CodeGen --------- */ /* --------- CodeGen --------- */
/* --------------------------- */ /* --------------------------- */
/* ---- xxx_dygraph.cc ---- */ /* ---- forward_dygraph_functions.cc ---- */
VLOG(6) << "-------- GenerateForwardFunctionContents -------"; VLOG(6) << "-------- GenerateForwardFunctionContents -------";
std::pair<std::string, std::string> body_and_declaration = std::pair<std::string, std::string> body_and_declaration =
GenerateForwardFunctionContents( GenerateForwardFunctionContents(
...@@ -1611,56 +1640,53 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -1611,56 +1640,53 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
grad_ins_fwd_slotname_map, grad_ins_grad_slotname_map, grad_ins_fwd_slotname_map, grad_ins_grad_slotname_map,
grad_outs_slotname_map, grad_ins, grad_outs, op_type, in_vars, grad_outs_slotname_map, grad_ins, grad_outs, op_type, in_vars,
out_vars); out_vars);
std::string fwd_function_str = body_and_declaration.first; fwd_function_str += body_and_declaration.first + "\n";
GenerateForwardDygraphFile(op_type, output_dir, fwd_function_str);
/* ---- dygraph_forward_api.h ---- */ /* ---- dygraph_forward_api.h ---- */
std::string fwd_function_declare_str = body_and_declaration.second; std::string fwd_function_declare_str = body_and_declaration.second;
dygraph_forward_api_str += fwd_function_declare_str; dygraph_forward_api_str += fwd_function_declare_str;
/* ---- xxx_node.h ---- */ /* ---- nodes.h ---- */
VLOG(6) << "-------- GenerateGradNodeHeaderContents -------"; VLOG(6) << "-------- GenerateGradNodeHeaderContents -------";
std::string grad_node_h_str = GenerateGradNodeHeaderContents( grad_node_h_str +=
grad_ins_fwd_slotname_map, op_type, in_vars, out_vars); GenerateGradNodeHeaderContents(grad_ins_fwd_slotname_map, op_type,
GenerateNodeHFile(op_type, output_dir, grad_node_h_str); in_vars, out_vars) +
"\n";
/* ---- xxx_node.cc ---- */ /* ---- nodes.cc ---- */
VLOG(6) << "-------- GenerateGradNodeCCContents -------"; VLOG(6) << "-------- GenerateGradNodeCCContents -------";
std::string grad_node_cc_str = GenerateGradNodeCCContents( grad_node_cc_str += GenerateGradNodeCCContents(
grad_op_types, fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, grad_op_types, fwd_inputs_name_pos_map,
grad_ins_fwd_slotname_map, grad_ins_grad_slotname_map, fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map,
grad_outs_slotname_map, grad_ins, grad_outs, op_type, in_vars, grad_ins_grad_slotname_map, grad_outs_slotname_map,
out_vars); grad_ins, grad_outs, op_type, in_vars, out_vars) +
GenerateNodeCCFile(op_type, output_dir, grad_node_cc_str); "\n";
VLOG(6) << op_type << ": Finished Generation"; VLOG(6) << op_type << ": Finished Generating Op: " << op_type;
} }
/* ---- dygraph_forward_function.cc ---- */
VLOG(6) << "-------- GenerateDygraphForwardCCFile -------";
GenerateForwardDygraphFile(output_dir, fwd_function_str);
/* ---- dygraph_forward_api.h ---- */ /* ---- dygraph_forward_api.h ---- */
VLOG(6) << "-------- GenerateForwardHFile -------"; VLOG(6) << "-------- GenerateForwardHFile -------";
GenerateForwardHFile(output_dir, dygraph_forward_api_str); GenerateForwardHFile(output_dir, dygraph_forward_api_str);
/* ---- nodes.h ---- */
VLOG(6) << "-------- GenerateNodeHFile -------";
GenerateNodeHFile(output_dir, grad_node_h_str);
/* ---- nodes.cc ---- */
VLOG(6) << "-------- GenerateNodeCCFile -------";
GenerateNodeCCFile(output_dir, grad_node_cc_str);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
static void CollectOperatorsToCodeGen(const std::string& op_list_path) {
std::string line;
std::ifstream op_list_file(op_list_path);
if (op_list_file.is_open()) {
while (getline(op_list_file, line)) {
operators_to_codegen.insert(line);
}
op_list_file.close();
} else {
PADDLE_THROW(
paddle::platform::errors::Fatal("Unable to open op_list.txt file"));
}
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
if (argc != 3) { if (argc != 3) {
std::cerr << "argc must be 2" << std::endl; std::cerr << "argc must be 3" << std::endl;
return -1; return -1;
} }
...@@ -1668,6 +1694,8 @@ int main(int argc, char* argv[]) { ...@@ -1668,6 +1694,8 @@ int main(int argc, char* argv[]) {
std::string op_list_path = argv[2]; std::string op_list_path = argv[2];
CollectOperatorsToCodeGen(op_list_path); CollectOperatorsToCodeGen(op_list_path);
PrepareAttrMapForOps();
paddle::framework::DygraphCodeGeneration(eager_root); paddle::framework::DygraphCodeGeneration(eager_root);
return 0; return 0;
......
...@@ -18,12 +18,6 @@ import os ...@@ -18,12 +18,6 @@ import os
if __name__ == "__main__": if __name__ == "__main__":
assert len(sys.argv) == 2 assert len(sys.argv) == 2
eager_dir = sys.argv[1] eager_dir = sys.argv[1]
op_list = []
with open(f"{eager_dir}/auto_code_generator/op_list.txt", "r") as f:
for line in f:
line = str(line.strip())
op_list.append(line)
""" """
paddle/fluid/eager paddle/fluid/eager
|- generated |- generated
...@@ -31,15 +25,15 @@ if __name__ == "__main__": ...@@ -31,15 +25,15 @@ if __name__ == "__main__":
| | "add_subdirectory(forwards), add_subdirectory(nodes)" | | "add_subdirectory(forwards), add_subdirectory(nodes)"
| |
| |- forwards | |- forwards
| |- op_name + "_dygraph.cc" | |- "dygraph_forward_functions.cc"
| |- CMakeLists.txt | |- CMakeLists.txt
| | "cc_library(dygraph_function SRCS op_name+"_dygraph.cc" DEPS ${eager_deps} ${fluid_deps} GLOB_OP_LIB)" | | "cc_library(dygraph_function SRCS dygraph_forward_functions.cc DEPS ${eager_deps} ${fluid_deps} GLOB_OP_LIB)"
| |
| |- nodes | |- nodes
| |- op_name + "_node.cc" | |- "nodes.cc"
| |- op_name + "_node.h" | |- "nodes.h"
| |- CMakeLists.txt | |- CMakeLists.txt
| | "cc_library(dygraph_node SRCS op_name+"_node.cc" DEPS ${eager_deps} ${fluid_deps})" | | "cc_library(dygraph_node SRCS nodes.cc DEPS ${eager_deps} ${fluid_deps})"
| |
| |- dygraph_forward_api.h | |- dygraph_forward_api.h
""" """
...@@ -56,10 +50,10 @@ if __name__ == "__main__": ...@@ -56,10 +50,10 @@ if __name__ == "__main__":
dygraph_forward_api_h_path = os.path.join(generated_dir, dygraph_forward_api_h_path = os.path.join(generated_dir,
"dygraph_forward_api.h") "dygraph_forward_api.h")
empty_files = [dygraph_forward_api_h_path] empty_files = [dygraph_forward_api_h_path]
for op_name in op_list: empty_files.append(
empty_files.append(os.path.join(forwards_dir, op_name + "_dygraph.cc")) os.path.join(forwards_dir, "dygraph_forward_functions.cc"))
empty_files.append(os.path.join(nodes_dir, op_name + "_node.cc")) empty_files.append(os.path.join(nodes_dir, "nodes.cc"))
empty_files.append(os.path.join(nodes_dir, op_name + "_node.h")) empty_files.append(os.path.join(nodes_dir, "nodes.h"))
for path in empty_files: for path in empty_files:
if not os.path.exists(path): if not os.path.exists(path):
...@@ -73,14 +67,14 @@ if __name__ == "__main__": ...@@ -73,14 +67,14 @@ if __name__ == "__main__":
with open(nodes_level_cmakelist_path, "w") as f: with open(nodes_level_cmakelist_path, "w") as f:
f.write( f.write(
"cc_library(dygraph_node SRCS %s DEPS ${eager_deps} ${fluid_deps})\n" "cc_library(dygraph_node SRCS nodes.cc DEPS ${eager_deps} ${fluid_deps})\n"
% " ".join([op_name + '_node.cc' for op_name in op_list])) )
f.write("add_dependencies(dygraph_node eager_codegen)") f.write("add_dependencies(dygraph_node eager_codegen)")
with open(forwards_level_cmakelist_path, "w") as f: with open(forwards_level_cmakelist_path, "w") as f:
f.write( f.write(
"cc_library(dygraph_function SRCS %s DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB})\n" "cc_library(dygraph_function SRCS dygraph_forward_functions.cc DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})\n"
% " ".join([op_name + '_dygraph.cc' for op_name in op_list])) )
f.write("add_dependencies(dygraph_function eager_codegen)") f.write("add_dependencies(dygraph_function eager_codegen)")
with open(generated_level_cmakelist_path, "w") as f: with open(generated_level_cmakelist_path, "w") as f:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// Eager Dygraph // Eager Dygraph
#include <paddle/fluid/framework/op_registry.h>
#include <chrono> #include <chrono>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -178,3 +179,8 @@ TEST(Benchmark, EagerIntermediateMLPCPU) { ...@@ -178,3 +179,8 @@ TEST(Benchmark, EagerIntermediateMLPCPU) {
} }
} }
} }
USE_OP(scale);
USE_OP(elementwise_add);
USE_OP(matmul_v2);
USE_OP(reduce_sum);
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
// Eager Dygraph // Eager Dygraph
#include <paddle/fluid/framework/op_registry.h>
#include <chrono> #include <chrono>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -188,4 +189,10 @@ TEST(Benchmark, EagerIntermediateMLPCUDA) { ...@@ -188,4 +189,10 @@ TEST(Benchmark, EagerIntermediateMLPCUDA) {
} }
} }
USE_OP(scale);
USE_OP(matmul_v2);
USE_OP(reduce_sum);
USE_OP(reduce_sum_grad);
USE_OP(elementwise_add);
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP #endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
...@@ -217,5 +217,6 @@ TEST(Benchmark, FluidMLPCPU) { ...@@ -217,5 +217,6 @@ TEST(Benchmark, FluidMLPCPU) {
} // namespace paddle } // namespace paddle
USE_OP(scale); USE_OP(scale);
USE_OP(elementwise_add);
USE_OP(matmul_v2); USE_OP(matmul_v2);
USE_OP(reduce_sum); USE_OP(reduce_sum);
...@@ -254,5 +254,6 @@ USE_OP(scale); ...@@ -254,5 +254,6 @@ USE_OP(scale);
USE_OP(matmul_v2); USE_OP(matmul_v2);
USE_OP(reduce_sum); USE_OP(reduce_sum);
USE_OP(reduce_sum_grad); USE_OP(reduce_sum_grad);
USE_OP(elementwise_add);
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP #endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
...@@ -89,4 +89,39 @@ TEST(Generated, Matmul_v2) { ...@@ -89,4 +89,39 @@ TEST(Generated, Matmul_v2) {
eager_test::CompareGradVariableWithValue<float>(Y, 3.0 * 4); eager_test::CompareGradVariableWithValue<float>(Y, 3.0 * 4);
} }
TEST(Generated, ElementwiseAdd) {
// Prepare Device Contexts
eager_test::InitEnv(paddle::platform::CPUPlace());
auto tracer = std::make_shared<paddle::imperative::Tracer>();
paddle::imperative::SetCurrentTracer(tracer);
// 1. Prepare Input
paddle::framework::DDim ddimX = paddle::framework::make_ddim({4, 16});
egr::EagerTensor X = egr_utils_api::CreateTensorWithValue(
ddimX, paddle::platform::CPUPlace(), pten::DataType::FLOAT32,
pten::DataLayout::NCHW, 3.0, true);
egr_utils_api::RetainGradForTensor(X);
paddle::framework::DDim ddimY = paddle::framework::make_ddim({4, 16});
egr::EagerTensor Y = egr_utils_api::CreateTensorWithValue(
ddimY, paddle::platform::CPUPlace(), pten::DataType::FLOAT32,
pten::DataLayout::NCHW, 2.0, true);
egr_utils_api::RetainGradForTensor(Y);
auto output_tensor = elementwise_add_dygraph_function(X, Y, {});
eager_test::CompareVariableWithValue<float>(output_tensor, 5);
std::vector<egr::EagerTensor> target_tensors = {output_tensor};
RunBackward(target_tensors, {});
eager_test::CompareGradVariableWithValue<float>(X, 1.0);
eager_test::CompareGradVariableWithValue<float>(Y, 1.0);
}
} // namespace egr } // namespace egr
USE_OP(sigmoid);
USE_OP(elementwise_add);
USE_OP(matmul_v2);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册