未验证 提交 dcfb6038 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enable the detection of subgraph composed of grad ops (#21223)

* Add the first implememtation of fusion_group op #19621 (#3)

* Add the dynamic load of nvrtc, and support runtime compiling of CUDA kernel using nvrtc.
test=develop

* Call CUDA driver api to launch the kernel compiled by nvrtc.
test=develop

* Disable for mac and windows.
test=develop

* Refine the codes to support manually specified num_threads and workload_per_thread.
test=develop

* Refine the CUDA kernel to support large dims.
test=develop

* Add DeviceCodePool to manage all device codes.

* Add the first implementation fusion_group op.

* Add unit-test for fusion_group op.

* Add the check of result.

* Add the check of nvrtc in unit-test.
test=develop

* Add comment to explain the inputs, outputs and features of fusion_group op.
test=develop

* Disable fusion_group op for mac and windows.
test=develop

* Make the compiling of device code return status instead of hanging up.
test=develop

* Add the check of whether there is CUDA driver library, and do not core dump when failing to call the CUDA driver API.

* Unify fusion_group_op's input and output names.
test=develop

* Add the check of CUDA driver library in unittest.
test=develop

* Enable generating code for a given subgraph. #21126 (#4)

* Enable generating code for a given subgraph.

* Support sorting the subgraph.

* Remove the rearange of expressions because we use the sorted subgraph directly.

* Enable generating code for a subgraph which is composed of grad ops.

* Use expression information to check the accuracy in unittest.

* Separate load and store from computation expressions.
test=develop

* Improve the loading statements in generated codes.
test=develop

* Remove unused arguments from formal list.
test=develop

* Enable the detection of subgraph of grad ops.

* Generate code for detected subgraph in fusion_group_pass.

* Add an option in BuildStrategy to enable fusion_group_pass and add unittest.
test=develop

* Fix a bug when checking whether the shape of all inputs are the same.

* Add debug information.

* Remove subgraph_detector from inference/analysis to the common framework/ir directory. (#5)

test=develop

* Call subgraph_detector in fusion_group pass.
test=develop

* Disable fusion_group when WITH_GPU is OFF.
test=develop

* Refine all PADDLE_ENFORCE message.
test=develop

* Fix the case that some inputs are not defined in grad ops, and set op_role for fused op.
test=develop

* Follow review comments.
test=develop
上级 50af6b5d
...@@ -64,7 +64,14 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d ...@@ -64,7 +64,14 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass buffer_shared_inplace_op_pass buffer_shared_cross_op_memory_reuse_pass) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
sequential_execution_pass
modify_op_lock_and_record_event_pass
all_reduce_deps_pass
reference_count_pass
eager_deletion_pass
buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
...@@ -91,23 +98,22 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo ...@@ -91,23 +98,22 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle) cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)
set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass
multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass)
if(WITH_GPU)
set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass)
endif()
if(WITH_NGRAPH) if(WITH_NGRAPH)
set(NGRAPH_BS_DEPS ngraph) set(IR_PASS_DEPS ${IR_PASS_DEPS} ngraph)
else()
set(NGRAPH_BS_DEPS)
endif() endif()
cc_library(build_strategy SRCS build_strategy.cc DEPS pass_builder ${IR_PASS_DEPS})
cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass
pass_builder
${NGRAPH_BS_DEPS})
if (WITH_MKLDNN) if (WITH_MKLDNN)
target_link_libraries(build_strategy mkldnn_placement_pass) target_link_libraries(build_strategy mkldnn_placement_pass)
......
...@@ -165,9 +165,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -165,9 +165,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendOpFusePasses() { void AppendOpFusePasses() {
AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_, AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_,
"fuse_relu_depthwise_conv_pass"); "fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
#ifdef PADDLE_WITH_CUDA
AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass");
#endif
AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_, AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_,
"fuse_elewise_add_act_pass"); "fuse_elewise_add_act_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
// for single card training, fuse_all_reduce_ops is unnecessary. // for single card training, fuse_all_reduce_ops is unnecessary.
// coalesce_grad_tensor_pass should be before of MultiDevPass. // coalesce_grad_tensor_pass should be before of MultiDevPass.
AppendPassWithCheck(strategy_.fuse_all_reduce_ops_, AppendPassWithCheck(strategy_.fuse_all_reduce_ops_,
...@@ -370,6 +373,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -370,6 +373,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped."; "GPU, skipped.";
continue; continue;
} }
} else if (pass->Type() == "fusion_group_pass") {
pass->Set<bool>("use_gpu", new bool(use_cuda));
if (!use_cuda) {
LOG(WARNING) << "fusion_group_pass is only supported on GPU, skipped.";
continue;
}
} else if (pass->Type() == "fuse_bn_act_pass") { } else if (pass->Type() == "fuse_bn_act_pass") {
if (!use_cuda) { if (!use_cuda) {
LOG(WARNING) << "fuse_bn_act_pass is only supported on " LOG(WARNING) << "fuse_bn_act_pass is only supported on "
...@@ -427,3 +436,6 @@ USE_PASS(mkldnn_placement_pass); ...@@ -427,3 +436,6 @@ USE_PASS(mkldnn_placement_pass);
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
USE_PASS(ngraph_subgraph_pass); USE_PASS(ngraph_subgraph_pass);
#endif #endif
#ifdef PADDLE_WITH_CUDA
USE_PASS(fusion_group_pass);
#endif
...@@ -86,8 +86,9 @@ struct BuildStrategy { ...@@ -86,8 +86,9 @@ struct BuildStrategy {
// Operator fusion // Operator fusion
// TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have // TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
// cycle. // cycle.
bool fuse_elewise_add_act_ops_{false};
bool fuse_bn_act_ops_{false}; bool fuse_bn_act_ops_{false};
bool fuse_elewise_add_act_ops_{false};
bool enable_auto_fusion_{false};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients // Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types // should not be sparse types
boost::optional<bool> fuse_all_optimizer_ops_{false}; boost::optional<bool> fuse_all_optimizer_ops_{false};
......
...@@ -6,7 +6,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n") ...@@ -6,7 +6,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
add_subdirectory(fuse_optimizer_ops_pass) add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass) add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass) add_subdirectory(multi_devices_graph_pass)
if(NOT APPLE AND NOT WIN32) if(NOT APPLE AND NOT WIN32 AND WITH_GPU)
add_subdirectory(fusion_group) add_subdirectory(fusion_group)
endif() endif()
......
cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph) cc_library(code_generator
SRCS operation.cc code_generator.cc code_generator_helper.cc
DEPS graph subgraph_detector)
if(WITH_GPU) if(WITH_GPU)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass) cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass)
endif() endif()
cc_library(fusion_group_pass cc_library(fusion_group_pass
SRCS fusion_group_pass.cc elementwise_group_detector.cc SRCS fusion_group_pass.cc elementwise_group_detector.cc
DEPS graph_pattern_detector pass code_generator) DEPS subgraph_detector fuse_pass_base code_generator device_code)
cc_test(test_fusion_group_pass SRCS fusion_group_pass_tester.cc DEPS fusion_group_pass graph_viz_pass) cc_test(test_fusion_group_pass SRCS fusion_group_pass_tester.cc DEPS fusion_group_pass graph_viz_pass)
...@@ -33,7 +33,7 @@ CodeGenerator::CodeGenerator() { ...@@ -33,7 +33,7 @@ CodeGenerator::CodeGenerator() {
std::string CodeGenerator::Generate(SubGraph* subgraph) { std::string CodeGenerator::Generate(SubGraph* subgraph) {
std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph); std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph);
return Generate(subgraph->func_name, expressions); return Generate(subgraph->GetFuncName(), expressions);
} }
static bool HasInput(Node* n, std::string name) { static bool HasInput(Node* n, std::string name) {
......
...@@ -227,7 +227,7 @@ std::vector<fusion_group::OperationExpression> TestMain( ...@@ -227,7 +227,7 @@ std::vector<fusion_group::OperationExpression> TestMain(
std::string code_str = code_generator.Generate(subgraph); std::string code_str = code_generator.Generate(subgraph);
VLOG(3) << code_str; VLOG(3) << code_str;
TestMainImpl(subgraph->func_name, code_str, cpu_tensors, n, input_ids, TestMainImpl(subgraph->GetFuncName(), code_str, cpu_tensors, n, input_ids,
output_ids); output_ids);
// Need to check the accuracy according to expressions. // Need to check the accuracy according to expressions.
......
...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h" #include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/operation.h" #include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/subgraph_detector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -26,20 +29,22 @@ static std::unordered_set<std::string> unary_op_types; ...@@ -26,20 +29,22 @@ static std::unordered_set<std::string> unary_op_types;
static std::unordered_set<std::string>& GetBinaryOpTypes() { static std::unordered_set<std::string>& GetBinaryOpTypes() {
if (binary_op_types.empty()) { if (binary_op_types.empty()) {
binary_op_types = OperationMap::Instance().Find(0, 2); binary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 2);
} }
return binary_op_types; return binary_op_types;
} }
static std::unordered_set<std::string>& GetUnaryOpTypes() { static std::unordered_set<std::string>& GetUnaryOpTypes() {
if (unary_op_types.empty()) { if (unary_op_types.empty()) {
unary_op_types = OperationMap::Instance().Find(0, 1); unary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 1);
} }
return unary_op_types; return unary_op_types;
} }
static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types, static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
Node* n) { const Node* n) {
if (n && n->IsOp() && n->Op() && n->outputs.size() > 0U) { if (n && n->IsOp() && n->Op() && n->outputs.size() > 0U) {
auto iter = op_types.find(n->Op()->Type()); auto iter = op_types.find(n->Op()->Type());
if (iter != op_types.end()) { if (iter != op_types.end()) {
...@@ -49,114 +54,63 @@ static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types, ...@@ -49,114 +54,63 @@ static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
return false; return false;
} }
static bool IsBinaryOp(Node* n) { static bool IsGradOp(const Node* n) {
if (IsSpecifiedOp(GetBinaryOpTypes(), n) && n->inputs.size() == 2U) { PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
auto* x = n->inputs[0]; platform::errors::InvalidArgument(
auto* y = n->inputs[1]; "Expected node %p to be an operator node.", n));
std::string suffix = "_grad";
std::string op_type = n->Op()->Type();
size_t pos = op_type.rfind(suffix);
return pos != std::string::npos &&
pos == (op_type.length() - suffix.length());
}
std::vector<int64_t> x_shape; static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
std::vector<int64_t> y_shape; const std::vector<int64_t>& r) {
if (x && x->IsVar() && x->Var()) { return l.size() != 0U && r.size() != 0U && l == r;
x_shape = x->Var()->GetShape(); }
}
if (y && y->IsVar() && y->Var()) { static bool IsBinaryOp(const Node* n) {
y_shape = y->Var()->GetShape(); if (IsSpecifiedOp(GetBinaryOpTypes(), n)) {
} if ((!IsGradOp(n) && n->inputs.size() != 2U) || n->inputs.size() == 0U) {
if (x_shape.size() == 0U || x_shape.size() != y_shape.size()) {
return false; return false;
} }
for (size_t i = 0; i < x_shape.size(); ++i) {
if (x_shape[i] != y_shape[i]) { // The shape of all inputs should be the same.
std::vector<int64_t> shape_0;
for (size_t i = 0; i < n->inputs.size(); ++i) {
auto* in_i = n->inputs[i];
if (!(in_i && in_i->IsVar() && in_i->Var())) {
return false; return false;
} }
}
return true;
}
return false;
}
static bool IsUnaryOp(Node* n) { return IsSpecifiedOp(GetUnaryOpTypes(), n); }
bool ElementwiseGroupDetector::IsElementwiseOp(Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n);
}
bool ElementwiseGroupDetector::IsInputOfElementwiseOp(Node* n, std::vector<int64_t> shape_i = in_i->Var()->GetShape();
std::string name) { if (i == 0U) {
if (n && n->IsVar() && n->Var()) { shape_0 = shape_i;
for (auto* op : n->outputs) { } else {
if (IsElementwiseOp(op)) { if (!IsEqualAndNotEmpty(shape_0, shape_i)) {
if (name.empty()) { return false;
return true;
} else if (IsNthInput(n, op, name, 0)) {
return true;
} }
} }
} }
return true;
} }
return false; return false;
} }
bool ElementwiseGroupDetector::IsOutputOfElementwiseOp(Node* n) { static bool IsUnaryOp(const Node* n) {
if (n && n->IsVar() && n->Var()) { return IsSpecifiedOp(GetUnaryOpTypes(), n);
for (auto* op : n->inputs) {
if (IsElementwiseOp(op)) {
return true;
}
}
}
return false;
} }
int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) { bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
std::unordered_set<Node*> except_nodes_set; return IsBinaryOp(n) || IsUnaryOp(n);
for (size_t i = 0; i < except_nodes.size(); ++i) {
except_nodes_set.insert(except_nodes[i]);
}
int num_operations = 0;
if (IsElementwiseOp(n)) {
subgraph_.Insert(n);
num_operations += 1;
for (auto* var : n->inputs) {
subgraph_.Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n});
}
}
for (auto* var : n->outputs) {
subgraph_.Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n});
}
}
} else if (n && n->IsVar() && n->Var()) {
for (auto* op : n->inputs) {
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end()) {
num_operations += Search(op, {n});
}
}
for (auto* op : n->outputs) {
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end()) {
num_operations += Search(op, {n});
}
}
}
return num_operations;
} }
int ElementwiseGroupDetector::operator()(Node* n) { std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()(
if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) { Graph* graph) {
name_ = n->Name(); auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); };
subgraph_.Insert(n);
num_operations_ = Search(n, n->inputs); return SubgraphDetector(graph, teller)();
VLOG(4) << "Detect elementwise subgraph begin with " << name_ << ", "
<< num_operations_ << " operations, " << GetSubgraph().GetNumNodes()
<< " nodes";
}
return num_operations_;
} }
} // namespace fusion_group } // namespace fusion_group
......
...@@ -14,10 +14,8 @@ limitations under the License. */ ...@@ -14,10 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
namespace paddle { namespace paddle {
...@@ -27,21 +25,10 @@ namespace fusion_group { ...@@ -27,21 +25,10 @@ namespace fusion_group {
class ElementwiseGroupDetector { class ElementwiseGroupDetector {
public: public:
int operator()(Node* n); std::vector<std::vector<Node*>> operator()(Graph* graph);
SubGraph GetSubgraph() const { return subgraph_; }
private:
bool IsElementwiseOp(Node* n);
bool IsInputOfElementwiseOp(Node* n, std::string name = "");
bool IsOutputOfElementwiseOp(Node* n);
int Search(Node* n, std::vector<Node*> except_nodes = {});
private: private:
std::string name_; bool IsElementwiseOp(const Node* n);
int num_operations_{0};
SubGraph subgraph_;
}; };
} // namespace fusion_group } // namespace fusion_group
......
...@@ -13,57 +13,88 @@ See the License for the specific language governing permissions and ...@@ -13,57 +13,88 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h" #include "paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h"
#include <memory>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h" #include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/platform/device_code.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void FusionGroupPass::ApplyImpl(ir::Graph* graph) const { void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph); FusePassBase::Init("fusion_group_pass", graph);
if (Get<bool>("use_gpu")) {
int num_elementwise_groups = DetectFusionGroup(graph, 0); fusion_group::OperationMap::Init();
LOG(INFO) << "Detect " << num_elementwise_groups int num_elementwise_groups = DetectFusionGroup(graph, 0);
VLOG(3) << "Detect " << num_elementwise_groups
<< " elementwise fusion groups."; << " elementwise fusion groups.";
}
} }
int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const { int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
std::vector<fusion_group::SubGraph> subgraphs; // TODO(liuyiqun): supported different places
std::unordered_set<Node*> all_nodes = graph->Nodes(); platform::CUDAPlace place = platform::CUDAPlace(0);
for (Node* n : all_nodes) { int index = platform::DeviceCodePool::Init({place}).size(place);
bool is_found = false;
for (auto& subgraph : subgraphs) { std::vector<std::vector<Node*>> subgraphs =
if (subgraph.Has(n)) { fusion_group::ElementwiseGroupDetector()(graph);
is_found = true;
break; int num_subgraphs = 0;
} size_t min_subgraph_size = 2;
} bool save_intermediate_out = true;
if (is_found) { for (auto& vec : subgraphs) {
continue; if (vec.size() >= min_subgraph_size) {
std::string func_name = "fused_elementwise_" + std::to_string(index++);
fusion_group::SubGraph subgraph(
type, func_name, save_intermediate_out,
std::unordered_set<Node*>(vec.begin(), vec.end()));
VLOG(3) << "subgraph: {\n"
<< DebugString(subgraph.SortedNodes()) << "}\n";
GenerateCode(&subgraph);
InsertFusionGroupOp(graph, &subgraph);
num_subgraphs++;
} }
}
return num_subgraphs;
}
fusion_group::SubGraph subgraph; void FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const {
if (type == 0) { fusion_group::CodeGenerator code_generator;
fusion_group::ElementwiseGroupDetector detector; std::string code_str = code_generator.Generate(subgraph);
int num_operations = detector(n); VLOG(3) << code_str;
if (num_operations >= 2) {
subgraph = detector.GetSubgraph(); // TODO(liuyiqun): supported different places
} platform::CUDAPlace place = platform::CUDAPlace(0);
} std::unique_ptr<platform::CUDADeviceCode> device_code(
new platform::CUDADeviceCode(place, subgraph->GetFuncName(), code_str));
device_code->Compile();
platform::DeviceCodePool& pool = platform::DeviceCodePool::Init({place});
pool.Set(std::move(device_code));
}
if (!subgraph.IsEmpty()) { static int ExtractOpRole(fusion_group::SubGraph* subgraph) {
subgraphs.push_back(subgraph); std::unordered_set<int> op_roles;
std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName();
for (auto* n : subgraph->Nodes()) {
if (n && n->IsOp() && n->Op()) {
if (n->Op()->HasAttr(attr_name)) {
op_roles.insert(boost::get<int>(n->Op()->GetAttr(attr_name)));
}
} }
} }
if (op_roles.size() == 1U) {
// TODO(liuyiqun): check whether there are intersection between subgraphs return *(op_roles.begin());
for (size_t i = 0; i < subgraphs.size(); ++i) { } else {
InsertFusionGroupOp(graph, &subgraphs[i]); return static_cast<int>(OpRole::kNotSpecified);
} }
return subgraphs.size();
} }
void FusionGroupPass::InsertFusionGroupOp( void FusionGroupPass::InsertFusionGroupOp(
...@@ -90,10 +121,12 @@ void FusionGroupPass::InsertFusionGroupOp( ...@@ -90,10 +121,12 @@ void FusionGroupPass::InsertFusionGroupOp(
external_nodes.insert(n); external_nodes.insert(n);
} }
op_desc.SetOutput("Outs", output_names); op_desc.SetOutput("Outs", output_names);
op_desc.SetAttr("type", subgraph->type); op_desc.SetAttr("type", subgraph->GetType());
op_desc.SetAttr("func_name", subgraph->func_name); op_desc.SetAttr("func_name", subgraph->GetFuncName());
op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(subgraph));
auto fusion_group_node = graph->CreateOpNode(&op_desc); Node* fusion_group_node = graph->CreateOpNode(&op_desc);
for (auto* in : input_vars_of_subgraph) { for (auto* in : input_vars_of_subgraph) {
IR_NODE_LINK_TO(in, fusion_group_node); IR_NODE_LINK_TO(in, fusion_group_node);
} }
...@@ -114,4 +147,5 @@ void FusionGroupPass::InsertFusionGroupOp( ...@@ -114,4 +147,5 @@ void FusionGroupPass::InsertFusionGroupOp(
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(fusion_group_pass, paddle::framework::ir::FusionGroupPass); REGISTER_PASS(fusion_group_pass, paddle::framework::ir::FusionGroupPass)
.RequirePassAttr("use_gpu");
...@@ -16,19 +16,20 @@ limitations under the License. */ ...@@ -16,19 +16,20 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h" #include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class FusionGroupPass : public Pass { class FusionGroupPass : public FusePassBase {
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
private: private:
int DetectFusionGroup(Graph* graph, int type = 0) const; int DetectFusionGroup(Graph* graph, int type = 0) const;
void GenerateCode(fusion_group::SubGraph* subgraph) const;
void InsertFusionGroupOp(Graph* graph, void InsertFusionGroupOp(Graph* graph,
fusion_group::SubGraph* subgraph) const; fusion_group::SubGraph* subgraph) const;
......
...@@ -138,19 +138,15 @@ int TestMain(std::unique_ptr<Graph> graph, std::string prefix) { ...@@ -138,19 +138,15 @@ int TestMain(std::unique_ptr<Graph> graph, std::string prefix) {
} }
TEST(FusionGroupPass, elementwise_list) { TEST(FusionGroupPass, elementwise_list) {
fusion_group::OperationMap::Init(); std::unique_ptr<Graph> graph = BuildElementwiseListGraph(true);
std::unique_ptr<Graph> graph = BuildElementwiseListGraph(false);
int num_fusion_group_ops = TestMain(std::move(graph), "elementwise_list"); int num_fusion_group_ops = TestMain(std::move(graph), "elementwise_list");
EXPECT_EQ(num_fusion_group_ops, 1); EXPECT_EQ(num_fusion_group_ops, 2);
} }
TEST(FusionGroupPass, elementwise_tree) { TEST(FusionGroupPass, elementwise_tree) {
fusion_group::OperationMap::Init(); std::unique_ptr<Graph> graph = BuildElementwiseTreeGraph(true);
std::unique_ptr<Graph> graph = BuildElementwiseTreeGraph(false);
int num_fusion_group_ops = TestMain(std::move(graph), "elementwise_tree"); int num_fusion_group_ops = TestMain(std::move(graph), "elementwise_tree");
EXPECT_EQ(num_fusion_group_ops, 2); EXPECT_EQ(num_fusion_group_ops, 4);
} }
} // namespace ir } // namespace ir
......
...@@ -20,48 +20,59 @@ limitations under the License. */ ...@@ -20,48 +20,59 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fusion_group/operation.h" #include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/subgraph_detector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
namespace fusion_group { namespace fusion_group {
struct SubGraph { class SubGraph {
int type{-1}; public:
std::string func_name;
bool save_intermediate_out{false};
SubGraph() = default; SubGraph() = default;
SubGraph(int t, std::string f, bool s, const std::unordered_set<Node*>& n) explicit SubGraph(int type) : type_(type) {}
: type(t), func_name(f), save_intermediate_out(s), nodes_set(n) {} SubGraph(int type, std::string func_name, bool save_intermediate_out,
const std::unordered_set<Node*>& nodes_set)
: type_(type),
func_name_(func_name),
save_intermediate_out_(save_intermediate_out) {
for (auto* n : nodes_set) {
nodes_set_.insert(n);
if (n && n->IsOp() && n->Op()) {
// If the node is an op node, then add its input/output var nodes
// into the subgraph.
for (auto* in : n->inputs) {
nodes_set_.insert(in);
}
for (auto* out : n->outputs) {
nodes_set_.insert(out);
}
}
}
}
bool IsEmpty() { return nodes_set.empty(); } bool IsEmpty() { return nodes_set_.empty(); }
const std::unordered_set<Node*>& Nodes() const { return nodes_set; } int GetType() const { return type_; }
void SetFuncName(std::string func_name) { func_name_ = func_name; }
std::string GetFuncName() const { return func_name_; }
const std::unordered_set<Node*>& Nodes() const { return nodes_set_; }
const std::vector<Node*>& SortedNodes() { const std::vector<Node*>& SortedNodes() {
if (!is_sorted) { if (!is_sorted_) {
Sort(); TopologicalSort();
} }
return sorted_nodes; return sorted_nodes_;
} }
size_t GetNumNodes() { return nodes_set.size(); } size_t GetNumNodes() { return nodes_set_.size(); }
bool Has(Node* n) { return nodes_set.find(n) != nodes_set.end(); } bool Has(Node* n) { return nodes_set_.find(n) != nodes_set_.end(); }
void Insert(Node* n) {
if (nodes_set.find(n) == nodes_set.end()) {
VLOG(5) << "Insert " << n->Name() << " to subgraph " << this;
nodes_set.insert(n);
is_sorted = false;
}
}
int GetNumOperations() { int GetNumOperations() {
int num_operations = 0; int num_operations = 0;
for (auto* n : nodes_set) { for (auto* n : nodes_set_) {
if (n && n->IsOp() && n->Op()) { if (n && n->IsOp() && n->Op()) {
num_operations++; num_operations++;
} }
...@@ -96,203 +107,108 @@ struct SubGraph { ...@@ -96,203 +107,108 @@ struct SubGraph {
std::vector<Node*> GetOutputVarNodes() { std::vector<Node*> GetOutputVarNodes() {
// The order of output nodes should be consistant anywhere.. // The order of output nodes should be consistant anywhere..
std::vector<Node*> output_vars; std::vector<Node*> output_vars_all;
for (auto* n : SortedNodes()) { for (auto* n : SortedNodes()) {
if (n && n->IsVar() && n->Var()) { if (n && n->IsVar() && n->Var()) {
if (save_intermediate_out) { // If the var_node is the output of some op_node in the subgraph, it
// If the var_node is the output of some op_node in the subgraph, it // is considered the output var node of the subgraph.
// is considered the output var node of the subgraph. bool is_found = false;
bool is_found = false; for (auto* in : n->inputs) {
for (auto* in : n->inputs) { if (Has(in)) {
if (Has(in)) { is_found = true;
is_found = true;
}
}
if (is_found) {
output_vars.push_back(n);
}
} else {
// If one of the var_node's outputs is the input of some operator
// outside the subgraph, it is considered the output var node of the
// subgraph.
bool is_found = true;
if (n->outputs.size() == 0U) {
is_found = false;
}
for (auto* out : n->outputs) {
if (!Has(out)) {
is_found = false;
}
}
if (!is_found) {
output_vars.push_back(n);
} }
} }
if (is_found) {
output_vars_all.push_back(n);
}
} }
} }
return output_vars;
}
private: if (save_intermediate_out_) {
int FindIndexInSortedNodes(Node* n) { return output_vars_all;
for (size_t i = 0; i < sorted_nodes.size(); ++i) {
if (n == sorted_nodes[i]) {
return static_cast<int>(i);
}
} }
return -1;
}
void SortVarsBasedOnSortedOps() {
// Insert var nodes to sorted_nodes.
std::unordered_map<std::string, Node*> sorted_vars;
for (auto* n : nodes_set) {
if (n && n->IsVar() && n->Var()) {
int from = 0;
int to = sorted_nodes.size();
for (auto* in : n->inputs) {
if (in && in->IsOp() && in->Op()) {
int index = FindIndexInSortedNodes(in);
// Insert after input op node
if (index >= 0) {
from = index + 1 > from ? index + 1 : from;
}
}
}
for (auto* out : n->outputs) {
if (out && out->IsOp() && out->Op()) {
int index = FindIndexInSortedNodes(out);
// Insert before output op node
if (index >= 0) {
to = index < to ? index : to;
}
}
}
if (from > to) { std::vector<Node*> output_vars_outside;
LOG(INFO) << "subgraph: {\n" << DebugString(Nodes()) << "}\n"; for (auto* n : output_vars_all) {
LOG(INFO) << "sorted nodes: {\n" // If one of the var_node's outputs is the input of some operator
<< DebugString(sorted_nodes) << "}\n"; // outside the subgraph, it is considered the output var node of the
// subgraph.
bool is_found = true;
if (n->outputs.size() == 0U) {
is_found = false;
}
for (auto* out : n->outputs) {
if (!Has(out)) {
is_found = false;
} }
PADDLE_ENFORCE_LE(from, to, "Range [%d, %d] is invalid.", from, to); }
sorted_nodes.insert(sorted_nodes.begin() + to, n); if (!is_found) {
sorted_vars[n->Name()] = n; output_vars_outside.push_back(n);
} }
} }
return output_vars_outside;
} }
std::vector<Node*> SortedOps() { private:
Node* start_op_n = nullptr; void TopologicalSort() {
std::unordered_set<Node*> ops; if (!is_sorted_) {
for (auto* op_n : nodes_set) { std::unordered_map<Node*, std::vector<Node*>> inputs_map;
if (op_n && op_n->IsOp() && op_n->Op()) { std::unordered_map<Node*, std::vector<Node*>> outputs_map;
// Initialize ops to all ops in the subgraph. for (auto* n : nodes_set_) {
ops.insert(op_n); inputs_map[n] = n->inputs;
outputs_map[n] = n->outputs;
}
if (!start_op_n) { for (auto* n : nodes_set_) {
// Find start op node whose inputs are produced outside the subgraph. if (n && n->IsVar() && n->Var()) {
bool is_found = false; // Set the input of subgraph's input var node to null.
for (auto* prev_op_n : GetPrevOpNodes(op_n)) { std::vector<Node*> inputs;
if (Has(prev_op_n)) { for (auto* in : n->inputs) {
is_found = true; if (Has(in)) {
break; inputs.push_back(in);
} }
} }
if (!is_found) { // Set the output of subgraph's output var node to null.
start_op_n = op_n; std::vector<Node*> outputs;
for (auto* out : n->outputs) {
if (Has(out)) {
outputs.push_back(out);
}
} }
n->inputs = inputs;
n->outputs = outputs;
} }
} }
} // Collect the start points of the subgraph.
std::vector<Node*> start_points;
std::vector<Node*> sorted_ops; for (auto* n : nodes_set_) {
sorted_ops.push_back(start_op_n); if (n->inputs.empty()) {
ops.erase(start_op_n); start_points.push_back(n);
while (ops.size() > 0U) {
std::unordered_set<Node*> erased_ops;
for (auto* op_n : ops) {
bool found_connected_ops = false;
int from = 1;
int to = sorted_ops.size();
std::unordered_set<Node*> prev_op_nodes = GetPrevOpNodes(op_n);
std::unordered_set<Node*> next_op_nodes = GetNextOpNodes(op_n);
for (int i = sorted_ops.size(); i >= 0; --i) {
if (prev_op_nodes.find(sorted_ops[i]) != prev_op_nodes.end()) {
// Insert after i (i + 1)
found_connected_ops = true;
from = (i + 1 > from) ? i + 1 : from;
}
if (next_op_nodes.find(sorted_ops[i]) != next_op_nodes.end()) {
// Insert before i
found_connected_ops = true;
to = (i < to) ? i : to;
}
}
if (found_connected_ops) {
if (from > to) {
LOG(INFO) << "subgraph: {\n" << DebugString(Nodes()) << "}\n";
}
PADDLE_ENFORCE_LE(from, to, "Range [%d, %d] is invalid.", from, to);
sorted_ops.insert(sorted_ops.begin() + to, op_n);
erased_ops.insert(op_n);
} }
} }
PADDLE_ENFORCE_GT(erased_ops.size(), 0U); // Sort the subgraph.
for (auto* op_n : erased_ops) { NodesTSIterator x(start_points);
ops.erase(op_n); for (auto& n : iterator_range<NodesTSIterator>(
NodesTSIterator(start_points), NodesTSIterator())) {
sorted_nodes_.push_back(&n);
} }
} // Reset the inputs, outputs.
return sorted_ops; for (auto* n : nodes_set_) {
} n->inputs = inputs_map[n];
n->outputs = outputs_map[n];
std::unordered_set<Node*> GetPrevOpNodes(Node* op_n) {
PADDLE_ENFORCE_EQ(op_n && op_n->IsOp() && op_n->Op(), true,
"Node %p is not a op node.", op_n);
std::unordered_set<Node*> prev_op_nodes;
for (auto* in_var : op_n->inputs) {
if (in_var && in_var->IsVar() && in_var->Var()) {
for (auto* prev_op_n : in_var->inputs) {
if (prev_op_n && prev_op_n->IsOp() && prev_op_n->Op()) {
prev_op_nodes.insert(prev_op_n);
}
}
} }
} }
return prev_op_nodes; is_sorted_ = true;
}
std::unordered_set<Node*> GetNextOpNodes(Node* op_n) {
PADDLE_ENFORCE_EQ(op_n && op_n->IsOp() && op_n->Op(), true,
"Node %p is not a op node.", op_n);
std::unordered_set<Node*> next_op_nodes;
for (auto* out_var : op_n->outputs) {
if (out_var && out_var->IsVar() && out_var->Var()) {
for (auto* next_op_n : out_var->outputs) {
if (next_op_n && next_op_n->IsOp() && next_op_n->Op()) {
next_op_nodes.insert(next_op_n);
}
}
}
}
return next_op_nodes;
}
void Sort() {
if (!is_sorted) {
sorted_nodes = SortedOps();
SortVarsBasedOnSortedOps();
}
is_sorted = true;
} }
private: private:
std::unordered_set<Node*> nodes_set; int type_{-1};
bool is_sorted{false}; std::string func_name_;
std::vector<Node*> sorted_nodes; bool save_intermediate_out_{true};
std::unordered_set<Node*> nodes_set_;
bool is_sorted_{false};
std::vector<Node*> sorted_nodes_;
}; };
} // namespace fusion_group } // namespace fusion_group
......
...@@ -2017,6 +2017,27 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2017,6 +2017,27 @@ All parameter, weight, gradient are variables in Paddle.
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.fuse_bn_act_ops = True build_strategy.fuse_bn_act_ops = True
)DOC") )DOC")
.def_property(
"enable_auto_fusion",
[](const BuildStrategy &self) { return self.enable_auto_fusion_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
platform::errors::PreconditionNotMet(
"BuildStrategy is finlaized."));
self.enable_auto_fusion_ = b;
},
R"DOC((bool, optional): Whether to enable fusing subgraph to a
fusion_group. Now we only support fusing subgraph that composed
of elementwise-like operators, such as elementwise_add/mul
without broadcast and activations.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.enable_auto_fusion = True
)DOC")
.def_property( .def_property(
"fuse_relu_depthwise_conv", "fuse_relu_depthwise_conv",
[](const BuildStrategy &self) { [](const BuildStrategy &self) {
......
...@@ -200,6 +200,10 @@ if (APPLE OR WIN32) ...@@ -200,6 +200,10 @@ if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_imperative_signal_handler) list(REMOVE_ITEM TEST_OPS test_imperative_signal_handler)
endif() endif()
if(NOT WITH_GPU OR WIN32 OR APPLE)
list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass)
endif()
# Some ops need to check results when gc is enabled # Some ops need to check results when gc is enabled
# Currently, only ops that register NoNeedBufferVarsInference need to do this test # Currently, only ops that register NoNeedBufferVarsInference need to do this test
set(TEST_OPS_WITH_GC set(TEST_OPS_WITH_GC
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
from test_eager_deletion_padding_rnn import RNNConfig, PaddingRNNTestBase
class FusionGroupPaddingRNNTest(PaddingRNNTestBase):
def set_customed_config(self):
self.build_strategy.enable_auto_fusion = True
# Use CUDA executor
if core.is_compiled_with_cuda():
self.exe = fluid.Executor(fluid.CUDAPlace(0))
def test_train_enable_fusion_group(self):
rnn_model = "static"
config = RNNConfig("test", rnn_model)
with fluid.scope_guard(fluid.Scope()):
self.train(config, parallel=True, use_program_cache=False)
if __name__ == '__main__':
unittest.main()
...@@ -21,7 +21,6 @@ import numpy as np ...@@ -21,7 +21,6 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import time
import os import os
from paddle.fluid import ParamAttr from paddle.fluid import ParamAttr
...@@ -118,8 +117,7 @@ def lm_model(hidden_size, ...@@ -118,8 +117,7 @@ def lm_model(hidden_size,
num_steps=20, num_steps=20,
init_scale=0.1, init_scale=0.1,
dropout=None, dropout=None,
rnn_model='static', rnn_model='static'):
use_py_reader=False):
def padding_rnn(input_embedding, len=3, init_hidden=None, init_cell=None): def padding_rnn(input_embedding, len=3, init_hidden=None, init_cell=None):
weight_1_arr = [] weight_1_arr = []
weight_2_arr = [] weight_2_arr = []
...@@ -279,38 +277,9 @@ def lm_model(hidden_size, ...@@ -279,38 +277,9 @@ def lm_model(hidden_size,
gate_input = layers.elementwise_add(gate_input, bias) gate_input = layers.elementwise_add(gate_input, bias)
i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1) i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
try: c = pre_cell * layers.sigmoid(f) + layers.sigmoid(
from paddle.fluid.contrib.layers import fused_elemwise_activation i) * layers.tanh(j)
# fluid.contrib.layers.fused_elemwise_activation can do a fused m = layers.tanh(c) * layers.sigmoid(o)
# operation, like:
# 1) x + sigmoid(y); x + tanh(y)
# 2) tanh(x + y)
# Now the unary operation supported in this fused op is limit, and
# we will extent this operation to support more unary operations and
# do this kind of fusion automitically in future version of paddle.fluid.
# layers.sigmoid(i) * layers.tanh(j)
tmp0 = fused_elemwise_activation(
x=layers.tanh(j),
y=i,
functor_list=['elementwise_mul', 'sigmoid'],
save_intermediate_out=False)
# pre_cell * layers.sigmoid(f)
tmp1 = fused_elemwise_activation(
x=pre_cell,
y=f,
functor_list=['elementwise_mul', 'sigmoid'],
save_intermediate_out=False)
c = tmp0 + tmp1
# layers.tanh(c) * layers.sigmoid(o)
m = fused_elemwise_activation(
x=layers.tanh(c),
y=o,
functor_list=['elementwise_mul', 'sigmoid'],
save_intermediate_out=False)
except ImportError:
c = pre_cell * layers.sigmoid(f) + layers.sigmoid(
i) * layers.tanh(j)
m = layers.tanh(c) * layers.sigmoid(o)
hidden_array[k] = m hidden_array[k] = m
cell_array[k] = c cell_array[k] = c
...@@ -342,23 +311,16 @@ def lm_model(hidden_size, ...@@ -342,23 +311,16 @@ def lm_model(hidden_size,
return real_res, last_hidden, last_cell return real_res, last_hidden, last_cell
batch_size_each = batch_size batch_size_each = batch_size
if use_py_reader: x = layers.data(
feed_shapes = [[batch_size_each, num_steps, 1], name="x",
[batch_size_each * num_steps, 1]] shape=[batch_size_each, num_steps, 1],
py_reader = fluid.layers.py_reader( dtype='int64',
capacity=16, shapes=feed_shapes, dtypes=['int64', 'int64']) append_batch_size=False)
x, y = fluid.layers.read_file(py_reader) y = layers.data(
else: name="y",
x = layers.data( shape=[batch_size_each * num_steps, 1],
name="x", dtype='int64',
shape=[batch_size_each, num_steps, 1], append_batch_size=False)
dtype='int64',
append_batch_size=False)
y = layers.data(
name="y",
shape=[batch_size_each * num_steps, 1],
dtype='int64',
append_batch_size=False)
init_hidden = layers.data( init_hidden = layers.data(
name="init_hidden", name="init_hidden",
...@@ -472,10 +434,7 @@ def lm_model(hidden_size, ...@@ -472,10 +434,7 @@ def lm_model(hidden_size,
layers.assign(input=last_hidden, output=init_hidden) layers.assign(input=last_hidden, output=init_hidden)
feeding_list = ['x', 'y', 'init_hidden', 'init_cell'] feeding_list = ['x', 'y', 'init_hidden', 'init_cell']
if use_py_reader: return loss, last_hidden, last_cell, feeding_list
return loss, last_hidden, last_cell, feeding_list, py_reader
else:
return loss, last_hidden, last_cell, feeding_list
class PaddingRNNTestBase(unittest.TestCase): class PaddingRNNTestBase(unittest.TestCase):
...@@ -483,7 +442,29 @@ class PaddingRNNTestBase(unittest.TestCase): ...@@ -483,7 +442,29 @@ class PaddingRNNTestBase(unittest.TestCase):
self.reader = Reader() self.reader = Reader()
self.device_count = 1 self.device_count = 1
def prepare_program(self, config, parallel=True): # The default exec_strategy used for PaddingRNN.
# You can change it in set_customed_config.
self.exec_strategy = fluid.ExecutionStrategy()
self.exec_strategy.num_threads = self.device_count
self.exec_strategy.num_iteration_per_drop_scope = 100
# The default build_strategy used for PaddingRNN.
# You can change it in set_customed_config.
self.build_strategy = fluid.BuildStrategy()
self.build_strategy.enable_inplace = True
self.build_strategy.memory_optimize = False
self.build_strategy.fuse_all_optimizer_ops = True
# CPU executor is used for PaddingRNN default.
# You can change to CUDA executor in set_customed_config.
self.exe = Executor(fluid.CPUPlace())
def set_customed_config(self):
# This function will be called before training.
# You can override the function to set your own config.
pass
def _prepare_program(self, config, parallel=True):
self.main_program = fluid.Program() self.main_program = fluid.Program()
self.startup_program = fluid.Program() self.startup_program = fluid.Program()
self.startup_program.random_seed = config.random_seed self.startup_program.random_seed = config.random_seed
...@@ -497,8 +478,7 @@ class PaddingRNNTestBase(unittest.TestCase): ...@@ -497,8 +478,7 @@ class PaddingRNNTestBase(unittest.TestCase):
num_steps=config.num_steps, num_steps=config.num_steps,
init_scale=config.init_scale, init_scale=config.init_scale,
dropout=config.dropout, dropout=config.dropout,
rnn_model=config.rnn_model, rnn_model=config.rnn_model)
use_py_reader=False)
self.loss, self.last_hidden, self.last_cell, self.feed_order = res_vars self.loss, self.last_hidden, self.last_cell, self.feed_order = res_vars
fluid.clip.set_gradient_clip( fluid.clip.set_gradient_clip(
...@@ -515,28 +495,19 @@ class PaddingRNNTestBase(unittest.TestCase): ...@@ -515,28 +495,19 @@ class PaddingRNNTestBase(unittest.TestCase):
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
learning_rate=self.learning_rate) learning_rate=self.learning_rate)
optimizer.minimize(self.loss) optimizer.minimize(self.loss)
self.exe = Executor(fluid.CPUPlace())
self.exe.run(self.startup_program) self.exe.run(self.startup_program)
if parallel: if parallel:
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = self.device_count
exec_strategy.num_iteration_per_drop_scope = 100
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = False
build_strategy.fuse_all_optimizer_ops = True
self.train_program = fluid.compiler.CompiledProgram( self.train_program = fluid.compiler.CompiledProgram(
self.main_program).with_data_parallel( self.main_program).with_data_parallel(
loss_name=self.loss.name, loss_name=self.loss.name,
build_strategy=build_strategy, build_strategy=self.build_strategy,
exec_strategy=exec_strategy) exec_strategy=self.exec_strategy)
else: else:
self.train_program = self.main_program self.train_program = self.main_program
def generate_init_data(self): def _generate_init_data(self):
init_hidden = np.zeros( init_hidden = np.zeros(
(self.config.num_layers, self.config.batch_size, (self.config.num_layers, self.config.batch_size,
self.config.hidden_size), self.config.hidden_size),
...@@ -547,19 +518,19 @@ class PaddingRNNTestBase(unittest.TestCase): ...@@ -547,19 +518,19 @@ class PaddingRNNTestBase(unittest.TestCase):
dtype='float32') dtype='float32')
return init_hidden, init_cell return init_hidden, init_cell
def generate_new_lr(self, epoch_id=0, device_count=1): def _generate_new_lr(self, epoch_id=0, device_count=1):
new_lr = self.config.base_learning_rate * (self.config.lr_decay**max( new_lr = self.config.base_learning_rate * (self.config.lr_decay**max(
epoch_id + 1 - self.config.epoch_start_decay, 0.0)) epoch_id + 1 - self.config.epoch_start_decay, 0.0))
lr = np.ones((self.device_count), dtype='float32') * new_lr lr = np.ones((self.device_count), dtype='float32') * new_lr
return lr return lr
def prepare_input(self, def _prepare_input(self,
batch, batch,
init_hidden=None, init_hidden=None,
init_cell=None, init_cell=None,
epoch_id=0, epoch_id=0,
with_lr=True, with_lr=True,
device_count=1): device_count=1):
x, y = batch x, y = batch
x = x.reshape((-1, self.config.num_steps, 1)) x = x.reshape((-1, self.config.num_steps, 1))
y = y.reshape((-1, 1)) y = y.reshape((-1, 1))
...@@ -572,19 +543,19 @@ class PaddingRNNTestBase(unittest.TestCase): ...@@ -572,19 +543,19 @@ class PaddingRNNTestBase(unittest.TestCase):
if init_cell is not None: if init_cell is not None:
res['init_cell'] = init_cell res['init_cell'] = init_cell
if with_lr: if with_lr:
res['learning_rate'] = self.generate_new_lr(epoch_id, device_count) res['learning_rate'] = self._generate_new_lr(epoch_id, device_count)
return res return res
def train_an_epoch(self, epoch_id, batch_times, use_program_cache=True): def _train_an_epoch(self, epoch_id, use_program_cache=True):
train_data_iter = self.reader.get_data_iter(self.config) train_data_iter = self.reader.get_data_iter(self.config)
total_loss = 0 total_loss = 0
iters = 0 iters = 0
init_hidden, init_cell = self.generate_init_data() init_hidden, init_cell = self._generate_init_data()
ppl = np.zeros(shape=(0)) ppl = np.zeros(shape=(0))
for batch_id, batch in enumerate(train_data_iter): for batch_id, batch in enumerate(train_data_iter):
input_data_feed = self.prepare_input( input_data_feed = self._prepare_input(
batch, batch,
init_hidden=init_hidden, init_hidden=init_hidden,
init_cell=init_cell, init_cell=init_cell,
...@@ -592,7 +563,6 @@ class PaddingRNNTestBase(unittest.TestCase): ...@@ -592,7 +563,6 @@ class PaddingRNNTestBase(unittest.TestCase):
with_lr=True, with_lr=True,
device_count=self.device_count) device_count=self.device_count)
batch_start_time = time.time()
fetch_outs = self.exe.run(self.train_program, fetch_outs = self.exe.run(self.train_program,
feed=input_data_feed, feed=input_data_feed,
fetch_list=[ fetch_list=[
...@@ -601,8 +571,6 @@ class PaddingRNNTestBase(unittest.TestCase): ...@@ -601,8 +571,6 @@ class PaddingRNNTestBase(unittest.TestCase):
self.last_cell.name self.last_cell.name
], ],
use_program_cache=use_program_cache) use_program_cache=use_program_cache)
batch_time = time.time() - batch_start_time
batch_times.append(batch_time)
cost_train = np.array(fetch_outs[0]) cost_train = np.array(fetch_outs[0])
lr = np.array(fetch_outs[1]) lr = np.array(fetch_outs[1])
...@@ -617,17 +585,13 @@ class PaddingRNNTestBase(unittest.TestCase): ...@@ -617,17 +585,13 @@ class PaddingRNNTestBase(unittest.TestCase):
return ppl return ppl
def train(self, config, parallel=True, use_program_cache=True): def train(self, config, parallel=True, use_program_cache=True):
self.set_customed_config()
self.config = config self.config = config
self.prepare_program(config, parallel) self._prepare_program(config, parallel)
total_time = 0.0
ppl = np.zeros(shape=(0, config.batch_size)) ppl = np.zeros(shape=(0, config.batch_size))
for epoch_id in range(config.max_epoch): for epoch_id in range(config.max_epoch):
batch_times = [] train_ppl = self._train_an_epoch(epoch_id, use_program_cache)
epoch_start_time = time.time()
train_ppl = self.train_an_epoch(epoch_id, batch_times,
use_program_cache)
epoch_time = time.time() - epoch_start_time
total_time += epoch_time
ppl = np.append(ppl, train_ppl) ppl = np.append(ppl, train_ppl)
return ppl return ppl
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册