提交 932402c1 编写于 作者: Y Yang Yang

debug for sum

上级 e655d291
......@@ -378,6 +378,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
backward_descs[dup_op[i]]->Rename(out_name, new_name);
sum_op_inputs.emplace_back(new_name);
}
LOG(INFO) << "fuck " << sum_op_inputs.size();
std::unique_ptr<OpDescBind> sum_op(new OpDescBind(
"sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {}));
pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
......
......@@ -74,7 +74,8 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
std::vector<bool> should_run = Prune(pdesc, block_id);
PADDLE_ENFORCE_EQ(should_run.size(), block.ops_size());
for (size_t i = 0; i < should_run.size(); ++i) {
if (should_run[i]) {
// if (should_run[i]) {
if (true) {
for (auto& var : block.ops(i).outputs()) {
for (auto& argu : var.arguments()) {
if (local_scope.FindVar(argu) == nullptr) {
......@@ -82,7 +83,17 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
}
}
}
LOG(INFO) << block.ops(i).type();
if (block.ops(i).type() == "sum") {
LOG(INFO) << "Here";
for (auto& var : block.ops(i).inputs()) {
for (auto& argu : var.arguments()) {
LOG(INFO) << var.parameter() << " " << argu;
}
}
}
auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i));
LOG(INFO) << op->DebugString();
op->Run(local_scope, *device);
}
}
......
......@@ -30,6 +30,7 @@ USE_OP(gaussian_random);
USE_OP(feed);
USE_OP(fetch);
USE_OP(mul);
USE_OP(sum);
USE_OP(squared_l2_distance);
using std::string;
......@@ -104,40 +105,63 @@ class ExecutorTesterRandom : public ::testing::Test {
virtual void SetUp() override {
int input_dim = 5, batch_size = 2, embed_dim = 5;
auto temp_root_block = pdesc_.add_blocks();
temp_root_block->set_idx(0);
temp_root_block->set_parent_idx(-1);
paddle::framework::ProgramDescBind& program =
paddle::framework::ProgramDescBind::Instance(&pdesc_);
paddle::framework::BlockDescBind* root_block = program.Block(0);
auto temp_init_root_block = init_pdesc_.add_blocks();
temp_init_root_block->set_idx(0);
temp_init_root_block->set_parent_idx(-1);
paddle::framework::ProgramDescBind& init_program =
paddle::framework::ProgramDescBind::Instance(&init_pdesc_);
paddle::framework::BlockDescBind* init_root_block = init_program.Block(0);
// block[0]
AddOp("gaussian_random", {}, {{"Out", {"w1"}}},
{{"dims", std::vector<int>{input_dim, embed_dim}}}, root_block);
{{"dims", std::vector<int>{input_dim, embed_dim}}}, init_root_block);
AddOp("gaussian_random", {}, {{"Out", {"w2"}}},
{{"dims", std::vector<int>{embed_dim, input_dim}}}, root_block);
{{"dims", std::vector<int>{embed_dim, input_dim}}}, init_root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {},
{{"dims", std::vector<int>{input_dim, embed_dim}}, {"col", 0}},
root_block);
init_root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {},
{{"dims", std::vector<int>{embed_dim, input_dim}}, {"col", 1}},
root_block);
init_root_block);
// flush
init_program.Proto();
auto temp_root_block = pdesc_.add_blocks();
temp_root_block->set_idx(0);
temp_root_block->set_parent_idx(-1);
paddle::framework::ProgramDescBind& program =
paddle::framework::ProgramDescBind::Instance(&pdesc_);
paddle::framework::BlockDescBind* root_block = program.Block(0);
// block[1]
paddle::framework::BlockDescBind* run_block =
program.AppendBlock(*root_block);
AddOp("gaussian_random", {}, {{"Out", {"a"}}},
{{"dims", std::vector<int>{batch_size, input_dim}}}, run_block);
{{"dims", std::vector<int>{batch_size, input_dim}}}, root_block);
AddOp("mul", {{"X", {"a"}}, {"Y", {"w1"}}}, {{"Out", {"b"}}}, {},
run_block);
root_block);
AddOp("mul", {{"X", {"b"}}, {"Y", {"w2"}}}, {{"Out", {"a_out"}}}, {},
run_block);
root_block);
AddOp("squared_l2_distance", {{"X", {"a"}}, {"Y", {"a_out"}}},
{{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {},
run_block);
AddOp("fetch", {{"Input", {"l2_distance"}}}, {},
{{"dims", std::vector<int>{batch_size}}, {"col", 1}}, run_block);
root_block);
AddOp("gaussian_random", {}, {{"Out", {"l2_distance@GRAD"}}},
{{"dims", std::vector<int>{batch_size, 1}}}, root_block);
AppendBackward(program, {});
program.Proto();
for (auto& op : pdesc_.blocks(0).ops()) {
if (op.type() == "sum") {
LOG(INFO) << "Here";
for (auto& var : op.inputs()) {
for (auto& argu : var.arguments()) {
LOG(INFO) << var.parameter() << " " << argu;
}
}
}
}
AddOp("fetch", {{"Input", {"l2_distance"}}}, {},
{{"dims", std::vector<int>{batch_size}}, {"col", 1}}, root_block);
// flush
program.Proto();
......@@ -146,6 +170,7 @@ class ExecutorTesterRandom : public ::testing::Test {
}
protected:
ProgramDesc init_pdesc_;
ProgramDesc pdesc_;
};
......@@ -200,8 +225,8 @@ TEST_F(ExecutorTesterRandom, CPU) {
std::unique_ptr<Executor> executor(new Executor(places));
executor->Run(init_pdesc_, GetGlobalScope(), 0);
executor->Run(pdesc_, GetGlobalScope(), 0);
executor->Run(pdesc_, GetGlobalScope(), 1);
std::vector<std::vector<float>> result = GetFetchVariable<float>();
}
......@@ -248,8 +273,8 @@ TEST_F(ExecutorTesterRandom, GPU) {
std::unique_ptr<Executor> executor(new Executor(places));
executor->Run(init_pdesc_, GetGlobalScope(), 0);
executor->Run(pdesc_, GetGlobalScope(), 0);
executor->Run(pdesc_, GetGlobalScope(), 1);
std::vector<std::vector<float>> result = GetFetchVariable<float>();
}
......
......@@ -22,7 +22,7 @@ class FeedOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
typedef std::vector<framework::Tensor> FeedInputs;
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null.");
int col = ctx->Attrs().Get<int>("col");
......
......@@ -22,7 +22,7 @@ class FetchOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
typedef std::vector<framework::Tensor> FetchOutputs;
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null.");
int col = ctx->Attrs().Get<int>("col");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册