提交 831927d5 编写于 作者: Y Yang Yang(Tony) 提交者: GitHub

Merge pull request #4738 from tonyyang-svail/prune_impl

Prune implementation
......@@ -44,6 +44,9 @@ cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_co
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor)
cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place)
......
......@@ -55,6 +55,7 @@ message OpDesc {
repeated Var inputs = 1;
repeated Var outputs = 2;
repeated Attr attrs = 4;
optional bool is_target = 5 [ default = false ];
};
// OpProto describes a C++ framework::OperatorBase derived class.
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/prune.h"
#include <algorithm>
#include <set>
#include <string>
#include <vector>
#include <glog/logging.h>
namespace paddle {
namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
bool HasDependentVar(const OpDesc& op_desc,
const std::set<std::string>& dependent_vars) {
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if (dependent_vars.count(argu) != 0) {
return true;
}
}
}
return false;
}
bool IsTarget(const OpDesc& op_desc) {
if (op_desc.has_is_target()) {
return op_desc.is_target();
}
return false;
}
void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
// TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op
auto& block = input.blocks(block_id);
auto& ops = block.ops();
bool expect_feed = true;
for (auto& op_desc : ops) {
PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed,
"All FeedOps are at the beginning of the ProgramDesc");
expect_feed = (op_desc.type() == kFeedOpType);
}
bool expect_fetch = true;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch,
"All FetchOps must at the end of the ProgramDesc");
expect_fetch = (op_desc.type() == kFetchOpType);
}
std::set<std::string> dependent_vars;
std::vector<bool> should_run;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) {
// insert its input to the dependency graph
for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) {
dependent_vars.insert(argu);
}
}
should_run.push_back(true);
} else {
should_run.push_back(false);
}
}
// since we are traversing the ProgramDesc in reverse order
// we reverse the should_run vector
std::reverse(should_run.begin(), should_run.end());
output = input;
auto* op_field = output.mutable_blocks(block_id)->mutable_ops();
op_field->Clear();
for (size_t i = 0; i < should_run.size(); ++i) {
if (should_run[i]) {
*op_field->Add() = input.blocks(block_id).ops(i);
}
}
}
void Prune(const ProgramDesc& input, ProgramDesc& output) {
prune_impl(input, output, 0);
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/framework.pb.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
void Prune(const ProgramDesc& input, ProgramDesc& output);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/prune.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/net_op.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/program_desc.h"
#include <gtest/gtest.h>
namespace f = paddle::framework;
namespace ops = paddle::operators;
void AddOp(const std::string &type, const f::VariableNameMap &inputs,
const f::VariableNameMap &outputs, f::AttributeMap attrs,
paddle::framework::BlockDescBind *block) {
// insert output
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->Var(v);
var->SetDataType(paddle::framework::DataType::FP32);
}
}
// insert op
auto op = block->AppendOp();
op->SetType(type);
for (auto &kv : inputs) {
op->SetInput(kv.first, kv.second);
}
for (auto &kv : outputs) {
op->SetOutput(kv.first, kv.second);
}
op->SetAttrMap(attrs);
}
TEST(Prune, one_operator) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block);
f::ProgramDesc *pdesc = program.Proto();
f::ProgramDesc pruned;
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
}
TEST(Prune, forward) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"c"}}}, {}, block);
AddOp("one_one", {{"input", {"c"}}}, {{"output", {"d"}}}, {}, block);
AddOp("one_one", {{"input", {"d"}}}, {{"output", {"e"}}}, {}, block);
f::ProgramDesc *pdesc = program.Proto();
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
f::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
}
}
TEST(Prune, multi_input_op) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, {}, block);
AddOp("one_one", {{"input", {"a1"}}}, {{"output", {"b1"}}}, {}, block);
AddOp("one_one", {{"input", {"a2"}}}, {{"output", {"b2"}}}, {}, block);
AddOp("three_one", {{"input", {"b0", "b1", "b2"}}}, {{"output", {"c"}}}, {},
block);
f::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
}
TEST(Prune, multi_output_op) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block);
AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, {}, block);
f::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
}
TEST(Prune, multi_target) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block);
AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, {}, block);
f::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true);
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册