提交 a67e8ea3 编写于 作者: Y Yang Yang

Add AddOp

上级 005f15b4
...@@ -16,7 +16,9 @@ limitations under the License. */ ...@@ -16,7 +16,9 @@ limitations under the License. */
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
...@@ -24,6 +26,7 @@ USE_OP(elementwise_add); ...@@ -24,6 +26,7 @@ USE_OP(elementwise_add);
USE_OP(gaussian_random); USE_OP(gaussian_random);
USE_OP(feed); USE_OP(feed);
USE_OP(fetch); USE_OP(fetch);
USE_OP(mul);
using std::string; using std::string;
using namespace paddle::platform; using namespace paddle::platform;
...@@ -32,7 +35,71 @@ using namespace paddle::framework; ...@@ -32,7 +35,71 @@ using namespace paddle::framework;
typedef paddle::framework::BlockDesc proto_block; typedef paddle::framework::BlockDesc proto_block;
typedef paddle::framework::OpDesc proto_op; typedef paddle::framework::OpDesc proto_op;
void add_gaussian_random_op(string var_name, std::vector<int>& dim, struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr* attr) : attr_(attr) {}
mutable OpDesc::Attr* attr_;
void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string& v) const { attr_->set_s(v); }
void operator()(bool b) const { attr_->set_b(b); }
void operator()(const std::vector<int>& v) const {
VectorToRepeated(v, attr_->mutable_ints());
}
void operator()(const std::vector<float>& v) const {
VectorToRepeated(v, attr_->mutable_floats());
}
void operator()(const std::vector<std::string>& v) const {
VectorToRepeated(v, attr_->mutable_strings());
}
void operator()(const std::vector<bool>& v) const {
VectorToRepeated(v, attr_->mutable_bools());
}
void operator()(BlockDesc* desc) const { attr_->set_block_idx(desc->idx()); }
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
void AddOp(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, AttributeMap attrs,
proto_block* block) {
// insert output
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->add_vars();
var->set_name(v);
auto var_lt = var->mutable_lod_tensor();
var_lt->set_data_type(paddle::framework::DataType::FP32);
}
}
// insert op
auto op = block->add_ops();
op->set_type(type);
for (auto kv : inputs) {
auto X = op->add_inputs();
X->set_parameter(kv.first);
for (auto argu : kv.second) {
X->add_arguments(argu);
}
}
for (auto kv : outputs) {
auto X = op->add_outputs();
X->set_parameter(kv.first);
for (auto argu : kv.second) {
X->add_arguments(argu);
}
}
for (auto& attr : attrs) {
auto* attr_desc = op->add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast<paddle::framework::AttrType>(attr.second.which() - 1));
SetAttrDescVisitor visitor(attr_desc);
boost::apply_visitor(visitor, attr.second);
}
}
void add_gaussian_random_op(string var_name, std::vector<int> dim,
proto_block* block) { proto_block* block) {
// insert variable // insert variable
auto a = block->add_vars(); auto a = block->add_vars();
...@@ -91,7 +158,7 @@ void add_feed_op(string var_name, std::vector<int>& dim, int index, ...@@ -91,7 +158,7 @@ void add_feed_op(string var_name, std::vector<int>& dim, int index,
Out->add_arguments(var_name); Out->add_arguments(var_name);
} }
void add_fetch_op(string var_name, std::vector<int>& dim, int index, void add_fetch_op(string var_name, std::vector<int> dim, int index,
proto_block* block) { proto_block* block) {
// insert variable // insert variable
auto a = block->add_vars(); auto a = block->add_vars();
...@@ -125,6 +192,28 @@ void add_fetch_op(string var_name, std::vector<int>& dim, int index, ...@@ -125,6 +192,28 @@ void add_fetch_op(string var_name, std::vector<int>& dim, int index,
Out->add_arguments(var_name); Out->add_arguments(var_name);
} }
void add_mul_op(string X_str, string Y_str, string Out_str,
proto_block* block) {
// insert variable
auto a = block->add_vars();
a->set_name(Out_str);
auto a_lt = a->mutable_lod_tensor();
a_lt->set_data_type(paddle::framework::DataType::FP32);
// insert op
auto op = block->add_ops();
op->set_type("mul");
auto X = op->add_inputs();
X->set_parameter("X");
X->add_arguments(X_str);
auto Y = op->add_inputs();
Y->set_parameter("Y");
Y->add_arguments(Y_str);
auto Out = op->add_outputs();
Out->set_parameter("Out");
Out->add_arguments(Out_str);
}
std::once_flag set_variable_flag; std::once_flag set_variable_flag;
// Tensors in feed value variable will only be in CPUPlace // Tensors in feed value variable will only be in CPUPlace
...@@ -168,36 +257,37 @@ std::vector<std::vector<T>> get_fetch_variable() { ...@@ -168,36 +257,37 @@ std::vector<std::vector<T>> get_fetch_variable() {
class ExecutorTesterRandom : public ::testing::Test { class ExecutorTesterRandom : public ::testing::Test {
public: public:
virtual void SetUp() override { virtual void SetUp() override {
int input_dim = 5, batch_size = 2, embed_dim = 5;
// init pdesc
auto init_root_block = init_pdesc_.add_blocks();
init_root_block->set_idx(0);
init_root_block->set_parent_idx(-1);
AddOp("gaussian_random", {}, {{"Out", {"w1"}}},
{{"dims", std::vector<int>{input_dim, embed_dim}}}, init_root_block);
AddOp("gaussian_random", {}, {{"Out", {"w2"}}},
{{"dims", std::vector<int>{embed_dim, input_dim}}}, init_root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {},
{{"dims", std::vector<int>{input_dim, embed_dim}}}, init_root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {},
{{"dims", std::vector<int>{embed_dim, input_dim}}}, init_root_block);
// run pdesc
auto root_block = pdesc_.add_blocks(); auto root_block = pdesc_.add_blocks();
root_block->set_idx(0); root_block->set_idx(0);
root_block->set_parent_idx(-1); root_block->set_parent_idx(-1);
std::vector<int> dim{2, 3}; add_gaussian_random_op("a", {batch_size, input_dim}, root_block);
add_gaussian_random_op("a", dim, root_block);
add_gaussian_random_op("b", dim, root_block);
auto c = root_block->add_vars(); add_mul_op("a", "w1", "b", root_block);
c->set_name("c"); add_mul_op("b", "w2", "a_out", root_block);
auto c_lt = c->mutable_lod_tensor();
c_lt->set_data_type(paddle::framework::DataType::FP32);
auto op = root_block->add_ops(); add_fetch_op("a_out", {input_dim, batch_size}, 0, root_block);
op->set_type("elementwise_add");
auto X = op->add_inputs();
X->set_parameter("X");
X->add_arguments("a");
auto Y = op->add_inputs();
Y->set_parameter("Y");
Y->add_arguments("b");
auto Out = op->add_outputs();
Out->set_parameter("Out");
Out->add_arguments("c");
add_fetch_op("c", dim, 0, root_block);
} }
protected: protected:
ProgramDesc pdesc_; ProgramDesc pdesc_;
ProgramDesc init_pdesc_;
}; };
class ExecutorTesterFeedAndFetch : public ::testing::Test { class ExecutorTesterFeedAndFetch : public ::testing::Test {
...@@ -238,6 +328,7 @@ TEST_F(ExecutorTesterRandom, CPU) { ...@@ -238,6 +328,7 @@ TEST_F(ExecutorTesterRandom, CPU) {
paddle::memory::Used(cpu_place); paddle::memory::Used(cpu_place);
Executor* executor = new Executor(places); Executor* executor = new Executor(places);
executor->Run(init_pdesc_, GetGlobalScope());
executor->Run(pdesc_, GetGlobalScope()); executor->Run(pdesc_, GetGlobalScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>(); std::vector<std::vector<float>> result = get_fetch_variable<float>();
...@@ -295,7 +386,19 @@ TEST_F(ExecutorTesterRandom, GPU) { ...@@ -295,7 +386,19 @@ TEST_F(ExecutorTesterRandom, GPU) {
paddle::memory::Used(gpu_place); paddle::memory::Used(gpu_place);
Executor* executor = new Executor(places); Executor* executor = new Executor(places);
LOG(INFO) << "Run Init";
executor->Run(init_pdesc_, GetGlobalScope());
LOG(INFO) << "Run";
executor->Run(pdesc_, GetGlobalScope()); executor->Run(pdesc_, GetGlobalScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>();
for (auto& vec : result) {
for (auto& num : vec) {
std::cout << num << " ";
}
std::cout << std::endl;
}
delete executor; delete executor;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册