提交 011f0644 编写于 作者: S sneaxiy

merge develop to solve conflict, test=develop

......@@ -64,7 +64,15 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass buffer_shared_inplace_op_pass buffer_shared_cross_op_memory_reuse_pass set_reader_device_count_pass)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
sequential_execution_pass
modify_op_lock_and_record_event_pass
all_reduce_deps_pass
reference_count_pass
eager_deletion_pass
buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass
set_reader_device_count_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......@@ -91,23 +99,22 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)
set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass
multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass)
if(WITH_GPU)
set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass)
endif()
if(WITH_NGRAPH)
set(NGRAPH_BS_DEPS ngraph)
else()
set(NGRAPH_BS_DEPS)
set(IR_PASS_DEPS ${IR_PASS_DEPS} ngraph)
endif()
cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass
pass_builder
${NGRAPH_BS_DEPS})
cc_library(build_strategy SRCS build_strategy.cc DEPS pass_builder ${IR_PASS_DEPS})
if (WITH_MKLDNN)
target_link_libraries(build_strategy mkldnn_placement_pass)
......
......@@ -166,9 +166,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendOpFusePasses() {
AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_,
"fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
#ifdef PADDLE_WITH_CUDA
AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass");
#endif
AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_,
"fuse_elewise_add_act_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
// for single card training, fuse_all_reduce_ops is unnecessary.
// coalesce_grad_tensor_pass should be before of MultiDevPass.
AppendPassWithCheck(strategy_.fuse_all_reduce_ops_,
......@@ -375,6 +378,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "fusion_group_pass") {
pass->Set<bool>("use_gpu", new bool(use_cuda));
if (!use_cuda) {
LOG(WARNING) << "fusion_group_pass is only supported on GPU, skipped.";
continue;
}
} else if (pass->Type() == "fuse_bn_act_pass") {
if (!use_cuda) {
LOG(WARNING) << "fuse_bn_act_pass is only supported on "
......@@ -435,3 +444,6 @@ USE_PASS(mkldnn_placement_pass);
#ifdef PADDLE_WITH_NGRAPH
USE_PASS(ngraph_subgraph_pass);
#endif
#ifdef PADDLE_WITH_CUDA
USE_PASS(fusion_group_pass);
#endif
......@@ -86,8 +86,9 @@ struct BuildStrategy {
// Operator fusion
// TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
// cycle.
bool fuse_elewise_add_act_ops_{false};
bool fuse_bn_act_ops_{false};
bool fuse_elewise_add_act_ops_{false};
bool enable_auto_fusion_{false};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types
boost::optional<bool> fuse_all_optimizer_ops_{false};
......
......@@ -6,7 +6,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass)
if(NOT APPLE AND NOT WIN32)
if(NOT APPLE AND NOT WIN32 AND WITH_GPU)
add_subdirectory(fusion_group)
endif()
......
cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph)
cc_library(code_generator
SRCS operation.cc code_generator.cc code_generator_helper.cc
DEPS graph subgraph_detector)
if(WITH_GPU)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass)
endif()
cc_library(fusion_group_pass
SRCS fusion_group_pass.cc elementwise_group_detector.cc
DEPS graph_pattern_detector pass code_generator)
DEPS subgraph_detector fuse_pass_base code_generator device_code)
cc_test(test_fusion_group_pass SRCS fusion_group_pass_tester.cc DEPS fusion_group_pass graph_viz_pass)
......@@ -33,7 +33,7 @@ CodeGenerator::CodeGenerator() {
std::string CodeGenerator::Generate(SubGraph* subgraph) {
std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph);
return Generate(subgraph->func_name, expressions);
return Generate(subgraph->GetFuncName(), expressions);
}
static bool HasInput(Node* n, std::string name) {
......
......@@ -227,7 +227,7 @@ std::vector<fusion_group::OperationExpression> TestMain(
std::string code_str = code_generator.Generate(subgraph);
VLOG(3) << code_str;
TestMainImpl(subgraph->func_name, code_str, cpu_tensors, n, input_ids,
TestMainImpl(subgraph->GetFuncName(), code_str, cpu_tensors, n, input_ids,
output_ids);
// Need to check the accuracy according to expressions.
......
......@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
namespace paddle {
namespace framework {
......@@ -26,20 +29,22 @@ static std::unordered_set<std::string> unary_op_types;
static std::unordered_set<std::string>& GetBinaryOpTypes() {
if (binary_op_types.empty()) {
binary_op_types = OperationMap::Instance().Find(0, 2);
binary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 2);
}
return binary_op_types;
}
static std::unordered_set<std::string>& GetUnaryOpTypes() {
if (unary_op_types.empty()) {
unary_op_types = OperationMap::Instance().Find(0, 1);
unary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 1);
}
return unary_op_types;
}
static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
Node* n) {
const Node* n) {
if (n && n->IsOp() && n->Op() && n->outputs.size() > 0U) {
auto iter = op_types.find(n->Op()->Type());
if (iter != op_types.end()) {
......@@ -49,114 +54,63 @@ static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
return false;
}
static bool IsBinaryOp(Node* n) {
if (IsSpecifiedOp(GetBinaryOpTypes(), n) && n->inputs.size() == 2U) {
auto* x = n->inputs[0];
auto* y = n->inputs[1];
static bool IsGradOp(const Node* n) {
PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
platform::errors::InvalidArgument(
"Expected node %p to be an operator node.", n));
std::string suffix = "_grad";
std::string op_type = n->Op()->Type();
size_t pos = op_type.rfind(suffix);
return pos != std::string::npos &&
pos == (op_type.length() - suffix.length());
}
std::vector<int64_t> x_shape;
std::vector<int64_t> y_shape;
if (x && x->IsVar() && x->Var()) {
x_shape = x->Var()->GetShape();
}
if (y && y->IsVar() && y->Var()) {
y_shape = y->Var()->GetShape();
}
if (x_shape.size() == 0U || x_shape.size() != y_shape.size()) {
static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
const std::vector<int64_t>& r) {
return l.size() != 0U && r.size() != 0U && l == r;
}
static bool IsBinaryOp(const Node* n) {
if (IsSpecifiedOp(GetBinaryOpTypes(), n)) {
if ((!IsGradOp(n) && n->inputs.size() != 2U) || n->inputs.size() == 0U) {
return false;
}
for (size_t i = 0; i < x_shape.size(); ++i) {
if (x_shape[i] != y_shape[i]) {
// The shape of all inputs should be the same.
std::vector<int64_t> shape_0;
for (size_t i = 0; i < n->inputs.size(); ++i) {
auto* in_i = n->inputs[i];
if (!(in_i && in_i->IsVar() && in_i->Var())) {
return false;
}
}
return true;
}
return false;
}
static bool IsUnaryOp(Node* n) { return IsSpecifiedOp(GetUnaryOpTypes(), n); }
bool ElementwiseGroupDetector::IsElementwiseOp(Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n);
}
bool ElementwiseGroupDetector::IsInputOfElementwiseOp(Node* n,
std::string name) {
if (n && n->IsVar() && n->Var()) {
for (auto* op : n->outputs) {
if (IsElementwiseOp(op)) {
if (name.empty()) {
return true;
} else if (IsNthInput(n, op, name, 0)) {
return true;
std::vector<int64_t> shape_i = in_i->Var()->GetShape();
if (i == 0U) {
shape_0 = shape_i;
} else {
if (!IsEqualAndNotEmpty(shape_0, shape_i)) {
return false;
}
}
}
return true;
}
return false;
}
bool ElementwiseGroupDetector::IsOutputOfElementwiseOp(Node* n) {
if (n && n->IsVar() && n->Var()) {
for (auto* op : n->inputs) {
if (IsElementwiseOp(op)) {
return true;
}
}
}
return false;
static bool IsUnaryOp(const Node* n) {
return IsSpecifiedOp(GetUnaryOpTypes(), n);
}
int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) {
std::unordered_set<Node*> except_nodes_set;
for (size_t i = 0; i < except_nodes.size(); ++i) {
except_nodes_set.insert(except_nodes[i]);
}
int num_operations = 0;
if (IsElementwiseOp(n)) {
subgraph_.Insert(n);
num_operations += 1;
for (auto* var : n->inputs) {
subgraph_.Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n});
}
}
for (auto* var : n->outputs) {
subgraph_.Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n});
}
}
} else if (n && n->IsVar() && n->Var()) {
for (auto* op : n->inputs) {
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end()) {
num_operations += Search(op, {n});
}
}
for (auto* op : n->outputs) {
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end()) {
num_operations += Search(op, {n});
}
}
}
return num_operations;
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n);
}
int ElementwiseGroupDetector::operator()(Node* n) {
if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) {
name_ = n->Name();
subgraph_.Insert(n);
num_operations_ = Search(n, n->inputs);
VLOG(4) << "Detect elementwise subgraph begin with " << name_ << ", "
<< num_operations_ << " operations, " << GetSubgraph().GetNumNodes()
<< " nodes";
}
return num_operations_;
std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()(
Graph* graph) {
auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); };
return SubgraphDetector(graph, teller)();
}
} // namespace fusion_group
......
......@@ -14,10 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
......@@ -27,21 +25,10 @@ namespace fusion_group {
class ElementwiseGroupDetector {
public:
int operator()(Node* n);
SubGraph GetSubgraph() const { return subgraph_; }
private:
bool IsElementwiseOp(Node* n);
bool IsInputOfElementwiseOp(Node* n, std::string name = "");
bool IsOutputOfElementwiseOp(Node* n);
int Search(Node* n, std::vector<Node*> except_nodes = {});
std::vector<std::vector<Node*>> operator()(Graph* graph);
private:
std::string name_;
int num_operations_{0};
SubGraph subgraph_;
bool IsElementwiseOp(const Node* n);
};
} // namespace fusion_group
......
......@@ -13,57 +13,88 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h"
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/platform/device_code.h"
namespace paddle {
namespace framework {
namespace ir {
void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph);
int num_elementwise_groups = DetectFusionGroup(graph, 0);
LOG(INFO) << "Detect " << num_elementwise_groups
FusePassBase::Init("fusion_group_pass", graph);
if (Get<bool>("use_gpu")) {
fusion_group::OperationMap::Init();
int num_elementwise_groups = DetectFusionGroup(graph, 0);
VLOG(3) << "Detect " << num_elementwise_groups
<< " elementwise fusion groups.";
}
}
int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
std::vector<fusion_group::SubGraph> subgraphs;
std::unordered_set<Node*> all_nodes = graph->Nodes();
for (Node* n : all_nodes) {
bool is_found = false;
for (auto& subgraph : subgraphs) {
if (subgraph.Has(n)) {
is_found = true;
break;
}
}
if (is_found) {
continue;
// TODO(liuyiqun): supported different places
platform::CUDAPlace place = platform::CUDAPlace(0);
int index = platform::DeviceCodePool::Init({place}).size(place);
std::vector<std::vector<Node*>> subgraphs =
fusion_group::ElementwiseGroupDetector()(graph);
int num_subgraphs = 0;
size_t min_subgraph_size = 2;
bool save_intermediate_out = true;
for (auto& vec : subgraphs) {
if (vec.size() >= min_subgraph_size) {
std::string func_name = "fused_elementwise_" + std::to_string(index++);
fusion_group::SubGraph subgraph(
type, func_name, save_intermediate_out,
std::unordered_set<Node*>(vec.begin(), vec.end()));
VLOG(3) << "subgraph: {\n"
<< DebugString(subgraph.SortedNodes()) << "}\n";
GenerateCode(&subgraph);
InsertFusionGroupOp(graph, &subgraph);
num_subgraphs++;
}
}
return num_subgraphs;
}
fusion_group::SubGraph subgraph;
if (type == 0) {
fusion_group::ElementwiseGroupDetector detector;
int num_operations = detector(n);
if (num_operations >= 2) {
subgraph = detector.GetSubgraph();
}
}
void FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const {
fusion_group::CodeGenerator code_generator;
std::string code_str = code_generator.Generate(subgraph);
VLOG(3) << code_str;
// TODO(liuyiqun): supported different places
platform::CUDAPlace place = platform::CUDAPlace(0);
std::unique_ptr<platform::CUDADeviceCode> device_code(
new platform::CUDADeviceCode(place, subgraph->GetFuncName(), code_str));
device_code->Compile();
platform::DeviceCodePool& pool = platform::DeviceCodePool::Init({place});
pool.Set(std::move(device_code));
}
if (!subgraph.IsEmpty()) {
subgraphs.push_back(subgraph);
static int ExtractOpRole(fusion_group::SubGraph* subgraph) {
std::unordered_set<int> op_roles;
std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName();
for (auto* n : subgraph->Nodes()) {
if (n && n->IsOp() && n->Op()) {
if (n->Op()->HasAttr(attr_name)) {
op_roles.insert(boost::get<int>(n->Op()->GetAttr(attr_name)));
}
}
}
// TODO(liuyiqun): check whether there are intersection between subgraphs
for (size_t i = 0; i < subgraphs.size(); ++i) {
InsertFusionGroupOp(graph, &subgraphs[i]);
if (op_roles.size() == 1U) {
return *(op_roles.begin());
} else {
return static_cast<int>(OpRole::kNotSpecified);
}
return subgraphs.size();
}
void FusionGroupPass::InsertFusionGroupOp(
......@@ -90,10 +121,12 @@ void FusionGroupPass::InsertFusionGroupOp(
external_nodes.insert(n);
}
op_desc.SetOutput("Outs", output_names);
op_desc.SetAttr("type", subgraph->type);
op_desc.SetAttr("func_name", subgraph->func_name);
op_desc.SetAttr("type", subgraph->GetType());
op_desc.SetAttr("func_name", subgraph->GetFuncName());
op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(subgraph));
auto fusion_group_node = graph->CreateOpNode(&op_desc);
Node* fusion_group_node = graph->CreateOpNode(&op_desc);
for (auto* in : input_vars_of_subgraph) {
IR_NODE_LINK_TO(in, fusion_group_node);
}
......@@ -114,4 +147,5 @@ void FusionGroupPass::InsertFusionGroupOp(
} // namespace framework
} // namespace paddle
REGISTER_PASS(fusion_group_pass, paddle::framework::ir::FusionGroupPass);
REGISTER_PASS(fusion_group_pass, paddle::framework::ir::FusionGroupPass)
.RequirePassAttr("use_gpu");
......@@ -16,19 +16,20 @@ limitations under the License. */
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class FusionGroupPass : public Pass {
class FusionGroupPass : public FusePassBase {
protected:
void ApplyImpl(Graph* graph) const override;
private:
int DetectFusionGroup(Graph* graph, int type = 0) const;
void GenerateCode(fusion_group::SubGraph* subgraph) const;
void InsertFusionGroupOp(Graph* graph,
fusion_group::SubGraph* subgraph) const;
......
......@@ -138,19 +138,15 @@ int TestMain(std::unique_ptr<Graph> graph, std::string prefix) {
}
TEST(FusionGroupPass, elementwise_list) {
fusion_group::OperationMap::Init();
std::unique_ptr<Graph> graph = BuildElementwiseListGraph(false);
std::unique_ptr<Graph> graph = BuildElementwiseListGraph(true);
int num_fusion_group_ops = TestMain(std::move(graph), "elementwise_list");
EXPECT_EQ(num_fusion_group_ops, 1);
EXPECT_EQ(num_fusion_group_ops, 2);
}
TEST(FusionGroupPass, elementwise_tree) {
fusion_group::OperationMap::Init();
std::unique_ptr<Graph> graph = BuildElementwiseTreeGraph(false);
std::unique_ptr<Graph> graph = BuildElementwiseTreeGraph(true);
int num_fusion_group_ops = TestMain(std::move(graph), "elementwise_tree");
EXPECT_EQ(num_fusion_group_ops, 2);
EXPECT_EQ(num_fusion_group_ops, 4);
}
} // namespace ir
......
......@@ -20,48 +20,59 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
namespace paddle {
namespace framework {
namespace ir {
namespace fusion_group {
struct SubGraph {
int type{-1};
std::string func_name;
bool save_intermediate_out{false};
class SubGraph {
public:
SubGraph() = default;
SubGraph(int t, std::string f, bool s, const std::unordered_set<Node*>& n)
: type(t), func_name(f), save_intermediate_out(s), nodes_set(n) {}
explicit SubGraph(int type) : type_(type) {}
SubGraph(int type, std::string func_name, bool save_intermediate_out,
const std::unordered_set<Node*>& nodes_set)
: type_(type),
func_name_(func_name),
save_intermediate_out_(save_intermediate_out) {
for (auto* n : nodes_set) {
nodes_set_.insert(n);
if (n && n->IsOp() && n->Op()) {
// If the node is an op node, then add its input/output var nodes
// into the subgraph.
for (auto* in : n->inputs) {
nodes_set_.insert(in);
}
for (auto* out : n->outputs) {
nodes_set_.insert(out);
}
}
}
}
bool IsEmpty() { return nodes_set.empty(); }
bool IsEmpty() { return nodes_set_.empty(); }
const std::unordered_set<Node*>& Nodes() const { return nodes_set; }
int GetType() const { return type_; }
void SetFuncName(std::string func_name) { func_name_ = func_name; }
std::string GetFuncName() const { return func_name_; }
const std::unordered_set<Node*>& Nodes() const { return nodes_set_; }
const std::vector<Node*>& SortedNodes() {
if (!is_sorted) {
Sort();
if (!is_sorted_) {
TopologicalSort();
}
return sorted_nodes;
return sorted_nodes_;
}
size_t GetNumNodes() { return nodes_set.size(); }
size_t GetNumNodes() { return nodes_set_.size(); }
bool Has(Node* n) { return nodes_set.find(n) != nodes_set.end(); }
void Insert(Node* n) {
if (nodes_set.find(n) == nodes_set.end()) {
VLOG(5) << "Insert " << n->Name() << " to subgraph " << this;
nodes_set.insert(n);
is_sorted = false;
}
}
bool Has(Node* n) { return nodes_set_.find(n) != nodes_set_.end(); }
int GetNumOperations() {
int num_operations = 0;
for (auto* n : nodes_set) {
for (auto* n : nodes_set_) {
if (n && n->IsOp() && n->Op()) {
num_operations++;
}
......@@ -96,203 +107,108 @@ struct SubGraph {
std::vector<Node*> GetOutputVarNodes() {
// The order of output nodes should be consistant anywhere..
std::vector<Node*> output_vars;
std::vector<Node*> output_vars_all;
for (auto* n : SortedNodes()) {
if (n && n->IsVar() && n->Var()) {
if (save_intermediate_out) {
// If the var_node is the output of some op_node in the subgraph, it
// is considered the output var node of the subgraph.
bool is_found = false;
for (auto* in : n->inputs) {
if (Has(in)) {
is_found = true;
}
}
if (is_found) {
output_vars.push_back(n);
}
} else {
// If one of the var_node's outputs is the input of some operator
// outside the subgraph, it is considered the output var node of the
// subgraph.
bool is_found = true;
if (n->outputs.size() == 0U) {
is_found = false;
}
for (auto* out : n->outputs) {
if (!Has(out)) {
is_found = false;
}
}
if (!is_found) {
output_vars.push_back(n);
// If the var_node is the output of some op_node in the subgraph, it
// is considered the output var node of the subgraph.
bool is_found = false;
for (auto* in : n->inputs) {
if (Has(in)) {
is_found = true;
}
}
if (is_found) {
output_vars_all.push_back(n);
}
}
}
return output_vars;
}
private:
int FindIndexInSortedNodes(Node* n) {
for (size_t i = 0; i < sorted_nodes.size(); ++i) {
if (n == sorted_nodes[i]) {
return static_cast<int>(i);
}
if (save_intermediate_out_) {
return output_vars_all;
}
return -1;
}
void SortVarsBasedOnSortedOps() {
// Insert var nodes to sorted_nodes.
std::unordered_map<std::string, Node*> sorted_vars;
for (auto* n : nodes_set) {
if (n && n->IsVar() && n->Var()) {
int from = 0;
int to = sorted_nodes.size();
for (auto* in : n->inputs) {
if (in && in->IsOp() && in->Op()) {
int index = FindIndexInSortedNodes(in);
// Insert after input op node
if (index >= 0) {
from = index + 1 > from ? index + 1 : from;
}
}
}
for (auto* out : n->outputs) {
if (out && out->IsOp() && out->Op()) {
int index = FindIndexInSortedNodes(out);
// Insert before output op node
if (index >= 0) {
to = index < to ? index : to;
}
}
}
if (from > to) {
LOG(INFO) << "subgraph: {\n" << DebugString(Nodes()) << "}\n";
LOG(INFO) << "sorted nodes: {\n"
<< DebugString(sorted_nodes) << "}\n";
std::vector<Node*> output_vars_outside;
for (auto* n : output_vars_all) {
// If one of the var_node's outputs is the input of some operator
// outside the subgraph, it is considered the output var node of the
// subgraph.
bool is_found = true;
if (n->outputs.size() == 0U) {
is_found = false;
}
for (auto* out : n->outputs) {
if (!Has(out)) {
is_found = false;
}
PADDLE_ENFORCE_LE(from, to, "Range [%d, %d] is invalid.", from, to);
sorted_nodes.insert(sorted_nodes.begin() + to, n);
sorted_vars[n->Name()] = n;
}
if (!is_found) {
output_vars_outside.push_back(n);
}
}
return output_vars_outside;
}
std::vector<Node*> SortedOps() {
Node* start_op_n = nullptr;
std::unordered_set<Node*> ops;
for (auto* op_n : nodes_set) {
if (op_n && op_n->IsOp() && op_n->Op()) {
// Initialize ops to all ops in the subgraph.
ops.insert(op_n);
private:
void TopologicalSort() {
if (!is_sorted_) {
std::unordered_map<Node*, std::vector<Node*>> inputs_map;
std::unordered_map<Node*, std::vector<Node*>> outputs_map;
for (auto* n : nodes_set_) {
inputs_map[n] = n->inputs;
outputs_map[n] = n->outputs;
}
if (!start_op_n) {
// Find start op node whose inputs are produced outside the subgraph.
bool is_found = false;
for (auto* prev_op_n : GetPrevOpNodes(op_n)) {
if (Has(prev_op_n)) {
is_found = true;
break;
for (auto* n : nodes_set_) {
if (n && n->IsVar() && n->Var()) {
// Set the input of subgraph's input var node to null.
std::vector<Node*> inputs;
for (auto* in : n->inputs) {
if (Has(in)) {
inputs.push_back(in);
}
}
if (!is_found) {
start_op_n = op_n;
// Set the output of subgraph's output var node to null.
std::vector<Node*> outputs;
for (auto* out : n->outputs) {
if (Has(out)) {
outputs.push_back(out);
}
}
n->inputs = inputs;
n->outputs = outputs;
}
}
}
std::vector<Node*> sorted_ops;
sorted_ops.push_back(start_op_n);
ops.erase(start_op_n);
while (ops.size() > 0U) {
std::unordered_set<Node*> erased_ops;
for (auto* op_n : ops) {
bool found_connected_ops = false;
int from = 1;
int to = sorted_ops.size();
std::unordered_set<Node*> prev_op_nodes = GetPrevOpNodes(op_n);
std::unordered_set<Node*> next_op_nodes = GetNextOpNodes(op_n);
for (int i = sorted_ops.size(); i >= 0; --i) {
if (prev_op_nodes.find(sorted_ops[i]) != prev_op_nodes.end()) {
// Insert after i (i + 1)
found_connected_ops = true;
from = (i + 1 > from) ? i + 1 : from;
}
if (next_op_nodes.find(sorted_ops[i]) != next_op_nodes.end()) {
// Insert before i
found_connected_ops = true;
to = (i < to) ? i : to;
}
}
if (found_connected_ops) {
if (from > to) {
LOG(INFO) << "subgraph: {\n" << DebugString(Nodes()) << "}\n";
}
PADDLE_ENFORCE_LE(from, to, "Range [%d, %d] is invalid.", from, to);
sorted_ops.insert(sorted_ops.begin() + to, op_n);
erased_ops.insert(op_n);
// Collect the start points of the subgraph.
std::vector<Node*> start_points;
for (auto* n : nodes_set_) {
if (n->inputs.empty()) {
start_points.push_back(n);
}
}
PADDLE_ENFORCE_GT(erased_ops.size(), 0U);
for (auto* op_n : erased_ops) {
ops.erase(op_n);
// Sort the subgraph.
NodesTSIterator x(start_points);
for (auto& n : iterator_range<NodesTSIterator>(
NodesTSIterator(start_points), NodesTSIterator())) {
sorted_nodes_.push_back(&n);
}
}
return sorted_ops;
}
std::unordered_set<Node*> GetPrevOpNodes(Node* op_n) {
PADDLE_ENFORCE_EQ(op_n && op_n->IsOp() && op_n->Op(), true,
"Node %p is not a op node.", op_n);
std::unordered_set<Node*> prev_op_nodes;
for (auto* in_var : op_n->inputs) {
if (in_var && in_var->IsVar() && in_var->Var()) {
for (auto* prev_op_n : in_var->inputs) {
if (prev_op_n && prev_op_n->IsOp() && prev_op_n->Op()) {
prev_op_nodes.insert(prev_op_n);
}
}
// Reset the inputs, outputs.
for (auto* n : nodes_set_) {
n->inputs = inputs_map[n];
n->outputs = outputs_map[n];
}
}
return prev_op_nodes;
}
std::unordered_set<Node*> GetNextOpNodes(Node* op_n) {
PADDLE_ENFORCE_EQ(op_n && op_n->IsOp() && op_n->Op(), true,
"Node %p is not a op node.", op_n);
std::unordered_set<Node*> next_op_nodes;
for (auto* out_var : op_n->outputs) {
if (out_var && out_var->IsVar() && out_var->Var()) {
for (auto* next_op_n : out_var->outputs) {
if (next_op_n && next_op_n->IsOp() && next_op_n->Op()) {
next_op_nodes.insert(next_op_n);
}
}
}
}
return next_op_nodes;
}
void Sort() {
if (!is_sorted) {
sorted_nodes = SortedOps();
SortVarsBasedOnSortedOps();
}
is_sorted = true;
is_sorted_ = true;
}
private:
std::unordered_set<Node*> nodes_set;
bool is_sorted{false};
std::vector<Node*> sorted_nodes;
int type_{-1};
std::string func_name_;
bool save_intermediate_out_{true};
std::unordered_set<Node*> nodes_set_;
bool is_sorted_{false};
std::vector<Node*> sorted_nodes_;
};
} // namespace fusion_group
......
......@@ -9,9 +9,11 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <string>
#include "paddle/fluid/operators/interpolate_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle {
namespace operators {
......@@ -586,17 +588,18 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
int out_chw = c * out_hw;
int pixelNum = n * out_chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
if ("nearest" == interp_method) {
KeNearestNeighborInterpFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
KeNearestNeighborInterpFw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
} else if ("bilinear" == interp_method) {
KeBilinearInterpFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
KeBilinearInterpFw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
}
......@@ -696,12 +699,13 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
int out_cdhw = c * out_dhw;
int pixelNum = n * out_cdhw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
if ("trilinear" == interp_method) {
KeTrilinearInterpFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
KeTrilinearInterpFw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode, data_layout);
......@@ -787,17 +791,18 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
int out_chw = c * out_hw;
int pixelNum = n * out_chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
if ("nearest" == interp_method) {
KeNearestNeighborInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
KeNearestNeighborInterpBw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
} else if ("bilinear" == interp_method) {
KeBilinearInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
KeBilinearInterpBw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode,
data_layout);
......@@ -892,12 +897,13 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
int out_cdhw = c * out_dhw;
int pixelNum = n * out_cdhw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
if ("trilinear" == interp_method) {
KeTrilinearInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
KeTrilinearInterpBw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode, data_layout);
......
......@@ -50,11 +50,11 @@ class LoDTensor2BatchFunctor {
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
//
struct SeqInfo {
SeqInfo(int start, int length, int seq_idx)
SeqInfo(size_t start, size_t length, size_t seq_idx)
: start(start), length(length), seq_idx(seq_idx) {}
int start;
int length;
int seq_idx;
size_t start;
size_t length;
size_t seq_idx;
};
public:
......@@ -82,7 +82,7 @@ class LoDTensor2BatchFunctor {
std::vector<SeqInfo> seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id];
size_t length = lod[seq_id + 1] - lod[seq_id];
seq_info.emplace_back(lod[seq_id], length, seq_id);
}
......@@ -118,8 +118,8 @@ class LoDTensor2BatchFunctor {
batch_lods.emplace_back(std::vector<size_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor
int max_seqlen = seq_info[0].length;
batch_lods[0].resize(static_cast<size_t>(max_seqlen + 1));
size_t max_seqlen = seq_info[0].length;
batch_lods[0].resize(max_seqlen + 1);
// batch_lods[1] is the raw index in the input LoDTensor
batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
// batch_lods[2] is the sort order for the input LoDTensor.
......@@ -128,11 +128,11 @@ class LoDTensor2BatchFunctor {
size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0;
for (int n = 0; n < max_seqlen; n++) {
auto batch_id = static_cast<int>(batch_starts[n]);
for (size_t n = 0; n < max_seqlen; n++) {
size_t batch_id = batch_starts[n];
for (size_t i = 0; i < seq_info.size(); ++i) {
int seq_len = seq_info[i].length;
int start = seq_info[i].start;
size_t seq_len = seq_info[i].length;
size_t start = seq_info[i].start;
if (n < seq_len) {
seq2batch_idx[batch_id] =
is_reverse ? start + seq_len - 1 - n : start + n;
......@@ -141,7 +141,7 @@ class LoDTensor2BatchFunctor {
break;
}
}
batch_starts[n + 1] = static_cast<size_t>(batch_id);
batch_starts[n + 1] = batch_id;
}
size_t* seq_order = batch_lods[2].data();
for (size_t i = 0; i < seq_info.size(); ++i) {
......
......@@ -29,10 +29,11 @@ inline std::vector<int> get_new_shape(
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(
tensor->dims(), framework::make_ddim({1}),
"ShapeError: If the element type of 'shape' in ReshapeOp is Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor->dims());
platform::errors::InvalidArgument(
"If the element type of 'shape' in ReshapeOp is Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor->dims()));
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
......@@ -64,10 +65,11 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto ShapeTensor = ctx->Inputs("ShapeTensor");
PADDLE_ENFORCE_GT(
ShapeTensor.size(), 0,
"ShapeError: When `shape` in ReshapeOp is a list or tuple "
"which contains Tensor, the shape's size can't be zero. "
"But received shape's size is %d.",
ShapeTensor.size());
platform::errors::InvalidArgument(
"When `shape` in ReshapeOp is a list or tuple "
"which contains Tensor, the shape's size can't be zero. "
"But received shape's size is %d.",
ShapeTensor.size()));
auto infer_shape = ctx->Attrs().Get<std::vector<int>>("shape");
const int64_t copy_dim_val = 0;
auto in_dims = ctx->GetInputDim("X");
......@@ -75,10 +77,11 @@ class ReshapeOp : public framework::OperatorWithKernel {
if (infer_shape[i] == copy_dim_val) {
PADDLE_ENFORCE_LT(
static_cast<int>(i), in_dims.size(),
"ShapeError: The index of 0 in `shape` must be less than "
"the input tensor X's dimensions. But received shape[%d] "
"= 0, X's dimensions = %d, X's shape = [%s].",
i, in_dims.size(), in_dims);
platform::errors::InvalidArgument(
"The index of 0 in `shape` must be less than "
"the input tensor X's dimensions. But received shape[%d] "
"= 0, X's dimensions = %d, X's shape = [%s].",
i, in_dims.size(), in_dims));
infer_shape[i] = in_dims[i];
}
}
......@@ -108,10 +111,10 @@ class ReshapeOp : public framework::OperatorWithKernel {
return;
}
PADDLE_ENFORCE_EQ(
!shape.empty(), true,
"ShapeError: The parameter 'shape' in ReshapeOp must be set. "
"But received 'shape' is empty.");
PADDLE_ENFORCE_EQ(!shape.empty(), true,
platform::errors::InvalidArgument(
"The parameter 'shape' in ReshapeOp must be set. "
"But received 'shape' is empty."));
auto x_dims = ctx->GetInputDim("X");
auto out_dims = ValidateShape(shape, x_dims);
ctx->SetOutputDim("Out", out_dims);
......@@ -140,25 +143,28 @@ class ReshapeOp : public framework::OperatorWithKernel {
if (shape[i] == unk_dim_val) {
PADDLE_ENFORCE_EQ(
unk_dim_idx, -1,
"ShapeError: Only one dimension value of 'shape' in ReshapeOp can "
"be -1. But received shape = [%s], shape[%d] is also -1.",
framework::make_ddim(shape), i);
platform::errors::InvalidArgument(
"Only one dimension value of 'shape' in ReshapeOp can "
"be -1. But received shape = [%s], shape[%d] is also -1.",
framework::make_ddim(shape), i));
unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) {
PADDLE_ENFORCE_LT(
static_cast<int>(i), in_dims.size(),
"ShapeError: The index of 0 in `shape` must be less than "
"the input tensor X's dimensions. "
"But received shape = [%s], shape[%d] = 0, X's shape = [%s], "
"X's dimensions = %d.",
framework::make_ddim(shape), i, in_dims, in_dims.size());
platform::errors::InvalidArgument(
"The index of 0 in `shape` must be less than "
"the input tensor X's dimensions. "
"But received shape = [%s], shape[%d] = 0, X's shape = [%s], "
"X's dimensions = %d.",
framework::make_ddim(shape), i, in_dims, in_dims.size()));
} else {
PADDLE_ENFORCE_GT(
shape[i], 0,
"ShapeError: Each dimension value of 'shape' in ReshapeOp must not "
"be negtive except one unknown dimension. "
"But received shape = [%s], shape[%d] = %d.",
framework::make_ddim(shape), i, shape[i]);
platform::errors::InvalidArgument(
"Each dimension value of 'shape' in ReshapeOp must not "
"be negtive except one unknown dimension. "
"But received shape = [%s], shape[%d] = %d.",
framework::make_ddim(shape), i, shape[i]));
}
capacity *= (shape[i] ? shape[i] : in_dims[i]);
......@@ -180,8 +186,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
"The input tensor X'size must be divisible by known "
"capacity of 'shape'. "
"But received X's shape = [%s], X's size = %d, "
"'shape' is [%s], known "
"capacity of 'shape' is %d.",
"'shape' is [%s], known capacity of 'shape' is %d.",
in_dims, in_size, framework::make_ddim(shape), capacity));
} else {
output_shape[unk_dim_idx] = -1;
......@@ -190,12 +195,13 @@ class ReshapeOp : public framework::OperatorWithKernel {
if (all_positive) {
PADDLE_ENFORCE_EQ(
capacity, in_size,
"ShapeError: The 'shape' in ReshapeOp is invalid. "
"The input tensor X'size must be equal to the capacity of 'shape'. "
"But received X's shape = [%s], X's size = %d, 'shape' is [%s], "
"the "
"capacity of 'shape' is %d.",
in_dims, in_size, framework::make_ddim(shape), capacity);
platform::errors::InvalidArgument(
"The 'shape' in ReshapeOp is invalid. "
"The input tensor X'size must be equal to the capacity of "
"'shape'. "
"But received X's shape = [%s], X's size = %d, 'shape' is "
"[%s], the capacity of 'shape' is %d.",
in_dims, in_size, framework::make_ddim(shape), capacity));
}
}
return framework::make_ddim(output_shape);
......
......@@ -90,4 +90,5 @@ REGISTER_OP_CPU_KERNEL(
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace platform {
struct GpuLaunchConfig {
// Number of threads per block.
int threads;
// Number of blocks for GPU kernel launch.
int blocks;
GpuLaunchConfig(int threads, int blocks) : threads(threads), blocks(blocks) {}
};
inline GpuLaunchConfig getGpuLaunchConfig(
const int N, const framework::ExecutionContext& ctx) {
int threads =
std::min(1024, ctx.cuda_device_context().GetMaxThreadsPerBlock());
int physical_thread_count =
std::min(ctx.cuda_device_context().GetMaxPhysicalThreadCount(), N);
int blocks = std::min((physical_thread_count + threads - 1) / threads,
ctx.cuda_device_context().GetSMCount());
GpuLaunchConfig config(threads, blocks);
return config;
}
} // namespace platform
} // namespace paddle
......@@ -1984,6 +1984,27 @@ All parameter, weight, gradient are variables in Paddle.
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_bn_act_ops = True
)DOC")
.def_property(
"enable_auto_fusion",
[](const BuildStrategy &self) { return self.enable_auto_fusion_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
platform::errors::PreconditionNotMet(
"BuildStrategy is finlaized."));
self.enable_auto_fusion_ = b;
},
R"DOC((bool, optional): Whether to enable fusing subgraph to a
fusion_group. Now we only support fusing subgraph that composed
of elementwise-like operators, such as elementwise_add/mul
without broadcast and activations.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.enable_auto_fusion = True
)DOC")
.def_property(
"fuse_relu_depthwise_conv",
[](const BuildStrategy &self) {
......
......@@ -1110,6 +1110,26 @@ def _get_son_parent_block_idx_dict(program, current_block_idx):
return son_parent_block_idx_dict
def _get_no_grad_set_name(no_grad_set):
no_grad_set_name = set()
if no_grad_set is not None:
if isinstance(no_grad_set, (set, list, tuple)):
for i, no_grad_var in enumerate(no_grad_set):
if isinstance(no_grad_var, framework.Variable):
no_grad_set_name.add(no_grad_var.name)
elif isinstance(no_grad_var, six.string_types):
no_grad_set_name.add(no_grad_var)
else:
raise TypeError(
"The type of no_grad_set's member must be paddle.fluid.Variable or str, but received %s."
% (type(no_grad_var)))
else:
raise TypeError(
"The type of no_grad_set should be set or list or tuple, but received {}".
format(type(no_grad_set)))
return no_grad_set_name
def append_backward(loss,
parameter_list=None,
no_grad_set=None,
......@@ -1133,11 +1153,11 @@ def append_backward(loss,
If it is None, all parameters
will be updated.
Default: None.
no_grad_set(set[str], optional): Variable names in the :ref:`api_guide_Block_en` 0 whose gradients
no_grad_set(set[Variable|str], optional): Set of Variables or Variable.names in the :ref:`api_guide_Block_en` 0 whose gradients
should be ignored. All variables with
`stop_gradient=True` from all blocks will
be automatically added into this set.
If this parameter is not None, the names in this set will be added to the default set.
If this parameter is not None, the Variables or Variable.names in this set will be added to the default set.
Default: None.
callbacks(list[callable object], optional): List of callback functions.
The callbacks are used for
......@@ -1174,18 +1194,40 @@ def append_backward(loss,
.. code-block:: python
import paddle.fluid as fluid
x = fluid.data(name='x', shape=[None, 13], dtype='float32')
y = fluid.data(name='y', shape=[None, 1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
x = fluid.data(name='x', shape=[None, 13], dtype='int64')
y = fluid.data(name='y', shape=[None, 1], dtype='float32')
x_emb = fluid.embedding(x, size=[100, 256])
y_predict = fluid.layers.fc(input=x_emb, size=1, act=None, name='my_fc')
loss = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(loss)
param_grad_list = fluid.backward.append_backward(loss=avg_loss)
p_g_list1 = fluid.backward.append_backward(loss=avg_loss) # len(p_g_list1) == 2
p_g_list2 = fluid.backward.append_backward(loss=avg_loss, parameter_list=[p_g_list1[0][0].name]) # len(p_g_list1) == 1
p_g_list3 = fluid.backward.append_backward(loss=avg_loss, no_grad_set=set([p_g_list1[0][0].name])) # len(p_g_list1) == 1
p_g_list4 = fluid.backward.append_backward(loss=avg_loss, parameter_list=[p_g_list1[0][0].name], no_grad_set=set([p_g_list1[0][0].name])) # len(p_g_list1) == 0
# Get all weights in main_program, not include bias.
all_weights = [param for param in fluid.default_main_program().block(0).all_parameters() if 'w_' in param.name]
all_weights_name = [w.name for w in all_weights]
# return all param_grads needed to be updated if parameter_list set default None.
p_g_list1 = fluid.backward.append_backward(loss=avg_loss)
# output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD), (my_fc.b_0, my_fc.b_0@GRAD)]
# return the param_grads corresponding to parameter_list that can be list of param (Variable).
p_g_list2 = fluid.backward.append_backward(loss=avg_loss, parameter_list=all_weights)
# output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)]
# parameter_list can be list of param.name (str).
p_g_list3 = fluid.backward.append_backward(loss=avg_loss, parameter_list=all_weights_name)
# output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)]
# no_grad_set can be set of Variables that means grad will be cut off from these Variables.
p_g_list4 = fluid.backward.append_backward(loss=avg_loss, no_grad_set=set([x_emb]))
# output: [(my_fc.w_0, my_fc.w_0@GRAD), (my_fc.b_0, my_fc.b_0@GRAD)]
# no_grad_set can be set of Variable.name when the Variable is created inside layers and can't be specified explicitly.
p_g_list5 = fluid.backward.append_backward(loss=avg_loss, no_grad_set=set(['my_fc.b_0']))
# output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)]
# return [] because all param_grads are filtered by no_grad_set.
p_g_list6 = fluid.backward.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights))
"""
assert isinstance(loss, framework.Variable)
......@@ -1215,7 +1257,8 @@ def append_backward(loss,
if no_grad_set is None:
no_grad_set = set()
no_grad_set = copy.copy(no_grad_set)
else:
no_grad_set = _get_no_grad_set_name(copy.copy(no_grad_set))
no_grad_dict = _get_stop_gradients_(program)
# no_grad_set only contains vars in block 0
# Todo(liym27): support vars in sub block
......@@ -1501,12 +1544,15 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
Args:
targets(Variable|list[Variable]): The target variables
inputs(Variable|list[Variable]): The input variables
target_gradients (Variable|list[Variable]|None): The gradient variables
target_gradients (Variable|list[Variable], optional): The gradient variables
of targets which has the same shape with targets, If None, ones will
be created for them.
no_grad_set(set[string]): The names of variables that have no gradients
in Block 0. All variables with `stop_gradient=True` from all blocks
will be automatically added.
no_grad_set(set[Variable|str], optional): Set of Variables or Variable.names in the :ref:`api_guide_Block_en` 0 whose gradients
should be ignored. All variables with
`stop_gradient=True` from all blocks will
be automatically added into this set.
If this parameter is not None, the Variables or Variable.names in this set will be added to the default set.
Default: None.
Return:
(list[Variable]): A list of gradients for inputs
......@@ -1532,7 +1578,8 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
if no_grad_set is None:
no_grad_set = set()
no_grad_set = copy.copy(no_grad_set)
else:
no_grad_set = _get_no_grad_set_name(copy.copy(no_grad_set))
no_grad_dict = _get_stop_gradients_(prog)
no_grad_dict[0].update(list(map(_append_grad_suffix_, no_grad_set)))
......@@ -1623,12 +1670,13 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
Args:
targets (Variable|list[Variable]): The target variables.
inputs (Variable|list[Variable]): The input variables.
target_gradients (Variable|list[Variable]|None): The gradient variables
target_gradients (Variable|list[Variable], optional): The gradient variables
of targets which has the same shape with targets, If None, ones will
be created for them.
no_grad_set (set[string]): The names of variables that have no gradients
in Block 0. All variables with `stop_gradient=True` from all blocks
will be automatically added.
no_grad_set (set[Variable|str], optional): Set of Variables or Variable.names in the :ref:`api_guide_Block_en` 0 whose gradients
should be ignored. All variables with `stop_gradient=True` from all blocks will
be automatically added into this set. If this parameter is not None, the Variables or Variable.names
in this set will be added to the default set. Default: None.
Return:
(list[Variable]): A list of gradients for inputs
......@@ -1640,7 +1688,7 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[2,8,8], dtype='float32')
x = fluid.data(name='x', shape=[None,2,8,8], dtype='float32')
x.stop_gradient=False
y = fluid.layers.conv2d(x, 4, 1, bias_attr=False)
y = fluid.layers.relu(y)
......
......@@ -16,10 +16,10 @@ import os
import re
import logging
import numpy as np
from ....executor import global_scope
from .... import io
from .... import core
from .... import framework
from ....executor import global_scope, Executor
from ....framework import IrGraph
from ....log_helper import get_logger
from .quantization_pass import QuantizationTransformPass
......@@ -27,12 +27,31 @@ from .quantization_pass import QuantizationFreezePass
from .quantization_pass import AddQuantDequantPass
from .quantization_pass import _op_real_in_out_name
__all__ = ['PostTrainingQuantization']
__all__ = ['PostTrainingQuantization', 'WeightQuantization']
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def _load_variable_data(scope, var_name):
'''
Load variable value from scope
'''
return np.array(scope.find_var(var_name).get_tensor())
def _set_variable_data(scope, place, var_name, np_value):
'''
Set the value of var node by name, if the node exits,
'''
assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.'
var_node = scope.find_var(var_name)
if var_node != None:
tensor = var_node.get_tensor()
tensor.set(np_value, place)
class PostTrainingQuantization(object):
def __init__(self,
executor,
......@@ -297,12 +316,12 @@ class PostTrainingQuantization(object):
'''
for var_name in self._quantized_weight_var_name:
if var_name not in self._sampling_data:
var_tensor = self._load_var_value(var_name)
var_tensor = _load_variable_data(self._scope, var_name)
self._sampling_data[var_name] = var_tensor
if self._is_use_cache_file:
for var_name in self._quantized_act_var_name:
var_tensor = self._load_var_value(var_name)
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.ravel()
save_path = os.path.join(self._cache_dir,
var_name + "_" + str(iter) + ".npy")
......@@ -311,7 +330,7 @@ class PostTrainingQuantization(object):
for var_name in self._quantized_act_var_name:
if var_name not in self._sampling_data:
self._sampling_data[var_name] = []
var_tensor = self._load_var_value(var_name)
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.ravel()
self._sampling_data[var_name].append(var_tensor)
......@@ -397,11 +416,17 @@ class PostTrainingQuantization(object):
# save scale factor to scale var node
for key, val in self._quantized_var_scale_factor.items():
self._set_var_node_value(
key + ".scale", np.array(
_set_variable_data(
self._scope,
self._place,
key + ".scale",
np.array(
[val], dtype=np.float32))
self._set_var_node_value(
key + ".quant_dequant.scale", np.array(
_set_variable_data(
self._scope,
self._place,
key + ".quant_dequant.scale",
np.array(
[val], dtype=np.float32))
# apply QuantizationFreezePass, and obtain the final quant model
......@@ -430,23 +455,6 @@ class PostTrainingQuantization(object):
self._quantized_var_scale_factor[
output_var_name])
def _load_var_value(self, var_name):
'''
Load variable value from scope
'''
return np.array(self._scope.find_var(var_name).get_tensor())
def _set_var_node_value(self, var_node_name, np_value):
'''
Set the value of var node by name, if the node exits,
'''
assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.'
var_node = self._scope.find_var(var_node_name)
if var_node != None:
tensor = var_node.get_tensor()
tensor.set(np_value, self._place)
def _is_input_all_not_persistable(self, op, persistable_var_names):
'''
Analyze the real inputs of the op are all not persistable.
......@@ -566,3 +574,132 @@ class PostTrainingQuantization(object):
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
return (tmp_sum1 - tmp_sum2) / P_sum
class WeightQuantization(object):
_supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
def __init__(self, model_dir, model_filename=None, params_filename=None):
'''
This class quantizes the weight of some ops to reduce the size of model
or improve the perforemace.
Args:
model_dir(str): The path of the fp32 model that will be quantized,
and the model and params files are under the path.
model_filename(str, optional): The name of file to load the inference
program. If it is None, the default filename '__model__' will
be used. Default is 'None'.
params_filename(str, optional): The name of file to load all parameters.
When all parameters were saved in a single binary file, set it
as the real filename. If parameters were saved in separate files,
set it as 'None'. Default is 'None'.
'''
self._model_dir = model_dir
self._model_filename = model_filename
self._params_filename = params_filename
def quantize_weight_to_int(self,
save_model_dir,
save_model_filename=None,
save_params_filename=None,
quantizable_op_type=["conv2d", "mul"],
quantize_weight_bits=8,
threshold_rate=0.0):
'''
In order to reduce the size of model, this api quantizes the weight
of some ops from float32 to int8/16. In the inference stage, the
quantized weight will be dequantized to float32 again.
Args:
save_model_dir(str): The path to save the quantized model.
save_model_filename(str, optional): The name of file to
save the inference program. If it is None, the default
filename '__model__' will be used. Default is 'None'.
save_params_filename(str, optional): The name of file to
save all parameters. If it is None, parameters were
saved in separate files. If it is not None, all
parameters were saved in a single binary file.
quantizable_op_type(list[str], optional): The list of ops
that will be quantized, and the quantized ops should be
contained in ["conv2d", "depthwise_conv2d", "mul"].
Default is ["conv2d","mul"].
quantize_weight_bits(int, optional): The bits for the quantized
weight, and it should be 8 or 16. Default is 8.
threshold_rate(float, optional): This api uses abs_max methd to
quantize the weight from float32 to int8/16, and the abs max
value is important for quantization diff. When the abs_max
value is far away from the center of the numerical distribution,
we can set threshold_rate between 1e-6 and 1e-8, so the abs max
value will be optimized. Default is 0.0.
'''
for op_type in quantizable_op_type:
assert op_type in self._supported_quantizable_op_type, \
"input error:" + op_type + \
" is not supported for weight quantization."
assert quantize_weight_bits in [8, 16], \
"input error: quantize_weight_bits should be 8 or 16."
quantize_range = (1 << (quantize_weight_bits - 1)) - 1
save_weight_dtype = np.int8 if quantize_weight_bits == 8 else np.int16
place = core.CPUPlace()
exe = Executor(place)
scope = global_scope()
[program, feed_list, fetch_list] = \
io.load_inference_model(dirname=self._model_dir,
executor=exe,
model_filename=self._model_filename,
params_filename=self._params_filename)
persistable_var_names = []
for var in program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
for op in program.global_block().ops:
if op.type in quantizable_op_type:
for var_name in op.input_arg_names:
if var_name in persistable_var_names:
var_tensor_data = _load_variable_data(scope, var_name)
if abs(threshold_rate) < 1e-10:
threshold_value = np.max(np.abs(var_tensor_data))
else:
threshold_value = self._calculate_threshold(\
var_tensor_data, threshold_rate)
var_tensor_data[var_tensor_data >
threshold_value] = threshold_value
var_tensor_data[var_tensor_data <
-threshold_value] = -threshold_value
scale = threshold_value / quantize_range
quantized_var_tensor_data = \
np.around(var_tensor_data / scale)
quantized_var_tensor_data = \
quantized_var_tensor_data.astype(save_weight_dtype)
_set_variable_data(scope, place, var_name,
quantized_var_tensor_data)
op._set_attr(var_name + "_quant_scale", [scale])
op._set_attr('quantize_weight_bits',
quantize_weight_bits)
io.save_inference_model(
dirname=save_model_dir,
feeded_var_names=feed_list,
target_vars=fetch_list,
executor=exe,
main_program=program,
model_filename=save_model_filename,
params_filename=save_params_filename)
def _calculate_threshold(self, input, threshold_rate, histogram_bins=5000):
input_abs = np.abs(input)
hist, hist_edeges = np.histogram(
input_abs, bins=histogram_bins, range=(0, np.max(input_abs)))
hist = hist / float(sum(hist))
hist_sum = 0
hist_index = 0
for i in range(len(hist)):
hist_sum += hist[i]
if hist_sum >= 1.0 - threshold_rate:
hist_index = i + 1
break
bin_width = hist_edeges[1] - hist_edeges[0]
return hist_index * bin_width
......@@ -63,6 +63,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
endif()
# int8 image classification python api test
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import time
from paddle.dataset.common import download, DATA_HOME
from paddle.fluid.contrib.slim.quantization import WeightQuantization
class TestWeightQuantization(unittest.TestCase):
def setUp(self):
self.weight_quantization_dir = 'weight_quantization'
self.cache_folder = os.path.join(DATA_HOME,
self.weight_quantization_dir)
def download_model(self, model_name, data_url, data_md5):
download(data_url, self.weight_quantization_dir, data_md5)
file_name = data_url.split('/')[-1]
file_path = os.path.join(self.cache_folder, file_name)
print(model_name + ' is downloaded at ' + file_path)
unziped_path = os.path.join(self.cache_folder, model_name)
self.cache_unzipping(unziped_path, file_path)
print(model_name + ' is unziped at ' + unziped_path)
return unziped_path
def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder,
zip_path)
os.system(cmd)
def run_test(self, model_name, model_data_url, model_data_md5,
quantize_weight_bits, quantizable_op_type, threshold_rate):
model_dir = self.download_model(model_name, model_data_url,
model_data_md5)
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
save_model_dir = os.path.join(
os.getcwd(),
model_name + "_wq_" + str(quantize_weight_bits) + "_" + timestamp)
weight_quant = WeightQuantization(model_dir=model_dir + "/model")
weight_quant.quantize_weight_to_int(
save_model_dir=save_model_dir,
quantize_weight_bits=quantize_weight_bits,
quantizable_op_type=quantizable_op_type,
threshold_rate=threshold_rate)
print("finish weight quantization for " + model_name + "\n")
try:
os.system("rm -rf {}".format(save_model_dir))
except Exception as e:
print("Failed to delete {} due to {}".format(save_model_dir, str(
e)))
class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
model_name = "mobilenetv1"
model_data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz"
model_data_md5 = "13892b0716d26443a8cdea15b3c6438b"
def test_weight_quantization_mobilenetv1_8bit(self):
quantize_weight_bits = 8
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
threshold_rate = 0.0
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
quantize_weight_bits, quantizable_op_type, threshold_rate)
def test_weight_quantization_mobilenetv1_16bit(self):
quantize_weight_bits = 16
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
threshold_rate = 1e-9
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
quantize_weight_bits, quantizable_op_type, threshold_rate)
if __name__ == '__main__':
unittest.main()
......@@ -23,7 +23,7 @@ from paddle.fluid.framework import Program, Variable, name_scope, default_main_p
from . import framework
from . import layers
from . import unique_name
from .backward import append_backward, _some_in_set_, _append_grad_suffix_
from .backward import append_backward, _some_in_set_, _append_grad_suffix_, _get_no_grad_set_name
from .clip import append_gradient_clip_ops, error_clip_callback
from .framework import program_guard
from .initializer import Constant
......@@ -592,7 +592,7 @@ class Optimizer(object):
parameter_list (list, optional): List of ``Variable`` or ``Variable.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters
will be updated.
no_grad_set (set, optional): Set of ``Variable`` objects that don't need
no_grad_set (set, optional): Set of ``Variable`` or ``Variable.name`` that don't need
to be updated. The default value is None.
callbacks (list, optional): list of callable objects to run when appending backward
operator for one parameter. The default value is None.
......@@ -705,14 +705,7 @@ class Optimizer(object):
return optimize_ops
def _get_no_grad_set(self, loss, no_grad_set=None):
if no_grad_set is None:
no_grad_set = set()
elif isinstance(no_grad_set, set) or isinstance(
no_grad_set, list) or isinstance(no_grad_set, tuple):
no_grad_set = set(no_grad_set)
else:
assert "no_grad_set should be a set, but the passed type is {}".format(
type(no_grad_set))
no_grad_set = _get_no_grad_set_name(no_grad_set)
parameters = loss.block.program.global_block().all_parameters()
param_no_trainable = set(
[param.name for param in parameters if param.trainable is False])
......@@ -770,7 +763,7 @@ class Optimizer(object):
parameter_list (list, optional): List of ``Variable`` or ``Variable.name`` to update
to minimize ``loss``. The default value is None, at this time all parameters
will be updated.
no_grad_set (set, optional): Set of ``Variable`` objects that don't need
no_grad_set (set, optional): Set of ``Variable`` or ``Variable.name`` that don't need
to be updated. The default value is None.
grad_clip (GradClipBase, optional) : Gradient clipping strategy, static
graph mode does not need to use this argument. Currently, this argument
......@@ -3843,8 +3836,8 @@ class RecomputeOptimizer(Optimizer):
loss (Variable): loss variable to run optimizations.
startup_program (Program): startup_program for initializing parameters
in `parameter_list`.
parameter_list (list): list of Variables to update.
no_grad_set (set|None): set of Variables should be ignored.
parameter_list (list): list of Variables or Variable.names to update.
no_grad_set (set|None): set of Variables or Variables.names should be ignored.
callbacks (list|None): list of callables to run when appending backward
operator for one parameter.
checkpoints (list): list of Variables as checkpoints
......
......@@ -200,6 +200,10 @@ if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_imperative_signal_handler)
endif()
if(NOT WITH_GPU OR WIN32 OR APPLE)
list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass)
endif()
# Some ops need to check results when gc is enabled
# Currently, only ops that register NoNeedBufferVarsInference need to do this test
set(TEST_OPS_WITH_GC
......
......@@ -142,6 +142,21 @@ class TestBackward(unittest.TestCase):
exe.run(startup)
exe.run(feed=net.init_data())
def _check_error_no_grad_set(self, net, no_grad_set):
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = net.build_model()
optimizer = fluid.optimizer.SGD(learning_rate=0.1)
optimizer.minimize(loss, no_grad_set=no_grad_set)
exe.run(startup)
exe.run(feed=net.init_data())
class SimpleNet(BackwardNet):
def __init__(self):
......@@ -233,12 +248,25 @@ class TestSimpleNetWithErrorParamList(TestBackward):
# The type of parameter_list argument must be list or tuple
with self.assertRaises(TypeError):
self._check_error_param_list(self.net, "test")
# The type of parameter_list's member must be varable or str
# The type of parameter_list's member must be Variable or str
test = fluid.data(name='test', shape=[None, 90], dtype='float32')
with self.assertRaises(TypeError):
self._check_error_param_list(self.net, [test, "test", 3])
class TestSimpleNetWithErrorNoGradSet(TestBackward):
def test_no_grad_set_type_error(self):
self.global_block_idx = 0
self.net = SimpleNet()
# The type of no_grad_set argument must be set or list or tuple
with self.assertRaises(TypeError):
self._check_error_no_grad_set(self.net, "test")
# The type of no_grad_set's member must be Variable or str
test = fluid.data(name='test', shape=[None, 90], dtype='float32')
with self.assertRaises(TypeError):
self._check_error_no_grad_set(self.net, [test, "test", 3])
# TODO(Aurelius84): add conditional network test
class ConditionalNet(BackwardNet):
def __init__(self):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
from test_eager_deletion_padding_rnn import RNNConfig, PaddingRNNTestBase
class FusionGroupPaddingRNNTest(PaddingRNNTestBase):
def set_customed_config(self):
self.build_strategy.enable_auto_fusion = True
# Use CUDA executor
if core.is_compiled_with_cuda():
self.exe = fluid.Executor(fluid.CUDAPlace(0))
def test_train_enable_fusion_group(self):
rnn_model = "static"
config = RNNConfig("test", rnn_model)
with fluid.scope_guard(fluid.Scope()):
self.train(config, parallel=True, use_program_cache=False)
if __name__ == '__main__':
unittest.main()
......@@ -21,7 +21,6 @@ import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
import time
import os
from paddle.fluid import ParamAttr
......@@ -118,8 +117,7 @@ def lm_model(hidden_size,
num_steps=20,
init_scale=0.1,
dropout=None,
rnn_model='static',
use_py_reader=False):
rnn_model='static'):
def padding_rnn(input_embedding, len=3, init_hidden=None, init_cell=None):
weight_1_arr = []
weight_2_arr = []
......@@ -279,38 +277,9 @@ def lm_model(hidden_size,
gate_input = layers.elementwise_add(gate_input, bias)
i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
try:
from paddle.fluid.contrib.layers import fused_elemwise_activation
# fluid.contrib.layers.fused_elemwise_activation can do a fused
# operation, like:
# 1) x + sigmoid(y); x + tanh(y)
# 2) tanh(x + y)
# Now the unary operation supported in this fused op is limit, and
# we will extent this operation to support more unary operations and
# do this kind of fusion automitically in future version of paddle.fluid.
# layers.sigmoid(i) * layers.tanh(j)
tmp0 = fused_elemwise_activation(
x=layers.tanh(j),
y=i,
functor_list=['elementwise_mul', 'sigmoid'],
save_intermediate_out=False)
# pre_cell * layers.sigmoid(f)
tmp1 = fused_elemwise_activation(
x=pre_cell,
y=f,
functor_list=['elementwise_mul', 'sigmoid'],
save_intermediate_out=False)
c = tmp0 + tmp1
# layers.tanh(c) * layers.sigmoid(o)
m = fused_elemwise_activation(
x=layers.tanh(c),
y=o,
functor_list=['elementwise_mul', 'sigmoid'],
save_intermediate_out=False)
except ImportError:
c = pre_cell * layers.sigmoid(f) + layers.sigmoid(
i) * layers.tanh(j)
m = layers.tanh(c) * layers.sigmoid(o)
c = pre_cell * layers.sigmoid(f) + layers.sigmoid(
i) * layers.tanh(j)
m = layers.tanh(c) * layers.sigmoid(o)
hidden_array[k] = m
cell_array[k] = c
......@@ -342,23 +311,16 @@ def lm_model(hidden_size,
return real_res, last_hidden, last_cell
batch_size_each = batch_size
if use_py_reader:
feed_shapes = [[batch_size_each, num_steps, 1],
[batch_size_each * num_steps, 1]]
py_reader = fluid.layers.py_reader(
capacity=16, shapes=feed_shapes, dtypes=['int64', 'int64'])
x, y = fluid.layers.read_file(py_reader)
else:
x = layers.data(
name="x",
shape=[batch_size_each, num_steps, 1],
dtype='int64',
append_batch_size=False)
y = layers.data(
name="y",
shape=[batch_size_each * num_steps, 1],
dtype='int64',
append_batch_size=False)
x = layers.data(
name="x",
shape=[batch_size_each, num_steps, 1],
dtype='int64',
append_batch_size=False)
y = layers.data(
name="y",
shape=[batch_size_each * num_steps, 1],
dtype='int64',
append_batch_size=False)
init_hidden = layers.data(
name="init_hidden",
......@@ -472,10 +434,7 @@ def lm_model(hidden_size,
layers.assign(input=last_hidden, output=init_hidden)
feeding_list = ['x', 'y', 'init_hidden', 'init_cell']
if use_py_reader:
return loss, last_hidden, last_cell, feeding_list, py_reader
else:
return loss, last_hidden, last_cell, feeding_list
return loss, last_hidden, last_cell, feeding_list
class PaddingRNNTestBase(unittest.TestCase):
......@@ -483,7 +442,29 @@ class PaddingRNNTestBase(unittest.TestCase):
self.reader = Reader()
self.device_count = 1
def prepare_program(self, config, parallel=True):
# The default exec_strategy used for PaddingRNN.
# You can change it in set_customed_config.
self.exec_strategy = fluid.ExecutionStrategy()
self.exec_strategy.num_threads = self.device_count
self.exec_strategy.num_iteration_per_drop_scope = 100
# The default build_strategy used for PaddingRNN.
# You can change it in set_customed_config.
self.build_strategy = fluid.BuildStrategy()
self.build_strategy.enable_inplace = True
self.build_strategy.memory_optimize = False
self.build_strategy.fuse_all_optimizer_ops = True
# CPU executor is used for PaddingRNN default.
# You can change to CUDA executor in set_customed_config.
self.exe = Executor(fluid.CPUPlace())
def set_customed_config(self):
# This function will be called before training.
# You can override the function to set your own config.
pass
def _prepare_program(self, config, parallel=True):
self.main_program = fluid.Program()
self.startup_program = fluid.Program()
self.startup_program.random_seed = config.random_seed
......@@ -497,8 +478,7 @@ class PaddingRNNTestBase(unittest.TestCase):
num_steps=config.num_steps,
init_scale=config.init_scale,
dropout=config.dropout,
rnn_model=config.rnn_model,
use_py_reader=False)
rnn_model=config.rnn_model)
self.loss, self.last_hidden, self.last_cell, self.feed_order = res_vars
fluid.clip.set_gradient_clip(
......@@ -515,28 +495,19 @@ class PaddingRNNTestBase(unittest.TestCase):
optimizer = fluid.optimizer.SGD(
learning_rate=self.learning_rate)
optimizer.minimize(self.loss)
self.exe = Executor(fluid.CPUPlace())
self.exe.run(self.startup_program)
if parallel:
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = self.device_count
exec_strategy.num_iteration_per_drop_scope = 100
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = False
build_strategy.fuse_all_optimizer_ops = True
self.train_program = fluid.compiler.CompiledProgram(
self.main_program).with_data_parallel(
loss_name=self.loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
build_strategy=self.build_strategy,
exec_strategy=self.exec_strategy)
else:
self.train_program = self.main_program
def generate_init_data(self):
def _generate_init_data(self):
init_hidden = np.zeros(
(self.config.num_layers, self.config.batch_size,
self.config.hidden_size),
......@@ -547,19 +518,19 @@ class PaddingRNNTestBase(unittest.TestCase):
dtype='float32')
return init_hidden, init_cell
def generate_new_lr(self, epoch_id=0, device_count=1):
def _generate_new_lr(self, epoch_id=0, device_count=1):
new_lr = self.config.base_learning_rate * (self.config.lr_decay**max(
epoch_id + 1 - self.config.epoch_start_decay, 0.0))
lr = np.ones((self.device_count), dtype='float32') * new_lr
return lr
def prepare_input(self,
batch,
init_hidden=None,
init_cell=None,
epoch_id=0,
with_lr=True,
device_count=1):
def _prepare_input(self,
batch,
init_hidden=None,
init_cell=None,
epoch_id=0,
with_lr=True,
device_count=1):
x, y = batch
x = x.reshape((-1, self.config.num_steps, 1))
y = y.reshape((-1, 1))
......@@ -572,19 +543,19 @@ class PaddingRNNTestBase(unittest.TestCase):
if init_cell is not None:
res['init_cell'] = init_cell
if with_lr:
res['learning_rate'] = self.generate_new_lr(epoch_id, device_count)
res['learning_rate'] = self._generate_new_lr(epoch_id, device_count)
return res
def train_an_epoch(self, epoch_id, batch_times, use_program_cache=True):
def _train_an_epoch(self, epoch_id, use_program_cache=True):
train_data_iter = self.reader.get_data_iter(self.config)
total_loss = 0
iters = 0
init_hidden, init_cell = self.generate_init_data()
init_hidden, init_cell = self._generate_init_data()
ppl = np.zeros(shape=(0))
for batch_id, batch in enumerate(train_data_iter):
input_data_feed = self.prepare_input(
input_data_feed = self._prepare_input(
batch,
init_hidden=init_hidden,
init_cell=init_cell,
......@@ -592,7 +563,6 @@ class PaddingRNNTestBase(unittest.TestCase):
with_lr=True,
device_count=self.device_count)
batch_start_time = time.time()
fetch_outs = self.exe.run(self.train_program,
feed=input_data_feed,
fetch_list=[
......@@ -601,8 +571,6 @@ class PaddingRNNTestBase(unittest.TestCase):
self.last_cell.name
],
use_program_cache=use_program_cache)
batch_time = time.time() - batch_start_time
batch_times.append(batch_time)
cost_train = np.array(fetch_outs[0])
lr = np.array(fetch_outs[1])
......@@ -617,17 +585,13 @@ class PaddingRNNTestBase(unittest.TestCase):
return ppl
def train(self, config, parallel=True, use_program_cache=True):
self.set_customed_config()
self.config = config
self.prepare_program(config, parallel)
total_time = 0.0
self._prepare_program(config, parallel)
ppl = np.zeros(shape=(0, config.batch_size))
for epoch_id in range(config.max_epoch):
batch_times = []
epoch_start_time = time.time()
train_ppl = self.train_an_epoch(epoch_id, batch_times,
use_program_cache)
epoch_time = time.time() - epoch_start_time
total_time += epoch_time
train_ppl = self._train_an_epoch(epoch_id, use_program_cache)
ppl = np.append(ppl, train_ppl)
return ppl
......
......@@ -55,7 +55,7 @@ class TestFusedEmbeddingSeqPoolOp(OpTest):
if ver.mkl() == "ON" and 'Linux' in platform.platform():
self.attrs = {'is_sparse': False}
self.check_grad(
['W'], 'Out', no_grad_set=('Ids'), check_dygraph=False)
['W'], 'Out', no_grad_set=['Ids'], check_dygraph=False)
class TestLookupTableOpWithPadding(TestFusedEmbeddingSeqPoolOp):
......@@ -89,7 +89,7 @@ class TestLookupTableOpWithPadding(TestFusedEmbeddingSeqPoolOp):
self.attrs = {'padding_idx': int(padding_idx), 'is_sparse': False}
# TODO(wangzhongpu): support lod in dygraph mode
self.check_grad(
['W'], 'Out', no_grad_set=('Ids'), check_dygraph=False)
['W'], 'Out', no_grad_set=['Ids'], check_dygraph=False)
class TestFusedEmbeddingSeqPoolApi(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册