You need to sign in or sign up before continuing.
提交 c83ea1cd 编写于 作者: Y Yang Yang

remove hardcode add_XX_op

上级 a67e8ea3
......@@ -99,121 +99,6 @@ void AddOp(const std::string& type, const VariableNameMap& inputs,
}
}
void add_gaussian_random_op(string var_name, std::vector<int> dim,
proto_block* block) {
// insert variable
auto a = block->add_vars();
a->set_name(var_name);
auto a_lt = a->mutable_lod_tensor();
a_lt->set_data_type(paddle::framework::DataType::FP32);
for (int i : dim) {
a_lt->add_dims(i);
}
// insert operation
auto op = block->add_ops();
op->set_type("gaussian_random");
auto dims = op->add_attrs();
dims->set_name("dims");
dims->set_type(paddle::framework::AttrType::INTS);
for (int i : dim) {
dims->add_ints(i);
}
auto Out = op->add_outputs();
Out->set_parameter("Out");
Out->add_arguments(var_name);
}
void add_feed_op(string var_name, std::vector<int>& dim, int index,
proto_block* block) {
// insert variable
auto a = block->add_vars();
a->set_name(var_name);
auto a_lt = a->mutable_lod_tensor();
a_lt->set_data_type(paddle::framework::DataType::FP32);
for (int i : dim) {
a_lt->add_dims(i);
}
// insert operation
auto op = block->add_ops();
op->set_type("feed");
// set dims attr
auto dims = op->add_attrs();
dims->set_name("dims");
dims->set_type(paddle::framework::AttrType::INTS);
for (int i : dim) {
dims->add_ints(i);
}
// set col attr
auto col = op->add_attrs();
col->set_name("col");
col->set_type(paddle::framework::AttrType::INT);
col->set_i(index);
auto Out = op->add_outputs();
Out->set_parameter("Out");
Out->add_arguments(var_name);
}
void add_fetch_op(string var_name, std::vector<int> dim, int index,
proto_block* block) {
// insert variable
auto a = block->add_vars();
a->set_name(var_name);
auto a_lt = a->mutable_lod_tensor();
a_lt->set_data_type(paddle::framework::DataType::FP32);
for (int i : dim) {
a_lt->add_dims(i);
}
// insert operation
auto op = block->add_ops();
op->set_type("fetch");
// set dims attr
auto dims = op->add_attrs();
dims->set_name("dims");
dims->set_type(paddle::framework::AttrType::INTS);
for (int i : dim) {
dims->add_ints(i);
}
// set col attr
auto col = op->add_attrs();
col->set_name("col");
col->set_type(paddle::framework::AttrType::INT);
col->set_i(index);
auto Out = op->add_inputs();
Out->set_parameter("Input");
Out->add_arguments(var_name);
}
void add_mul_op(string X_str, string Y_str, string Out_str,
proto_block* block) {
// insert variable
auto a = block->add_vars();
a->set_name(Out_str);
auto a_lt = a->mutable_lod_tensor();
a_lt->set_data_type(paddle::framework::DataType::FP32);
// insert op
auto op = block->add_ops();
op->set_type("mul");
auto X = op->add_inputs();
X->set_parameter("X");
X->add_arguments(X_str);
auto Y = op->add_inputs();
Y->set_parameter("Y");
Y->add_arguments(Y_str);
auto Out = op->add_outputs();
Out->set_parameter("Out");
Out->add_arguments(Out_str);
}
std::once_flag set_variable_flag;
// Tensors in feed value variable will only be in CPUPlace
......@@ -268,21 +153,27 @@ class ExecutorTesterRandom : public ::testing::Test {
AddOp("gaussian_random", {}, {{"Out", {"w2"}}},
{{"dims", std::vector<int>{embed_dim, input_dim}}}, init_root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {},
{{"dims", std::vector<int>{input_dim, embed_dim}}}, init_root_block);
{{"dims", std::vector<int>{input_dim, embed_dim}}, {"col", 0}},
init_root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {},
{{"dims", std::vector<int>{embed_dim, input_dim}}}, init_root_block);
{{"dims", std::vector<int>{embed_dim, input_dim}}, {"col", 1}},
init_root_block);
// run pdesc
auto root_block = pdesc_.add_blocks();
root_block->set_idx(0);
root_block->set_parent_idx(-1);
add_gaussian_random_op("a", {batch_size, input_dim}, root_block);
add_mul_op("a", "w1", "b", root_block);
add_mul_op("b", "w2", "a_out", root_block);
AddOp("gaussian_random", {}, {{"Out", {"a"}}},
{{"dims", std::vector<int>{batch_size, input_dim}}}, root_block);
AddOp("mul", {{"X", {"a"}}, {"Y", {"w1"}}}, {{"Out", {"b"}}}, {},
root_block);
AddOp("mul", {{"X", {"b"}}, {"Y", {"w2"}}}, {{"Out", {"a_out"}}}, {},
root_block);
add_fetch_op("a_out", {input_dim, batch_size}, 0, root_block);
AddOp("fetch", {{"Input", {"a_out"}}}, {},
{{"dims", std::vector<int>{input_dim, batch_size}}, {"col", 1}},
root_block);
}
protected:
......@@ -299,10 +190,14 @@ class ExecutorTesterFeedAndFetch : public ::testing::Test {
std::vector<int> dim{6};
add_feed_op("a", dim, 0, root_block);
add_feed_op("b", dim, 1, root_block);
add_fetch_op("a", dim, 0, root_block);
add_fetch_op("b", dim, 1, root_block);
AddOp("feed", {}, {{"Out", {"a"}}}, {{"dims", dim}, {"col", 0}},
root_block);
AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}},
root_block);
AddOp("fetch", {{"Input", {"a"}}}, {}, {{"dims", dim}, {"col", 0}},
root_block);
AddOp("fetch", {{"Input", {"b"}}}, {}, {{"dims", dim}, {"col", 1}},
root_block);
std::vector<float> vec1 = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
std::vector<float> vec2 = {4.0, 5.0, 6.0, 7.0, 8.0, 9.0};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册