提交 062556f9 编写于 作者: L Luo Tao

Merge branch 'develop' into unify

...@@ -46,6 +46,7 @@ ...@@ -46,6 +46,7 @@
| tianbingsz | Tian-Bing Xu | | tianbingsz | Tian-Bing Xu |
| tpatejko | Tomasz Patejko | | tpatejko | Tomasz Patejko |
| typhoonzero | Yi Wu | | typhoonzero | Yi Wu |
| velconia | Qi-Yang Min |
| wanghaoshuang | Hao-Shuang Wang | | wanghaoshuang | Hao-Shuang Wang |
| wangyang59 | Yang Wang | | wangyang59 | Yang Wang |
| wangzhen-nlp | Zhen Wang | | wangzhen-nlp | Zhen Wang |
......
## Motivation ## Motivation
There is a ```gap``` between the ```Program``` defined by There is a `gap` between the `Program` defined by
user and the ```Executable``` that can be scheduled user and the `Executable` that can be scheduled
efficiently on heterogeneous hardware, either locally efficiently on heterogeneous hardware, either locally
or distributedly. or distributedly.
Usually, the ```gap``` is bridged by Usually, the `gap` is bridged by
* A serious transformations with defined order. * A serious transformations with defined order.
* These transformations usually involve * These transformations usually involve
```insert, delete, clustering, split, dependency analysis```. `insert, delete, clustering, split, dependency analysis`.
* Has a simple way to verify and debug each transformation. * Has a simple way to verify and debug each transformation.
...@@ -38,44 +38,44 @@ design below. ...@@ -38,44 +38,44 @@ design below.
#### Node #### Node
```Node``` represents an operation that performs some computation or `Node` represents an operation that performs some computation or
a variable that is input or output of operation. a variable that is input or output of operation.
```Node```s are connected to other ```Node```s via inputs and outputs. `Node`s are connected to other `Node`s via inputs and outputs.
Other properties (maybe device placement information) can be added Other properties (maybe device placement information) can be added
to ```Node``` in the future if it's a to `Node` in the future if it's a
common requirement of many other ```Pass```es. Otherwise, it should live common requirement of many other `Pass`es. Otherwise, it should live
in a ```Node``` wrapper class that is private to some ```Pass``` or be in a `Node` wrapper class that is private to some `Pass` or be
a local member of a ```Pass```. a local member of a `Pass`.
#### Graph #### Graph
```Graph``` contains a list of ```Node```s, which are connected to `Graph` contains a list of `Node`s, which are connected to
each other via inputs and outputs. each other via inputs and outputs.
TODO: Better definitions for the graph. TODO: Better definitions for the graph.
```Graph``` can also contain ```Attribute```s. ```Attribute```s `Graph` can also contain `Attribute`s. `Attribute`s
can be ``any`` thing. For example, it can be a list of "wraper" can be `any` thing. For example, it can be a list of "wraper"
nodes. The ```wrapper``` nodes compose ```Node```s and provide nodes. The `wrapper` nodes compose `Node`s and provide
helper method for execution or transformation. ```Attribute``` helper method for execution or transformation. `Attribute`
can also contain other things that describe some properties of can also contain other things that describe some properties of
the ```Graph``` or ```Graph``` nodes. ```Attribute``` can be passed the `Graph` or `Graph` nodes. `Attribute` can be passed
across ```Pass```. However, it should be used with care. across `Pass`. However, it should be used with care.
#### Pass #### Pass
```Pass``` represents a transformation of ```Graph```. Its input `Pass` represents a transformation of `Graph`. Its input
is a ```Graph``` and its output is also a ```Graph```. For example, is a `Graph` and its output is also a `Graph`. For example,
a ```Pass``` can simply print out the ```Graph```. A ```Pass``` a `Pass` can simply print out the `Graph`. A `Pass`
can also fuse some ```Graph```'s ```Node```s. can also fuse some `Graph`'s `Node`s.
#### Optimize #### Optimize
```Optimize``` contains a series of ```Pass``` with defined order. `Optimize` contains a series of `Pass` with defined order.
```Optimize``` transforms a ```Graph``` that only contains raw `Optimize` transforms a `Graph` that only contains raw
modeling logic to a ```Graph``` that can be run efficiently while modeling logic to a `Graph` that can be run efficiently while
maintaining the original modeling logic. maintaining the original modeling logic.
......
...@@ -22,7 +22,12 @@ endif() ...@@ -22,7 +22,12 @@ endif()
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
nv_test(mixed_vector_test SRCS mixed_vector_test.cu DEPS place memory device_context tensor) if(WITH_GPU)
nv_test(mixed_vector_test SRCS mixed_vector_test.cc mixed_vector_test.cu DEPS place memory device_context tensor)
else()
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
endif()
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
......
...@@ -88,9 +88,8 @@ class BlockDesc { ...@@ -88,9 +88,8 @@ class BlockDesc {
OpDesc *InsertOp(size_t index); OpDesc *InsertOp(size_t index);
/* /*
* Remove Op and its input/output variables. * Only remove op itself,
* Note that for either input or output variable, if it is also an input or * do nothing to its input and output variables
* output variable of other ops, we should remain it.
*/ */
void RemoveOp(size_t s, size_t e); void RemoveOp(size_t s, size_t e);
......
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto) cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph graph_helper)
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -67,7 +68,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -67,7 +68,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
} }
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node,
size_t place_id) const { size_t place_id) const {
auto p = places_[place_id]; auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
...@@ -92,12 +94,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node, ...@@ -92,12 +94,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
} }
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const { const std::vector<ir::Node *> &nodes) const {
std::vector<std::string> send_vars; std::vector<std::string> send_vars;
// since parameters are all in block 0, // since parameters are all in block 0,
// it's enough to only scan send ops in block 0 // it's enough to only scan send ops in block 0
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
OpDesc *op = node->Op(); OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find send op, // TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string // instead of the the hard code string
...@@ -112,10 +113,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( ...@@ -112,10 +113,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
} }
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars( std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const { const std::vector<ir::Node *> &nodes) const {
std::vector<std::string> recv_vars; std::vector<std::string> recv_vars;
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
OpDesc *op = node->Op(); OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find recv op, // TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string // instead of the hard code string
...@@ -170,6 +170,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -170,6 +170,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const { const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0; int64_t numel_sum = 0;
for (auto var_name : var_names) { for (auto var_name : var_names) {
if (all_vars_.find(var_name) == all_vars_.end()) continue;
auto var_desc = all_vars_.at(var_name); auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc); PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape()); auto dim = framework::make_ddim(var_desc->GetShape());
...@@ -186,19 +187,70 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -186,19 +187,70 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
return dev_id; return dev_id;
} }
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( // Topology sort the graph nodes from inputs to outputs.
std::unique_ptr<Graph> graph) const { // Since SSAGraphBuilder depends on forward/backward nodes to assign devices
// Rebuild the graph structure. // to parameter/gradients before optimizer ops, topo sort is insufficient. (
auto nodes = std::move(graph->nodes); // some optimizer ops might not depend on any nodes), we manually move all
graph->nodes.clear(); // optimizer nodes after last backward nodes.
// However, the assumption by SSAGraphBuilder should be relaxed in the future.
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
size_t last_backward = 0;
for (size_t i = 0; i < ret.size(); ++i) {
if (boost::get<int>(
ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kBackward)) {
last_backward = i;
}
}
std::vector<ir::Node *> optimize_ops;
std::vector<ir::Node *> sorted_ret;
for (size_t i = 0; i < ret.size(); ++i) {
if (i < last_backward) {
if (boost::get<int>(ret[i]->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kOptimize)) {
optimize_ops.push_back(ret[i]);
} else {
sorted_ret.push_back(ret[i]);
}
} else if (i == last_backward) {
sorted_ret.push_back(ret[i]);
// Verify that no operations before optimize ops depends on optimize ops.
std::unordered_set<ir::Node *> optimize_set(optimize_ops.begin(),
optimize_ops.end());
for (ir::Node *n : sorted_ret) {
for (ir::Node *in : n->inputs) {
for (ir::Node *pre_n : in->inputs) {
PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(),
"optimize operations cannot be depended by forward "
"or backward node %s -> %s",
pre_n->Name(), n->Name());
}
}
}
sorted_ret.insert(sorted_ret.end(), optimize_ops.begin(),
optimize_ops.end());
} else {
sorted_ret.push_back(ret[i]);
}
}
return sorted_ret;
}
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std::unique_ptr<ir::Graph> graph) const {
// Give the topology sort order and rebuild the graph structure.
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph;
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->NodeType() == ir::Node::Type::kVariable) { if (node->NodeType() == ir::Node::Type::kVariable && node->Var()) {
all_vars_.emplace(node->Name(), node->Var()); all_vars_.emplace(node->Name(), node->Var());
} }
} }
Graph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast; std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8 // We cannot invoke resize. It is a bug of GCC 4.8
...@@ -207,9 +259,9 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -207,9 +259,9 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
result.Set("ops", new GraphOps); result.Set("ops", new GraphOps);
// find send/recv vars so that we can place the distributed training // find send/recv vars so that we can place the distributed training
// realted op in the place 0 // related op in the place 0
auto send_vars = FindDistTrainSendVars(nodes); auto send_vars = FindDistTrainSendVars(sorted_ops);
auto recv_vars = FindDistTrainRecvVars(nodes); auto recv_vars = FindDistTrainRecvVars(sorted_ops);
std::vector<std::unordered_set<std::string>> bcast_var_name_set; std::vector<std::unordered_set<std::string>> bcast_var_name_set;
bcast_var_name_set.resize(places_.size()); bcast_var_name_set.resize(places_.size());
...@@ -217,22 +269,18 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -217,22 +269,18 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
size_t cur_device_id = 0; size_t cur_device_id = 0;
bool is_forwarding = true; bool is_forwarding = true;
// NOTE: Currently, passes before SSAGraphBuilder cannot reorder for (ir::Node *node : sorted_ops) {
// forward, backward nodes. E.g. you can't append an forward node
// at the end of the node list.
// TODO(panyx0718): FIXME: Needs to sort by forward->backward order.
for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
if (boost::get<int>( if (boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
CreateRPCOp(&result, node.get()); CreateRPCOp(&result, node);
} else if (IsDistTrainOp(node.get(), send_vars, recv_vars)) { } else if (IsDistTrainOp(node, send_vars, recv_vars)) {
CreateDistTrainOp(&result, node.get()); CreateDistTrainOp(&result, node);
} else if (IsScaleLossOp(node.get())) { } else if (IsScaleLossOp(node)) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ != if (strategy_.gradient_scale_ !=
BuildStrategy::GradientScaleStrategy::kCustomized) { BuildStrategy::GradientScaleStrategy::kCustomized) {
// TODO(paddle-dev): Why is there no input for this op_handle?
CreateScaleLossGradOp(&result); CreateScaleLossGradOp(&result);
} }
// This assumes the backward generating code will ensure IsScaleLossOp // This assumes the backward generating code will ensure IsScaleLossOp
...@@ -241,24 +289,23 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -241,24 +289,23 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
// the block. // the block.
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(node.get()); int op_dev_id = GetOpDeviceID(node);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node.get(), op_dev_id); CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
var_name_on_devices_.emplace(n->Name(), op_dev_id); var_name_on_devices_.emplace(n->Name(), op_dev_id);
} }
} else { } else {
// This op runs on all devices, and its output may have parameter's // This op runs on all devices, and its output may have parameter's
// gradients. // gradients.
// TODO(paddle-dev): Why is so special about "read" op?
if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
node->Op()->SetAttr("throw_eof_exp", false); node->Op()->SetAttr("throw_eof_exp", false);
CreateComputationalOps(&result, node.get(), places_.size()); CreateComputationalOps(&result, node, places_.size());
// TODO(paddle-dev): builder shouldn't depend on the out logic of
// a specific op.
const auto &data_var_names = node->Op()->Output("Out"); const auto &data_var_names = node->Op()->Output("Out");
InsertDataBalanceOp(&result, data_var_names); InsertDataBalanceOp(&result, data_var_names);
} else { } else {
CreateComputationalOps(&result, node.get(), places_.size()); CreateComputationalOps(&result, node, places_.size());
} }
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
...@@ -322,17 +369,17 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -322,17 +369,17 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
} }
} }
} }
/* /*
Dependency graph has been constructed. However, there are still data Dependency graph has been constructed. However, there are still data
hazards need to be handled. hazards need to be handled.
*/ */
PolishGraphToSupportDataHazards(&result); PolishGraphToSupportDataHazards(&result);
/* /*
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
PADDLE_ENFORCE(!ir::HasCircle(result));
return graph; return graph;
} }
...@@ -357,7 +404,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( ...@@ -357,7 +404,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
#endif #endif
} }
void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
const std::string &p_name, const std::string &p_name,
size_t src_dev_id) const { size_t src_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -387,7 +434,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, ...@@ -387,7 +434,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
} }
} }
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
ir::Node *node, ir::Node *node,
int dev_id) const { int dev_id) const {
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(
...@@ -396,7 +443,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, ...@@ -396,7 +443,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
CreateOpHandleIOs(result, node, dev_id); CreateOpHandleIOs(result, node, dev_id);
} }
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
const std::string &og) const { const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle( result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
...@@ -426,7 +473,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, ...@@ -426,7 +473,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
} }
void MultiDevSSAGraphBuilder::InsertDataBalanceOp( void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
Graph *result, const std::vector<std::string> &datas) const { ir::Graph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle( result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
...@@ -479,8 +526,8 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { ...@@ -479,8 +526,8 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1]); int dev_id = GetVarDeviceID(param_grad[1]);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0]); node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id; return dev_id;
} }
...@@ -489,7 +536,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { ...@@ -489,7 +536,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
return got == var_name_on_devices_.end() ? -1 : got->second; return got == var_name_on_devices_.end() ? -1 : got->second;
} }
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -519,7 +566,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { ...@@ -519,7 +566,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
} }
} }
void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
ir::Node *node, ir::Node *node,
size_t num_places) const { size_t num_places) const {
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
...@@ -531,7 +578,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, ...@@ -531,7 +578,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
} }
} }
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
const std::string &og, const std::string &og,
int dst_dev_id) const { int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -564,12 +611,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, ...@@ -564,12 +611,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
// Find the first occurence of `prev_op_name` and make current `op` depend // Find the first occurence of `prev_op_name` and make current `op` depend
// on it. // on it.
void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->Get<GraphOps>("ops")) { for (auto &prev_op : result->Get<GraphOps>("ops")) {
if (prev_op->Name() == prev_op_name) { if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle( auto *dep_var = new DummyVarHandle(result->CreateControlDepVar());
result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
prev_op->AddOutput(dep_var); prev_op->AddOutput(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var); result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
op->AddInput(dep_var); op->AddInput(dep_var);
...@@ -577,7 +623,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, ...@@ -577,7 +623,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
} }
} }
void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
ir::Node *node) const { ir::Node *node) const {
int op_dev_id = -1; int op_dev_id = -1;
std::vector<std::string> input_var_names; std::vector<std::string> input_var_names;
...@@ -591,6 +637,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, ...@@ -591,6 +637,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
if (node->Op()->Type() == "split_byref" || if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows") { node->Op()->Type() == "split_selected_rows") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(input_var_names[0]); op_dev_id = GetVarDeviceID(input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
...@@ -624,10 +671,14 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, ...@@ -624,10 +671,14 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const { void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
ir::Node *node) const {
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
// split_byref op // split_byref op
// so that we can balance the variable blocks to all the pserver // so that we can balance the variable blocks to all the pserver
......
...@@ -46,11 +46,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -46,11 +46,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const BuildStrategy &strategy); const BuildStrategy &strategy);
#endif #endif
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override; std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override; int GetVarDeviceID(const std::string &varname) const override;
private: private:
void CreateOpHandleIOs(Graph *result, ir::Node *node, size_t device_id) const; void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
...@@ -64,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -64,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(Graph *result, ir::Node *node) const; void CreateRPCOp(ir::Graph *result, ir::Node *node) const;
void CreateDistTrainOp(Graph *result, ir::Node *node) const; void CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
...@@ -74,21 +76,22 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -74,21 +76,22 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<std::string> &recv_vars) const; const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars( std::vector<std::string> FindDistTrainSendVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const; const std::vector<ir::Node *> &nodes) const;
std::vector<std::string> FindDistTrainRecvVars( std::vector<std::string> FindDistTrainRecvVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const; const std::vector<ir::Node *> &nodes) const;
void ConnectOp(Graph *result, OpHandleBase *op, void ConnectOp(ir::Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const; const std::string &prev_op_name) const;
void CreateComputationalOps(Graph *result, ir::Node *node, void CreateComputationalOps(ir::Graph *result, ir::Node *node,
size_t num_places) const; size_t num_places) const;
void CreateScaleLossGradOp(Graph *result) const; void CreateScaleLossGradOp(ir::Graph *result) const;
VarHandle *CreateReduceOp(Graph *result, const std::string &og, VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const; int dst_dev_id) const;
void CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const; void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const;
bool IsParameterGradientOnce( bool IsParameterGradientOnce(
const std::string &og, const std::string &og,
...@@ -96,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -96,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int GetOpDeviceID(ir::Node *node) const; int GetOpDeviceID(ir::Node *node) const;
void InsertAllReduceOp(Graph *result, const std::string &og) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
void InsertDataBalanceOp(Graph *result, void InsertDataBalanceOp(ir::Graph *result,
const std::vector<std::string> &datas) const; const std::vector<std::string> &datas) const;
void CreateBroadcastOp(Graph *result, const std::string &p_name, void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -33,7 +34,7 @@ void RPCOpHandle::RunImpl() { ...@@ -33,7 +34,7 @@ void RPCOpHandle::RunImpl() {
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString() // FIXME(Yancey1989): need a better solution instead of use DebugString()
if (in->DebugString() == "dummy") { // HACK if (ir::IsControlDepVar(*in->Node())) { // HACK
continue; continue;
} }
if (in->GeneratedOp()) { if (in->GeneratedOp()) {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) { for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) { if (name_pair.second.size() <= 1) {
...@@ -36,9 +36,18 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { ...@@ -36,9 +36,18 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
// Read Write is the same op. // Read Write is the same op.
continue; continue;
} }
bool has_dep = false;
for (auto *r_out : read_op->Outputs()) {
for (auto *w_in : write_op->Inputs()) {
if (r_out->Node() == w_in->Node()) {
has_dep = true;
break;
}
}
}
if (has_dep) continue;
auto *dep_var = new DummyVarHandle( auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
read_op->AddOutput(dep_var); read_op->AddOutput(dep_var);
write_op->AddInput(dep_var); write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var); graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
...@@ -49,7 +58,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { ...@@ -49,7 +58,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
} }
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
Graph *graph, ir::Node *node, const platform::Place &place, ir::Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset]; auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
auto &var_holder = var_holders[node->Name()]; auto &var_holder = var_holders[node->Name()];
...@@ -70,7 +79,7 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ...@@ -70,7 +79,7 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
return var; return var;
} }
void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node, ir::Node *new_node,
const platform::Place &place, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
...@@ -82,13 +91,12 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, ...@@ -82,13 +91,12 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
for (auto &op : graph->Get<GraphOps>("ops")) { for (auto &op : graph->Get<GraphOps>("ops")) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
auto *dummy_leaf = new DummyVarHandle( auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf); graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
......
...@@ -57,26 +57,23 @@ class SSAGraphBuilder : public ir::Pass { ...@@ -57,26 +57,23 @@ class SSAGraphBuilder : public ir::Pass {
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
protected: protected:
/** /*
* We only handle write after read(WAR), since it should not have a write Dependency graph has been constructed. However, there are still data
* after write in program. If there are write after write operators, we need hazards need to be handled.
* prune them. */
* static void PolishGraphToSupportDataHazards(ir::Graph *graph);
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/ static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
static void PolishGraphToSupportDataHazards(Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, ir::Node *node,
const platform::Place &place, const platform::Place &place,
size_t place_offset); size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle, // Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph // which belongs to graph
static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle, static void CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node, const platform::Place &place, ir::Node *new_node, const platform::Place &place,
size_t place_offset); size_t place_offset);
static void AddOutputToLeafOps(Graph *graph); static void AddOutputToLeafOps(ir::Graph *graph);
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const { bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars; std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars; std::unordered_set<VarHandleBase *> ready_vars;
......
...@@ -28,7 +28,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -28,7 +28,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std::unique_ptr<SSAGraphBuilder>&& builder) std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {} : builder_(std::move(builder)) {}
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override { std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph)); auto new_graph = builder_->Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get())); PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return new_graph; return new_graph;
...@@ -38,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -38,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
return builder_->GetVarDeviceID(var_name); return builder_->GetVarDeviceID(var_name);
} }
bool IsValidGraph(const Graph* graph) const; bool IsValidGraph(const ir::Graph* graph) const;
private: private:
std::unique_ptr<SSAGraphBuilder> builder_; std::unique_ptr<SSAGraphBuilder> builder_;
......
...@@ -21,7 +21,7 @@ namespace framework { ...@@ -21,7 +21,7 @@ namespace framework {
namespace details { namespace details {
template <typename Callback> template <typename Callback>
static inline void IterAllVar(const Graph &graph, Callback callback) { static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>("vars")) { for (auto &each : graph.Get<GraphVars>("vars")) {
for (auto &pair1 : each) { for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) { for (auto &pair2 : pair1.second) {
...@@ -35,7 +35,7 @@ static inline void IterAllVar(const Graph &graph, Callback callback) { ...@@ -35,7 +35,7 @@ static inline void IterAllVar(const Graph &graph, Callback callback) {
} }
} }
void GraphvizSSAGraphPrinter::Print(const Graph &graph, void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
std::ostream &sout) const { std::ostream &sout) const {
size_t var_id = 0; size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars; std::unordered_map<const VarHandleBase *, size_t> vars;
......
...@@ -25,12 +25,12 @@ namespace details { ...@@ -25,12 +25,12 @@ namespace details {
class SSAGraphPrinter { class SSAGraphPrinter {
public: public:
virtual ~SSAGraphPrinter() {} virtual ~SSAGraphPrinter() {}
virtual void Print(const Graph& graph, std::ostream& sout) const = 0; virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0;
}; };
class GraphvizSSAGraphPrinter : public SSAGraphPrinter { class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
public: public:
void Print(const Graph& graph, std::ostream& sout) const override; void Print(const ir::Graph& graph, std::ostream& sout) const override;
}; };
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
...@@ -50,7 +50,8 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { ...@@ -50,7 +50,8 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_(std::move(sout)), stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {} stream_ref_(*stream_ptr_) {}
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override { std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph)); auto new_graph = builder_->Apply(std::move(graph));
printer_->Print(*new_graph, stream_ref_); printer_->Print(*new_graph, stream_ref_);
return new_graph; return new_graph;
......
...@@ -21,7 +21,8 @@ namespace framework { ...@@ -21,7 +21,8 @@ namespace framework {
namespace details { namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, std::unique_ptr<Graph> &&graph) const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph)
: graph_(std::move(graph)), : graph_(std::move(graph)),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr), : nullptr),
......
...@@ -40,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -40,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<Graph> &&graph); std::unique_ptr<ir::Graph> &&graph);
// Run a SSAGraph by a thread pool // Run a SSAGraph by a thread pool
// Use topological sort algorithm // Use topological sort algorithm
...@@ -53,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -53,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op); details::OpHandleBase *op);
private: private:
std::unique_ptr<Graph> graph_; std::unique_ptr<ir::Graph> graph_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
......
...@@ -26,7 +26,7 @@ std::string VarHandle::DebugString() const { ...@@ -26,7 +26,7 @@ std::string VarHandle::DebugString() const {
return ss.str(); return ss.str();
} }
std::string DummyVarHandle::DebugString() const { return "dummy"; } std::string DummyVarHandle::DebugString() const { return node_->Name(); }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
cc_library(node SRCS node.cc DEPS proto_desc) cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node) cc_library(graph SRCS graph.cc DEPS node)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node) cc_library(pass SRCS pass.cc DEPS graph node)
cc_test(graph_test SRCS graph_test.cc DEPS graph op_registry)
cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry)
...@@ -12,14 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,14 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir {
// NOTE(paddle-dev): This graph contains circle.
Graph::Graph(const ProgramDesc &program) : program_(program) { Graph::Graph(const ProgramDesc &program) : program_(program) {
VLOG(3) << "block in program:" << program_.Size(); VLOG(3) << "block in program:" << program_.Size();
std::unordered_map<std::string, VarDesc *> all_vars; std::unordered_map<std::string, VarDesc *> all_vars;
...@@ -27,40 +31,87 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -27,40 +31,87 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
all_vars.emplace(var->Name(), var); all_vars.emplace(var->Name(), var);
} }
std::map<std::string, ir::Node *> var_nodes; std::map<std::string, std::vector<ir::Node *>> var_nodes;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op); ir::Node *node = CreateOpNode(op);
// For input args, reuse the same var name if it was created before.
// Otherwise, create a new one.
for (auto &each_var_name : op->InputArgumentNames()) { for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr; ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) { if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name); var = var_nodes.at(each_var_name).back();
} else if (all_vars.count(each_var_name) != 0) { } else if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name)); var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name] = var; var_nodes[each_var_name].push_back(var);
} else { } else {
// TODO(paddle-dev): Seems some assumption doesn't hold? // Operation input var can be optional (dispensable). Which means
VLOG(3) << op->Type() // the operation doesn't really need the var at runtime. In this
<< " input var not in all_var list: " << each_var_name; // case, the no-existed var is ready at the beginning.
var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable); var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable);
var_nodes[each_var_name] = var; var_nodes[each_var_name].push_back(var);
} }
node->inputs.push_back(var); node->inputs.push_back(var);
var->outputs.push_back(node); var->outputs.push_back(node);
} }
// For output args, always create a new var.
for (auto &each_var_name : op->OutputArgumentNames()) { for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = nullptr; ir::Node *var = CreateVarNode(all_vars.at(each_var_name));
if (var_nodes.find(each_var_name) != var_nodes.end()) { var_nodes[each_var_name].push_back(var);
var = var_nodes.at(each_var_name);
} else {
var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name] = var;
}
node->outputs.push_back(var); node->outputs.push_back(var);
var->inputs.push_back(node); var->inputs.push_back(node);
} }
} }
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
for (auto &var : var_nodes) {
auto &versions = var.second;
if (versions.size() <= 1) continue;
auto it_new = versions.rbegin();
auto it_old = versions.rbegin();
++it_old;
for (; it_old != versions.rend(); it_new = it_old, ++it_old) {
ir::Node *write_op =
(*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0];
const auto &read_ops = (*it_old)->outputs;
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
// Read Write is the same op.
continue;
}
// 2 ops might have been connected via other vars.
bool has_dep = false;
for (ir::Node *r_out : read_op->outputs) {
for (ir::Node *w_in : write_op->inputs) {
if (r_out == w_in) {
has_dep = true;
break;
}
}
}
if (has_dep) continue;
ir::Node *dep_var = CreateControlDepVar();
read_op->outputs.push_back(dep_var);
dep_var->inputs.push_back(read_op);
write_op->inputs.push_back(dep_var);
dep_var->outputs.push_back(write_op);
}
}
}
}
bool IsControlDepVar(const ir::Node &var) {
return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos;
} }
} // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,13 +26,14 @@ limitations under the License. */ ...@@ -26,13 +26,14 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir {
class Graph { class Graph {
public: public:
explicit Graph(const ProgramDesc& program); explicit Graph(const ProgramDesc &program);
virtual ~Graph() { virtual ~Graph() {
for (auto& attr : attrs_) { for (auto &attr : attrs_) {
attr_dels_[attr.first](); attr_dels_[attr.first]();
} }
attrs_.clear(); attrs_.clear();
...@@ -40,12 +41,12 @@ class Graph { ...@@ -40,12 +41,12 @@ class Graph {
} }
template <typename AttrType> template <typename AttrType>
AttrType& Get(const std::string& attr_name) const { AttrType &Get(const std::string &attr_name) const {
return *boost::any_cast<AttrType*>(attrs_.at(attr_name)); return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} }
template <typename AttrType> template <typename AttrType>
void Set(const std::string& attr_name, AttrType* attr) { void Set(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0); PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() { attr_dels_[attr_name] = [attr, attr_name]() {
...@@ -54,29 +55,70 @@ class Graph { ...@@ -54,29 +55,70 @@ class Graph {
}; };
} }
ir::Node* CreateVarNode(VarDesc* var_desc) { const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
nodes.emplace_back(new ir::Node(var_desc));
return nodes.back().get(); // Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) {
return AddNode(new ir::Node(var_desc));
}
// Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) {
return AddNode(new ir::Node(op_desc));
} }
ir::Node* CreateOpNode(OpDesc* op_desc) { // Create a control dependency var that connects 2 operations. The
nodes.emplace_back(new ir::Node(op_desc)); // var doesn't hold any data. Other than that, it's no different from
return nodes.back().get(); // other var, considering dependency analysis.
ir::Node *CreateControlDepVar() {
// TODO(panyx0718): control var name should be really unique.
const std::string name = string::Sprintf(
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
} }
ir::Node* CreateEmptyNode(const std::string& name, ir::Node::Type type) { // A more free style way of creating a graph node. Mostly use for test
nodes.emplace_back(new ir::Node(name, type)); // or "copy" from another node. Avoid using it if possible.
return nodes.back().get(); ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
return AddNode(new ir::Node(name, type));
} }
std::vector<std::unique_ptr<ir::Node>> nodes; // Clear all node information of the graph and return the ownership of the
// nodes.
std::vector<std::unique_ptr<ir::Node>> ReleaseNodes() {
std::vector<std::unique_ptr<ir::Node>> ret;
for (auto &n : nodes_) {
ret.emplace_back(n.second.release());
}
nodes_.clear();
node_set_.clear();
return ret;
}
private: private:
// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
nodes_[node].reset(node);
node_set_.insert(node);
return node;
}
void RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
nodes_.erase(node);
}
// NOTE: program_ shouldn't be exposed to user. // NOTE: program_ shouldn't be exposed to user.
const ProgramDesc& program_; const ProgramDesc &program_;
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_;
}; };
bool IsControlDepVar(const ir::Node &var);
} // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
void SortHelper(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
ir::Node *node, std::unordered_set<ir::Node *> *visited,
std::vector<ir::Node *> *ret) {
visited->insert(node);
for (auto adj : adj_list.at(node)) {
if (visited->find(adj) == visited->end()) {
SortHelper(adj_list, adj, visited, ret);
}
}
VLOG(3) << "topology sort insert: " << node->Name()
<< reinterpret_cast<void *>(node) << " input " << node->inputs.size();
ret->push_back(node);
}
bool HasCircleHelper(
ir::Node *node,
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
std::unordered_set<ir::Node *> *visited,
std::unordered_set<ir::Node *> *in_trace) {
if (visited->find(node) == visited->end()) {
visited->insert(node);
in_trace->insert(node);
for (ir::Node *in : adj_list.at(node)) {
if (visited->find(in) == visited->end() &&
HasCircleHelper(in, adj_list, visited, in_trace)) {
return true;
} else if (in_trace->find(in) != in_trace->end()) {
return true;
}
}
}
in_trace->erase(node);
return false;
}
bool HasCircleInternal(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) {
std::unordered_set<ir::Node *> visited;
std::unordered_set<ir::Node *> in_trace;
for (auto &adj : adj_list) {
if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) {
return true;
}
}
return false;
}
} // namespace
bool HasCircle(const Graph &graph) {
return HasCircleInternal(BuildOperationAdjList(graph));
}
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildOperationAdjList(graph);
PADDLE_ENFORCE(!HasCircleInternal(adj_list));
std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret;
for (auto adj : adj_list) {
if (visited.find(adj.first) == visited.end()) {
SortHelper(adj_list, adj.first, &visited, &ret);
}
}
return ret;
}
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
for (auto &n : graph.Nodes()) {
if (n->NodeType() != ir::Node::Type::kOperation) continue;
if (adj_list.find(n) == adj_list.end()) {
adj_list[n] = std::unordered_set<ir::Node *>();
}
for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
adj_list[n].insert(adj_n);
VLOG(3) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var);
}
}
}
return adj_list;
}
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
namespace framework {
namespace ir {
// Test if the graph contains circle.
bool HasCircle(const Graph &graph);
// Topology Sort the operations in the graph from inputs to outputs.
// `graph` cannot contain circle.
std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
// Build an adjacency list of operations for the `graph`.
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph);
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
void BuildCircleGraph(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
o1->outputs.push_back(v1);
o1->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o1);
}
void BuildCircleGraph2(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
o2->outputs.push_back(v2);
o1->inputs.push_back(v2);
v2->inputs.push_back(o2);
v2->outputs.push_back(o1);
}
void BuildNoCircleGraph(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation);
ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation);
ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable);
ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable);
// o1->v1->o2
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
// o2->v2->o3
// o2->v2->o4
o2->outputs.push_back(v2);
o3->inputs.push_back(v2);
o4->inputs.push_back(v2);
v2->inputs.push_back(o2);
v2->outputs.push_back(o3);
v2->outputs.push_back(o4);
// o2->v3->o5
o2->outputs.push_back(v3);
o5->inputs.push_back(v3);
v3->inputs.push_back(o2);
v3->outputs.push_back(o5);
// o3-v4->o5
o3->outputs.push_back(v4);
o5->inputs.push_back(v4);
v4->inputs.push_back(o3);
v4->outputs.push_back(o5);
}
TEST(GraphHelperTest, Basic) {
ProgramDesc prog;
Graph g(prog);
BuildCircleGraph(&g);
ASSERT_TRUE(HasCircle(g));
Graph g2(prog);
BuildCircleGraph2(&g2);
ASSERT_TRUE(HasCircle(g2));
auto adj_list = BuildOperationAdjList(g2);
for (auto& adj : adj_list) {
auto& adj_set = adj.second;
if (adj.first->Name() == "op1") {
ASSERT_EQ((*adj_set.begin())->Name(), "op2");
} else if (adj.first->Name() == "op2") {
ASSERT_EQ((*adj_set.begin())->Name(), "op1");
} else {
ASSERT_TRUE(false);
}
}
Graph g3(prog);
BuildNoCircleGraph(&g3);
ASSERT_FALSE(HasCircle(g3));
auto sorted = TopologySortOperations(g3);
std::map<std::string, size_t> node_map;
for (size_t i = 0; i < sorted.size(); ++i) {
node_map[sorted[i]->Name()] = i;
}
ASSERT_EQ(node_map.at("op1"), 0);
ASSERT_EQ(node_map.at("op2"), 1);
ASSERT_TRUE(node_map.at("op3") < node_map.at("op5"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -76,6 +76,7 @@ TEST(GraphTest, Basic) { ...@@ -76,6 +76,7 @@ TEST(GraphTest, Basic) {
op->SetType("sum"); op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS);
...@@ -92,21 +93,22 @@ TEST(GraphTest, Basic) { ...@@ -92,21 +93,22 @@ TEST(GraphTest, Basic) {
ASSERT_EQ(proto::VarType::LOD_TENSOR, ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType()); prog.MutableBlock(0)->Var("test_out")->GetType());
std::unique_ptr<Graph> g(new Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ASSERT_EQ(g->nodes[0]->Name(), "sum"); std::vector<ir::Node *> nodes(g->Nodes().begin(), g->Nodes().end());
ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a"); for (ir::Node *n : nodes) {
ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b"); if (n->Name() == "sum") {
ASSERT_EQ(g->nodes[0]->inputs[2]->Name(), "test_c"); ASSERT_EQ(n->inputs.size(), 3);
ASSERT_EQ(g->nodes[0]->outputs[0]->Name(), "test_out"); ASSERT_EQ(n->outputs.size(), 1);
ASSERT_EQ(g->nodes[1]->Name(), "test_a"); } else if (n->Name() == "test_a" || n->Name() == "test_b" ||
ASSERT_EQ(g->nodes[1]->outputs[0]->Name(), "sum"); n->Name() == "test_c") {
ASSERT_EQ(g->nodes[2]->Name(), "test_b"); ASSERT_EQ(n->inputs.size(), 0);
ASSERT_EQ(g->nodes[2]->outputs[0]->Name(), "sum"); ASSERT_EQ(n->outputs.size(), 1);
ASSERT_EQ(g->nodes[3]->Name(), "test_c"); } else if (n->Name() == "test_out") {
ASSERT_EQ(g->nodes[3]->outputs[0]->Name(), "sum"); ASSERT_EQ(n->inputs.size(), 1);
ASSERT_EQ(g->nodes[4]->Name(), "test_out"); ASSERT_EQ(n->outputs.size(), 0);
ASSERT_EQ(g->nodes[4]->inputs[0]->Name(), "sum"); }
ASSERT_EQ(g->nodes.size(), 5); }
ASSERT_EQ(nodes.size(), 5);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,5 +15,9 @@ limitations under the License. */ ...@@ -15,5 +15,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
namespace paddle { namespace paddle {
namespace framework {} // namespace framework namespace framework {
namespace ir {
const char Node::kControlDepVarName[] = "__control_var";
} // namespace ir
} // namespace framework
} // namespace paddle } // namespace paddle
...@@ -27,6 +27,8 @@ namespace ir { ...@@ -27,6 +27,8 @@ namespace ir {
class Node { class Node {
public: public:
enum class Type { kOperation, kVariable }; enum class Type { kOperation, kVariable };
static const char kControlDepVarName[];
explicit Node(const std::string& name, Type type) explicit Node(const std::string& name, Type type)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
...@@ -50,6 +52,7 @@ class Node { ...@@ -50,6 +52,7 @@ class Node {
PADDLE_ENFORCE(type_ == Type::kVariable); PADDLE_ENFORCE(type_ == Type::kVariable);
return var_desc_; return var_desc_;
} }
OpDesc* Op() { OpDesc* Op() {
PADDLE_ENFORCE(type_ == Type::kOperation); PADDLE_ENFORCE(type_ == Type::kOperation);
return op_desc_; return op_desc_;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <initializer_list> #include <initializer_list>
#include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -386,13 +387,14 @@ template <typename T> ...@@ -386,13 +387,14 @@ template <typename T>
class CPUVector : public std::vector<T, std::allocator<T>> { class CPUVector : public std::vector<T, std::allocator<T>> {
public: public:
CPUVector() : std::vector<T>() {} CPUVector() : std::vector<T>() {}
CPUVector(size_t count, const T &value = T()) CPUVector(size_t count, const T &value = T()) // NOLINT
: std::vector<T>(count, value) {} : std::vector<T>(count, value) {}
CPUVector(std::initializer_list<T> init) : std::vector<T>(init) {} CPUVector(std::initializer_list<T> init) : std::vector<T>(init) {}
CPUVector(const std::vector<T> &other) : std::vector<T>(other) {} CPUVector(const std::vector<T> &other) : std::vector<T>(other) {} // NOLINT
explicit CPUVector(const CPUVector<T> &other) : std::vector<T>(other) {} CPUVector(const CPUVector<T> &other) : std::vector<T>(other) {}
CPUVector(CPUVector<T> &&other) : std::vector<T>(std::move(other)) {} CPUVector(CPUVector<T> &&other) : std::vector<T>(std::move(other)) {}
CPUVector(std::vector<T> &&other) : std::vector<T>(std::move(other)) {} CPUVector(std::vector<T> &&other) // NOLINT
: std::vector<T>(std::move(other)) {}
CPUVector &operator=(const CPUVector &other) { CPUVector &operator=(const CPUVector &other) {
this->assign(other.begin(), other.end()); this->assign(other.begin(), other.end());
return *this; return *this;
...@@ -410,8 +412,6 @@ class CPUVector : public std::vector<T, std::allocator<T>> { ...@@ -410,8 +412,6 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
return os; return os;
} }
void resize(size_t size) { this->resize(size); }
T &operator[](size_t id) { return this->at(id); } T &operator[](size_t id) { return this->at(id); }
const T &operator[](size_t id) const { return this->at(id); } const T &operator[](size_t id) const { return this->at(id); }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/mixed_vector.h"
template <typename T>
using vec = paddle::framework::Vector<T>;
TEST(mixed_vector, CPU_VECTOR) {
vec<int> tmp;
for (int i = 0; i < 10; ++i) {
tmp.push_back(i);
}
ASSERT_EQ(tmp.size(), 10UL);
vec<int> tmp2;
tmp2 = tmp;
ASSERT_EQ(tmp2.size(), 10UL);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp2[i], i);
ASSERT_EQ(tmp2[i], tmp[i]);
}
int cnt = 0;
for (auto& t : tmp2) {
ASSERT_EQ(t, cnt);
++cnt;
}
}
TEST(mixed_vector, InitWithCount) {
paddle::framework::Vector<int> vec(10, 10);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(vec[i], 10);
}
}
TEST(mixed_vector, ForEach) {
vec<int> tmp;
for (auto& v : tmp) {
VLOG(3) << v;
}
}
TEST(mixed_vector, Reserve) {
paddle::framework::Vector<int> vec;
vec.reserve(1);
vec.push_back(0);
vec.push_back(0);
vec.push_back(0);
}
TEST(mixed_vector, Resize) {
paddle::framework::Vector<int> vec;
vec.resize(1);
vec.push_back(0);
vec.push_back(0);
vec.push_back(0);
}
...@@ -11,7 +11,9 @@ ...@@ -11,7 +11,9 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <memory>
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -21,26 +23,6 @@ ...@@ -21,26 +23,6 @@
template <typename T> template <typename T>
using vec = paddle::framework::Vector<T>; using vec = paddle::framework::Vector<T>;
TEST(mixed_vector, CPU_VECTOR) {
vec<int> tmp;
for (int i = 0; i < 10; ++i) {
tmp.push_back(i);
}
ASSERT_EQ(tmp.size(), 10UL);
vec<int> tmp2;
tmp2 = tmp;
ASSERT_EQ(tmp2.size(), 10UL);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp2[i], i);
ASSERT_EQ(tmp2[i], tmp[i]);
}
int cnt = 0;
for (auto& t : tmp2) {
ASSERT_EQ(t, cnt);
++cnt;
}
}
static __global__ void multiply_10(int* ptr) { static __global__ void multiply_10(int* ptr) {
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
ptr[i] *= 10; ptr[i] *= 10;
...@@ -91,24 +73,3 @@ TEST(mixed_vector, MultiGPU) { ...@@ -91,24 +73,3 @@ TEST(mixed_vector, MultiGPU) {
ASSERT_EQ(tmp[i], i * 100); ASSERT_EQ(tmp[i], i * 100);
} }
} }
TEST(mixed_vector, InitWithCount) {
paddle::framework::Vector<int> vec(10, 10);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(vec[i], 10);
}
}
TEST(mixed_vector, ForEach) {
vec<int> tmp;
for (auto& v : tmp) {
}
}
TEST(mixed_vector, Reserve) {
paddle::framework::Vector<int> vec;
vec.reserve(1);
vec.push_back(0);
vec.push_back(0);
vec.push_back(0);
}
...@@ -132,7 +132,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -132,7 +132,7 @@ ParallelExecutor::ParallelExecutor(
#endif #endif
} }
builder_ = builder_factory.Create(); builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph(new Graph(main_program)); std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
graph = builder_->Apply(std::move(graph)); graph = builder_->Apply(std::move(graph));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
......
set -x
cd `dirname $0`
rm -rf build/ data/
set +x
# Add TRT tests # Add TRT tests
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry) DEPS tensorrt_engine operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
...@@ -13,3 +13,6 @@ nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc ...@@ -13,3 +13,6 @@ nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine activation_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine activation_op SERIAL)
nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Pool2dOp, IPoolingLayer in TRT. This Layer doesn't has weights.
*/
class Pool2dOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4)
<< "convert a fluid pool2d op to tensorrt pool2d layer without bias";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
std::string pool_type =
boost::get<std::string>(op_desc.GetAttr("pooling_type"));
std::vector<int> ksize =
boost::get<std::vector<int>>(op_desc.GetAttr("ksize"));
std::vector<int> strides =
boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
std::vector<int> paddings =
boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
const nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]);
const nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
PADDLE_ENFORCE_EQ(input1->getDimensions().nbDims, 3UL);
nvinfer1::PoolingType nv_pool_type = nvinfer1::PoolingType::kMAX;
if (pool_type == "max") {
nv_pool_type = nvinfer1::PoolingType::kMAX;
} else if (pool_type == "avg") {
nv_pool_type = nvinfer1::PoolingType::kAVERAGE;
} else {
PADDLE_THROW("TensorRT unsupported pooling type!");
}
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling,
*const_cast<nvinfer1::ITensor*>(input1),
nv_pool_type, nv_ksize);
PADDLE_ENFORCE_NOT_NULL(layer, "pool layer could not be created.");
layer->setStride(nv_strides);
layer->setPadding(nv_paddings);
auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) {
engine_->DeclareOutput(output_name);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(pool2d);
REGISTER_TRT_OP_CONVERTER(pool2d, Pool2dOpConverter);
...@@ -37,7 +37,7 @@ TEST(ReluOpConverter, main) { ...@@ -37,7 +37,7 @@ TEST(ReluOpConverter, main) {
validator.SetOp(*desc.Proto()); validator.SetOp(*desc.Proto());
LOG(INFO) << "execute"; LOG(INFO) << "execute";
validator.Execute(1); validator.Execute(5);
} }
} // namespace tensorrt } // namespace tensorrt
......
...@@ -24,9 +24,8 @@ TEST(fc_op, test) { ...@@ -24,9 +24,8 @@ TEST(fc_op, test) {
std::unordered_set<std::string> parameters({"mul-Y"}); std::unordered_set<std::string> parameters({"mul-Y"});
framework::Scope scope; framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000); TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("mul-X", nvinfer1::Dims4(1, 10, 1, 1)); validator.DeclInputVar("mul-X", nvinfer1::Dims3(10, 1, 1));
validator.DeclParamVar("mul-Y", nvinfer1::Dims2(10, 2)); validator.DeclParamVar("mul-Y", nvinfer1::Dims2(10, 2));
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(1, 2)); validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(1, 2));
// Prepare Op description // Prepare Op description
...@@ -38,7 +37,7 @@ TEST(fc_op, test) { ...@@ -38,7 +37,7 @@ TEST(fc_op, test) {
validator.SetOp(*desc.Proto()); validator.SetOp(*desc.Proto());
validator.Execute(1); validator.Execute(10);
} }
} // namespace tensorrt } // namespace tensorrt
......
...@@ -23,7 +23,7 @@ namespace tensorrt { ...@@ -23,7 +23,7 @@ namespace tensorrt {
TEST(MulOpConverter, main) { TEST(MulOpConverter, main) {
framework::Scope scope; framework::Scope scope;
std::unordered_set<std::string> parameters; std::unordered_set<std::string> parameters;
TRTConvertValidation validator(10, parameters, scope, 1000); TRTConvertValidation validator(10, parameters, scope, 1000, false);
validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6)); validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6));
validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10)); validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10));
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10)); validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10));
...@@ -39,7 +39,7 @@ TEST(MulOpConverter, main) { ...@@ -39,7 +39,7 @@ TEST(MulOpConverter, main) {
validator.SetOp(*desc.Proto()); validator.SetOp(*desc.Proto());
LOG(INFO) << "execute"; LOG(INFO) << "execute";
validator.Execute(1); validator.Execute(2);
} }
} // namespace tensorrt } // namespace tensorrt
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <fstream>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(Pool2dOpConverter, main) {
framework::Scope scope;
std::unordered_set<std::string> parameters;
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
// The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W.
validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4));
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("pool2d");
desc.SetInput("X", {"pool2d-X"});
desc.SetOutput("Out", {"pool2d-Out"});
std::vector<int> ksize({2, 2});
std::vector<int> strides({2, 2});
std::vector<int> paddings({0, 0});
std::string pooling_t = "max";
desc.SetAttr("pooling_type", pooling_t);
desc.SetAttr("ksize", ksize);
desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings);
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
validator.Execute(3);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(pool2d);
...@@ -63,13 +63,16 @@ class TRTConvertValidation { ...@@ -63,13 +63,16 @@ class TRTConvertValidation {
public: public:
TRTConvertValidation() = delete; TRTConvertValidation() = delete;
TRTConvertValidation(int batch_size, TRTConvertValidation(int max_batch_size,
const std::unordered_set<std::string>& parameters, const std::unordered_set<std::string>& parameters,
framework::Scope& scope, // NOLINT framework::Scope& scope, // NOLINT
int workspace_size = 1 << 10) int workspace_size = 1 << 10, bool if_add_batch = true)
: parameters_(parameters), scope_(scope) { : parameters_(parameters),
scope_(scope),
if_add_batch_(if_add_batch),
max_batch_size_(max_batch_size) {
// create engine. // create engine.
engine_.reset(new TensorRTEngine(batch_size, workspace_size, &stream_)); engine_.reset(new TensorRTEngine(max_batch_size, workspace_size, &stream_));
engine_->InitNetwork(); engine_->InitNetwork();
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0); PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
...@@ -84,7 +87,7 @@ class TRTConvertValidation { ...@@ -84,7 +87,7 @@ class TRTConvertValidation {
// Declare a parameter varaible in the scope. // Declare a parameter varaible in the scope.
void DeclParamVar(const std::string& name, const nvinfer1::Dims& dims) { void DeclParamVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims); DeclVar(name, dims, true);
} }
void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) { void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) {
...@@ -92,12 +95,18 @@ class TRTConvertValidation { ...@@ -92,12 +95,18 @@ class TRTConvertValidation {
} }
// Declare a variable in a fluid Scope. // Declare a variable in a fluid Scope.
void DeclVar(const std::string& name, const nvinfer1::Dims& dims) { void DeclVar(const std::string& name, const nvinfer1::Dims& dims,
bool is_param = false) {
platform::CPUPlace place; platform::CPUPlace place;
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
// Init Fluid tensor. // Init Fluid tensor.
std::vector<int> dim_vec(dims.d, dims.d + dims.nbDims); std::vector<int> dim_vec(dims.d, dims.d + dims.nbDims);
// There is no batchsize in ITensor's shape, but We should add it to
// tensor's shape of fluid. If the variable is not parameter and the
// if_add_batch_ flag is true, add the max batchsize to dim_vec.
if (is_param != true && if_add_batch_ == true)
dim_vec.insert(dim_vec.begin(), max_batch_size_);
auto* x = scope_.Var(name); auto* x = scope_.Var(name);
auto* x_tensor = x->GetMutable<framework::LoDTensor>(); auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(dim_vec)); x_tensor->Resize(framework::make_ddim(dim_vec));
...@@ -131,6 +140,7 @@ class TRTConvertValidation { ...@@ -131,6 +140,7 @@ class TRTConvertValidation {
void Execute(int batch_size) { void Execute(int batch_size) {
// Execute Fluid Op // Execute Fluid Op
PADDLE_ENFORCE_LE(batch_size, max_batch_size_);
platform::CPUPlace place; platform::CPUPlace place;
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
op_->Run(scope_, place); op_->Run(scope_, place);
...@@ -149,9 +159,15 @@ class TRTConvertValidation { ...@@ -149,9 +159,15 @@ class TRTConvertValidation {
auto* var = scope_.FindVar(output); auto* var = scope_.FindVar(output);
auto tensor = var->GetMutable<framework::LoDTensor>(); auto tensor = var->GetMutable<framework::LoDTensor>();
framework::TensorToVector(*tensor, ctx, &fluid_out); framework::TensorToVector(*tensor, ctx, &fluid_out);
size_t fluid_out_size = fluid_out.size();
if (if_add_batch_ == true) {
fluid_out_size =
batch_size * (framework::product(tensor->dims()) / max_batch_size_);
}
// Compare two output // Compare two output
ASSERT_FALSE(fluid_out.empty()); ASSERT_FALSE(fluid_out.empty());
for (size_t i = 0; i < fluid_out.size(); i++) { for (size_t i = 0; i < fluid_out_size; i++) {
// Loose the threshold for CI in different machine model. // Loose the threshold for CI in different machine model.
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 2e-5); EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 2e-5);
} }
...@@ -167,6 +183,12 @@ class TRTConvertValidation { ...@@ -167,6 +183,12 @@ class TRTConvertValidation {
std::unique_ptr<framework::OpDesc> op_desc_; std::unique_ptr<framework::OpDesc> op_desc_;
const std::unordered_set<std::string>& parameters_; const std::unordered_set<std::string>& parameters_;
framework::Scope& scope_; framework::Scope& scope_;
// The ITensor of trt does not cotain the batch size,
// bug, in most cases, we need to set batch size for
// fluid's tensor shape. This variable indicates
// whether to add batch size to tensor shape of fluid.
bool if_add_batch_;
int max_batch_size_;
}; };
} // namespace tensorrt } // namespace tensorrt
......
...@@ -113,7 +113,7 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) { ...@@ -113,7 +113,7 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
ASSERT_EQ(y_cpu[1], 14.5); ASSERT_EQ(y_cpu[1], 14.5);
} }
TEST_F(TensorRTEngineTest, test_conv2d_temp) { TEST_F(TensorRTEngineTest, test_conv2d) {
// Weight in CPU memory. // Weight in CPU memory.
float raw_weight[9] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; float raw_weight[9] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
float raw_bias[1] = {0}; float raw_bias[1] = {0};
...@@ -146,6 +146,37 @@ TEST_F(TensorRTEngineTest, test_conv2d_temp) { ...@@ -146,6 +146,37 @@ TEST_F(TensorRTEngineTest, test_conv2d_temp) {
ASSERT_EQ(y_cpu[1], 6.0); ASSERT_EQ(y_cpu[1], 6.0);
} }
TEST_F(TensorRTEngineTest, test_pool2d) {
// Weight in CPU memory.
auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT,
nvinfer1::Dims3{1, 2, 2});
nvinfer1::PoolingType pool_t = nvinfer1::PoolingType::kAVERAGE;
auto* pool_layer =
TRT_ENGINE_ADD_LAYER(engine_, Pooling, *const_cast<nvinfer1::ITensor*>(x),
pool_t, nvinfer1::DimsHW{2, 2});
PADDLE_ENFORCE(pool_layer != nullptr);
pool_layer->setStride(nvinfer1::DimsHW{1, 1});
pool_layer->setPadding(nvinfer1::DimsHW{0, 0});
engine_->DeclareOutput(pool_layer, 0, "y");
engine_->FreezeNetwork();
ASSERT_EQ(engine_->engine()->getNbBindings(), 2);
float x_v[8] = {1.0, 2.0, 5.0, 0.0, 2.0, 3.0, 5.0, 10.0};
engine_->SetInputFromCPU("x", reinterpret_cast<void*>(&x_v),
8 * sizeof(float));
engine_->Execute(2);
LOG(INFO) << "to get output";
float* y_cpu = new float[2];
engine_->GetOutputInCPU("y", &y_cpu[0], 2 * sizeof(float));
ASSERT_EQ(y_cpu[0], 2.0);
ASSERT_EQ(y_cpu[1], 5.0);
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -20,9 +20,6 @@ limitations under the License. */ ...@@ -20,9 +20,6 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/inference/tests/test_helper.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
#ifdef PADDLE_WITH_MKLML
#include <omp.h>
#endif
DEFINE_string(model_path, "", "Directory of the inference model."); DEFINE_string(model_path, "", "Directory of the inference model.");
DEFINE_string(data_file, "", "File of input index data."); DEFINE_string(data_file, "", "File of input index data.");
...@@ -30,6 +27,7 @@ DEFINE_int32(repeat, 100, "Running the inference program repeat times"); ...@@ -30,6 +27,7 @@ DEFINE_int32(repeat, 100, "Running the inference program repeat times");
DEFINE_bool(prepare_vars, true, "Prepare variables before executor"); DEFINE_bool(prepare_vars, true, "Prepare variables before executor");
DEFINE_int32(num_threads, 1, "Number of threads should be used"); DEFINE_int32(num_threads, 1, "Number of threads should be used");
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
DECLARE_int32(paddle_num_threads);
inline double GetCurrentMs() { inline double GetCurrentMs() {
struct timeval time; struct timeval time;
...@@ -160,12 +158,7 @@ TEST(inference, nlp) { ...@@ -160,12 +158,7 @@ TEST(inference, nlp) {
std::unique_ptr<paddle::framework::Scope> scope( std::unique_ptr<paddle::framework::Scope> scope(
new paddle::framework::Scope()); new paddle::framework::Scope());
#ifdef PADDLE_WITH_MKLML paddle::platform::SetNumThreads(FLAGS_paddle_num_threads);
// only use 1 thread number per std::thread
omp_set_dynamic(0);
omp_set_num_threads(1);
paddle::platform::SetNumThreads(1);
#endif
double start_ms = 0, stop_ms = 0; double start_ms = 0, stop_ms = 0;
if (FLAGS_num_threads > 1) { if (FLAGS_num_threads > 1) {
......
...@@ -15,6 +15,10 @@ limitations under the License. */ ...@@ -15,6 +15,10 @@ limitations under the License. */
#include "paddle/fluid/memory/detail/buddy_allocator.h" #include "paddle/fluid/memory/detail/buddy_allocator.h"
#include "glog/logging.h" #include "glog/logging.h"
DEFINE_bool(free_idle_memory, false,
"If it is true, Paddle will try to free idle memory trunks during "
"running time.");
namespace paddle { namespace paddle {
namespace memory { namespace memory {
namespace detail { namespace detail {
...@@ -152,13 +156,14 @@ void BuddyAllocator::Free(void* p) { ...@@ -152,13 +156,14 @@ void BuddyAllocator::Free(void* p) {
pool_.insert( pool_.insert(
IndexSizeAddress(block->index(cache_), block->total_size(cache_), block)); IndexSizeAddress(block->index(cache_), block->total_size(cache_), block));
// Clean up if existing too much free memory if (FLAGS_free_idle_memory) {
// Clean up if existing too much free memory
// Prefer freeing fallback allocation first // Prefer freeing fallback allocation first
CleanIdleFallBackAlloc(); CleanIdleFallBackAlloc();
// Free normal allocation // Free normal allocation
CleanIdleNormalAlloc(); CleanIdleNormalAlloc();
}
} }
size_t BuddyAllocator::Used() { return total_used_; } size_t BuddyAllocator::Used() { return total_used_; }
......
...@@ -192,9 +192,9 @@ if(WITH_DISTRIBUTE) ...@@ -192,9 +192,9 @@ if(WITH_DISTRIBUTE)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
if(WITH_GRPC) if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else() else()
set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib) set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA) if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs) find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
...@@ -270,6 +270,7 @@ op_library(cos_sim_op DEPS cos_sim_functor) ...@@ -270,6 +270,7 @@ op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor) op_library(parallel_do_op DEPS executor)
op_library(unsqueeze_op DEPS reshape_op) op_library(unsqueeze_op DEPS reshape_op)
op_library(squeeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op)
op_library(extract_rows_op DEPS memory)
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
...@@ -77,7 +77,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -77,7 +77,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// cudnn 7 can support groups, no need to do it mannually // cudnn 7 can support groups, no need to do it mannually
// FIXME(typhoonzero): find a better way to disable groups // FIXME(typhoonzero): find a better way to disable groups
// rather than setting it to 1. // rather than setting it to 1.
PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount( CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
cudnn_conv_desc, groups)); cudnn_conv_desc, groups));
groups = 1; groups = 1;
#endif #endif
...@@ -129,7 +129,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -129,7 +129,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo)); workspace_size_limit, &algo));
...@@ -140,18 +140,18 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -140,18 +140,18 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
if (dev_ctx.GetComputeCapability() >= 70 && if (dev_ctx.GetComputeCapability() >= 70 &&
std::type_index(typeid(T)) == std::type_index(typeid(T)) ==
std::type_index(typeid(platform::float16))) { std::type_index(typeid(platform::float16))) {
PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
// Currently tensor core is only enabled using this algo // Currently tensor core is only enabled using this algo
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
} else { } else {
PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH)); cudnn_conv_desc, CUDNN_DEFAULT_MATH));
} }
#endif #endif
// get workspace size able to allocate // get workspace size able to allocate
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes)); cudnn_output_desc, algo, &workspace_size_in_bytes));
// It is possible for float16 on Volta GPU to allocate more memory than // It is possible for float16 on Volta GPU to allocate more memory than
...@@ -165,7 +165,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -165,7 +165,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv forward --------------------- // ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
cudnn_filter_desc, filter_data + i * group_offset_filter, cudnn_filter_desc, filter_data + i * group_offset_filter,
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
...@@ -218,7 +218,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -218,7 +218,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// cudnn 7 can support groups, no need to do it mannually // cudnn 7 can support groups, no need to do it mannually
// FIXME(typhoonzero): find a better way to disable groups // FIXME(typhoonzero): find a better way to disable groups
// rather than setting it to 1. // rather than setting it to 1.
PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount( CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
cudnn_conv_desc, groups)); cudnn_conv_desc, groups));
groups = 1; groups = 1;
#endif #endif
...@@ -273,7 +273,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -273,7 +273,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
if (input_grad) { if (input_grad) {
if (FLAGS_cudnn_deterministic) { if (FLAGS_cudnn_deterministic) {
PADDLE_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc, handle, cudnn_filter_desc,
// dyDesc: Handle to the previously initialized input // dyDesc: Handle to the previously initialized input
...@@ -289,7 +289,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -289,7 +289,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} }
PADDLE_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, cudnn_filter_desc, cudnn_output_grad_desc, handle, cudnn_filter_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size)); cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
...@@ -298,7 +298,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -298,7 +298,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
if (filter_grad) { if (filter_grad) {
if (FLAGS_cudnn_deterministic) { if (FLAGS_cudnn_deterministic) {
PADDLE_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_input_desc, cudnn_output_grad_desc, handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc, cudnn_conv_desc, cudnn_filter_desc,
...@@ -308,7 +308,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -308,7 +308,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} }
PADDLE_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
cudnn_filter_desc, filter_algo, &tmp_size)); cudnn_filter_desc, filter_algo, &tmp_size));
...@@ -326,7 +326,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -326,7 +326,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, handle, &alpha, cudnn_filter_desc,
filter_data + i * group_offset_filter, cudnn_output_grad_desc, filter_data + i * group_offset_filter, cudnn_output_grad_desc,
output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo, output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
...@@ -339,7 +339,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -339,7 +339,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace()); T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset filter_grad. // Because beta is zero, it is unnecessary to reset filter_grad.
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
cudnn_output_grad_desc, output_grad_data + i * group_offset_out, cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
cudnn_conv_desc, filter_algo, cudnn_workspace, cudnn_conv_desc, filter_algo, cudnn_workspace,
......
...@@ -87,7 +87,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -87,7 +87,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
// Get the algorithm // Get the algorithm
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
// dxDesc: Handle to the previously initialized output tensor // dxDesc: Handle to the previously initialized output tensor
// descriptor. // descriptor.
...@@ -95,7 +95,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -95,7 +95,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
workspace_size_limit, &algo)); workspace_size_limit, &algo));
// get workspace size able to allocate // get workspace size able to allocate
PADDLE_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes)); cudnn_output_desc, algo, &workspace_size_in_bytes));
...@@ -110,7 +110,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -110,7 +110,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int filter_offset = filter->numel() / groups; int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f; T alpha = 1.0f, beta = 0.0f;
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g, handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc, cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
algo, cudnn_workspace, workspace_size_in_bytes, &beta, algo, cudnn_workspace, workspace_size_in_bytes, &beta,
...@@ -178,11 +178,11 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -178,11 +178,11 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
if (input_grad) { if (input_grad) {
// choose backward algorithm for data // choose backward algorithm for data
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc, handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_input_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, cudnn_input_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &data_algo)); workspace_size_limit, &data_algo));
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc, handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_input_desc, data_algo, &fwd_ws_size)); cudnn_input_desc, data_algo, &fwd_ws_size));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, fwd_ws_size); workspace_size_in_bytes = std::max(workspace_size_in_bytes, fwd_ws_size);
...@@ -190,7 +190,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -190,7 +190,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
if (filter_grad) { if (filter_grad) {
// choose backward algorithm for filter // choose backward algorithm for filter
PADDLE_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc, handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_filter_desc, cudnn_filter_desc,
...@@ -198,7 +198,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -198,7 +198,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
workspace_size_limit, &filter_algo)); workspace_size_limit, &filter_algo));
// get workspace for backwards filter algorithm // get workspace for backwards filter algorithm
PADDLE_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc, handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_filter_desc, filter_algo, &bwd_filter_ws_size)); cudnn_filter_desc, filter_algo, &bwd_filter_ws_size));
...@@ -222,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -222,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc, handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_filter_desc, output_grad_data + output_grad_offset * g, cudnn_filter_desc,
filter_data + filter_offset * g, cudnn_conv_desc, data_algo, filter_data + filter_offset * g, cudnn_conv_desc, data_algo,
...@@ -237,7 +237,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -237,7 +237,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset filter_grad. // Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter // Gradient with respect to the filter
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc, handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_input_desc, output_grad_data + output_grad_offset * g, cudnn_input_desc,
input_data + input_offset * g, cudnn_conv_desc, filter_algo, input_data + input_offset * g, cudnn_conv_desc, filter_algo,
......
...@@ -17,9 +17,9 @@ if(WITH_GRPC) ...@@ -17,9 +17,9 @@ if(WITH_GRPC)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(grpc_serde_test SRCS grpc_serde_test.cc cc_test(grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL) DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
return() return()
endif() endif()
......
...@@ -30,7 +30,7 @@ namespace framework = paddle::framework; ...@@ -30,7 +30,7 @@ namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace distributed = paddle::operators::distributed; namespace distributed = paddle::operators::distributed;
USE_OP(lookup_table); USE_NO_KERNEL_OP(lookup_sparse_table);
std::unique_ptr<distributed::RPCServer> g_rpc_service; std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler; std::unique_ptr<distributed::RequestHandler> g_req_handler;
...@@ -42,13 +42,13 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { ...@@ -42,13 +42,13 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}}); framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}}); framework::VariableNameMap output({{"Output", {"out"}}});
auto op = block->AppendOp(); auto op = block->AppendOp();
op->SetType("lookup_table"); op->SetType("lookup_sparse_table");
op->SetInput("W", {"w"}); op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"}); op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"}); op->SetOutput("Out", {"out"});
auto& out = *root_block->Var("out"); auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::SELECTED_ROWS); out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetShape({10, 10}); out.SetShape({10, 10});
return block; return block;
...@@ -59,20 +59,19 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { ...@@ -59,20 +59,19 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
w_var->GetMutable<framework::SelectedRows>(); w_var->GetMutable<framework::SelectedRows>();
auto out_var = scope->Var("out"); auto out_var = scope->Var("out");
out_var->GetMutable<framework::SelectedRows>(); out_var->GetMutable<framework::LoDTensor>();
auto ids_var = scope->Var("ids"); auto ids_var = scope->Var("ids");
ids_var->GetMutable<framework::SelectedRows>(); ids_var->GetMutable<framework::LoDTensor>();
} }
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) { int64_t rows_numel) {
CreateVarsOnScope(scope, place); CreateVarsOnScope(scope, place);
auto ids_var = scope->Var("ids")->GetMutable<framework::SelectedRows>(); auto ids_var = scope->Var("ids")->GetMutable<framework::LoDTensor>();
auto rows = ids_var->mutable_rows(); int64_t* ids_ptr =
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i * 2); ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
ids_var->mutable_value()->Resize({rows_numel, 1}); for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
ids_var->mutable_value()->mutable_data<float>(*place);
} }
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
...@@ -148,11 +147,11 @@ TEST(PREFETCH, CPU) { ...@@ -148,11 +147,11 @@ TEST(PREFETCH, CPU) {
client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name); client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
client->Wait(); client->Wait();
auto var = scope.Var(out_var_name); auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value(); auto value = var->GetMutable<framework::LoDTensor>();
auto ptr = value.mutable_data<float>(place); auto ptr = value->mutable_data<float>(place);
for (int64_t i = 0; i < rows_numel; ++i) { for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2)); EXPECT_EQ(ptr[0 + i * value->dims()[1]], static_cast<float>(i * 2));
} }
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class ExtractRowsOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ExtractRowsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ExtractRowsOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X")[0],
framework::proto::VarType::SELECTED_ROWS,
"The type of input(X) must be SelectedRows.");
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(
"Out", framework::make_ddim(std::vector<int64_t>{in_dims[0], 1}));
}
};
class ExtractRowsOp : public framework::OperatorBase {
public:
ExtractRowsOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &in = scope.FindVar(Input("X"))->Get<framework::SelectedRows>();
auto out = scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
auto in_rows = in.rows();
auto out_dim = framework::make_ddim(
std::vector<int64_t>{static_cast<int64_t>(in_rows.size()), 1});
auto dst_ptr = out->mutable_data<int64_t>(out_dim, in.place());
if (paddle::platform::is_gpu_place(in.place())) {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(in.place());
auto src_ptr = in_rows.Data(in.place());
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(*dev_ctx)
.stream();
memory::Copy(boost::get<platform::CUDAPlace>(out->place()), dst_ptr,
boost::get<platform::CUDAPlace>(in.place()), src_ptr,
in_rows.size() * sizeof(int64_t), stream);
#else
PADDLE_THROW("Not compiled with CUDA.");
#endif
} else {
memory::Copy(platform::CPUPlace(), dst_ptr, platform::CPUPlace(),
in_rows.data(), in_rows.size() * sizeof(int64_t));
}
}
};
class ExtractRowsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(SelectedRows). The input tensor of extract_rows operator,"
" and its type is SelectedRows.");
AddOutput("Out", "(Tensor). The the rows of input(X).");
AddComment(R"DOC(
ExtractRows Operator.
The function of extract_rows_op is extracting the rows from the input(X)
whose type is SelectedRows.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(extract_rows, ops::ExtractRowsOp, ops::ExtractRowsOpMaker,
ops::ExtractRowsOpInferShape);
...@@ -33,19 +33,15 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -33,19 +33,15 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W"); auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputDim("Ids");
auto ids_var_type = ctx->GetInputsVarType("Ids").front(); PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type PADDLE_ENFORCE_EQ(ids_dims[1], 1);
// is LoDTensor, this tensor contains the ids to be looked up in W
// and it must be a column vector with rank = 2 while the 2nd dimension
// size must be 1, when Ids's type is SelectedRows, the rows of Ids
// contains the ids to be looked up in W;
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
}
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
ctx->ShareLoD("Ids", /*->*/ "Out");
if (ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("Ids", /*->*/ "Out");
}
} }
protected: protected:
...@@ -62,17 +58,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -62,17 +58,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("W", AddInput("W",
"(Tensor) The input represents embedding tensors, " "(Tensor) The input represents embedding tensors, "
"which is a learnable parameter."); "which is a learnable parameter.");
AddInput( AddInput("Ids",
"Ids", "An input with type int32 or int64 "
"(Tensor or SelectedRows) Ids's type can be Tensor or " "contains the ids to be looked up in W. "
"SelectedRows, when Ids's type is Tensor, this tensor contains " "Ids must be a column vector with rank = 2. "
"the ids to be looked up in W and it must be a column vector with " "The 2nd dimension size must be 1.");
"rank = 2 while the 2nd dimension size must be 1; when Ids's type is " AddOutput("Out", "The lookup results, which have the same type as W.");
"SelectedRows, the rows of Ids contains the ids to be looked up "
"in W.");
AddOutput("Out",
"(Tensor or SelectedRows) The lookup results, which have the "
"same type as W.");
AddAttr<bool>("is_sparse", AddAttr<bool>("is_sparse",
"(boolean, default false) " "(boolean, default false) "
"Sparse update.") "Sparse update.")
...@@ -90,15 +81,10 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -90,15 +81,10 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
Lookup Table Operator. Lookup Table Operator.
This operator is used to perform lookups on the parameter W, This operator is used to perform lookups on the parameter W,
then concatenated into a dense or sparse tensor. then concatenated into a dense tensor.
The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's
type is SelectedRows, the rows of Ids contains the ids to be looked up in W;
when Ids's type is Tensor, this tensor contains the ids to be looked up in W
and it must be a column vector with rank = 2 while the 2nd dimension size must be 1,
at this time, Ids can carry the LoD (Level of Details) information, or not, and
the output only shares the LoD information with input Ids.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC"); )DOC");
} }
......
...@@ -23,7 +23,7 @@ namespace operators { ...@@ -23,7 +23,7 @@ namespace operators {
template <typename T, int BlockDimX, int BlockDimY, int GridDimX, template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
bool PaddingFlag> bool PaddingFlag>
__global__ void LookupTable(T* output, const T* table, const int64_t* ids, __global__ void LookupTable(T *output, const T *table, const int64_t *ids,
const int64_t N, const int64_t K, const int64_t D, const int64_t N, const int64_t K, const int64_t D,
const int64_t padding_idx) { const int64_t padding_idx) {
int idx = threadIdx.x; int idx = threadIdx.x;
...@@ -33,8 +33,8 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids, ...@@ -33,8 +33,8 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
int64_t id = ids[idy]; int64_t id = ids[idy];
PADDLE_ASSERT(id >= 0); PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N); PADDLE_ASSERT(id < N);
T* out = output + idy * D; T *out = output + idy * D;
const T* tab = table + id * D; const T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) { for (int i = idx; i < D; i += BlockDimX) {
if (PaddingFlag) { if (PaddingFlag) {
if (id == padding_idx) if (id == padding_idx)
...@@ -50,7 +50,7 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids, ...@@ -50,7 +50,7 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
} }
template <typename T, int BlockDimX, int BlockDimY, int GridDimX> template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids, __global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
const int64_t N, const int64_t K, const int64_t N, const int64_t K,
const int64_t D) { const int64_t D) {
int idx = threadIdx.x; int idx = threadIdx.x;
...@@ -60,8 +60,8 @@ __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids, ...@@ -60,8 +60,8 @@ __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
int id = ids[idy]; int id = ids[idy];
PADDLE_ASSERT(id >= 0); PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N); PADDLE_ASSERT(id < N);
const T* out = output + idy * D; const T *out = output + idy * D;
T* tab = table + id * D; T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) { for (int i = idx; i < D; i += BlockDimX) {
paddle::platform::CudaAtomicAdd(&tab[i], out[i]); paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
} }
...@@ -72,36 +72,19 @@ __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids, ...@@ -72,36 +72,19 @@ __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
template <typename T> template <typename T>
class LookupTableCUDAKernel : public framework::OpKernel<T> { class LookupTableCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* table_t = context.Input<LoDTensor>("W"); auto *table_t = context.Input<LoDTensor>("W");
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx"); int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto* ids_var = context.InputVar("Ids");
Tensor* output_t = context.Output<Tensor>("Out");
int64_t* ids;
int64_t K;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<framework::LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
K = ids_t->numel();
} else if (ids_var->IsType<framework::SelectedRows>()) {
auto* ids_t = context.Input<framework::SelectedRows>("Ids");
ids = const_cast<int64_t*>(ids_t->rows().CUDAData(context.GetPlace()));
K = ids_t->rows().size();
output_t->Resize({K, table_t->dims()[1]});
} else {
PADDLE_THROW("Unsupported Variable Type of Ids");
}
size_t N = table_t->dims()[0]; size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1]; size_t D = table_t->dims()[1];
auto* table = table_t->data<T>(); size_t K = ids_t->numel();
auto* output = output_t->mutable_data<T>(context.GetPlace());
auto *ids = ids_t->data<int64_t>();
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(8, 1); dim3 grids(8, 1);
...@@ -122,19 +105,19 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -122,19 +105,19 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
template <typename T> template <typename T>
class LookupTableGradCUDAKernel : public framework::OpKernel<T> { class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto& dev_ctx = auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>(); context.template device_context<platform::CUDADeviceContext>();
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context.Attr<bool>("is_sparse");
// Since paddings are not trainable and fixed in forward, the gradient of // Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward. // paddings makes no sense and we don't deal with it in backward.
if (is_sparse) { if (is_sparse) {
auto* ids = context.Input<LoDTensor>("Ids"); auto *ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W"); auto *table = context.Input<LoDTensor>("W");
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto* ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims(); auto ids_dim = ids->dims();
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
...@@ -150,12 +133,12 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -150,12 +133,12 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
d_table->set_rows(new_rows); d_table->set_rows(new_rows);
auto* d_table_value = d_table->mutable_value(); auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_dim[0], table->dims()[1]}); d_table_value->Resize({ids_dim[0], table->dims()[1]});
d_table_value->mutable_data<T>(context.GetPlace()); d_table_value->mutable_data<T>(context.GetPlace());
auto* d_table_data = d_table_value->data<T>(); auto *d_table_data = d_table_value->data<T>();
auto* d_output_data = d_output->data<T>(); auto *d_output_data = d_output->data<T>();
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims()); PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data, memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
d_output->numel() * sizeof(T), stream); d_output->numel() * sizeof(T), stream);
...@@ -168,9 +151,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -168,9 +151,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
int N = d_table_t->dims()[0]; int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1]; int D = d_table_t->dims()[1];
int K = ids_t->numel(); int K = ids_t->numel();
const int64_t* ids = ids_t->data<int64_t>(); const int64_t *ids = ids_t->data<int64_t>();
const T* d_output = d_output_t->data<T>(); const T *d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace()); T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t); auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0)); t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
......
...@@ -36,43 +36,13 @@ template <typename T> ...@@ -36,43 +36,13 @@ template <typename T>
class LookupTableKernel : public framework::OpKernel<T> { class LookupTableKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W"); auto *table_var = context.InputVar("W");
auto *ids_var = context.InputVar("Ids");
Tensor *output_t = context.Output<Tensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
DDim table_dim;
if (table_var->IsType<LoDTensor>()) { int64_t padding_idx = context.Attr<int64_t>("padding_idx");
table_dim = context.Input<LoDTensor>("W")->dims(); int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
} else if (table_var->IsType<SelectedRows>()) { int64_t ids_numel = ids_t->numel();
auto *table_t = context.Input<SelectedRows>("W");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW(
"The parameter W of a LookupTable "
"must be either LoDTensor or SelectedRows");
}
int64_t *ids;
int64_t ids_numel;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<LoDTensor>()) {
auto *ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t *>(ids_t->data<int64_t>());
ids_numel = ids_t->numel();
} else if (ids_var->IsType<SelectedRows>()) {
auto *ids_t = context.Input<SelectedRows>("Ids");
ids = const_cast<int64_t *>(ids_t->rows().data());
ids_numel = ids_t->rows().size();
output_t->Resize({ids_numel, table_dim[1]});
} else {
PADDLE_THROW("Unsupported Variable Type of Ids");
}
if (table_var->IsType<LoDTensor>()) { if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W"); auto *table_t = context.Input<LoDTensor>("W");
......
...@@ -40,22 +40,47 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -40,22 +40,47 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int im_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col->dims()[1]; int filter_height = col->dims()[1];
int filter_width = col->dims()[2]; int filter_width = col->dims()[2];
int col_height = col->dims()[3]; int output_height = col->dims()[3];
int col_width = col->dims()[4]; int output_width = col->dims()[4];
int channels_col = im_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col->data<T>(); T* col_data = col->data<T>();
// TODO(TJ): change me to template
// further optimaze:
// 1. padding != 1
// 2. could also support stride_h != 1
if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
dilation[1] == 1 && padding[0] == 0 && padding[1] == 0) {
int col_matrix_width = output_width * output_height;
size_t copy_size = sizeof(T) * output_width;
for (int oh = 0; oh < output_height; ++oh) {
const T* im_data_start = im_data + oh * im_width;
T* dst_data = col_data + oh * output_width;
for (int ic = 0; ic < im_channels; ++ic) {
const T* src_data = im_data_start + ic * im_height * im_width;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
std::memcpy(dst_data, src_data + kw, copy_size);
dst_data = dst_data + col_matrix_width;
}
src_data = src_data + im_width;
}
}
}
return;
}
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height); int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) { for (int h = 0; h < output_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) { for (int w = 0; w < output_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w; int col_idx = (c * output_height + h) * output_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
......
...@@ -160,8 +160,80 @@ void testIm2col() { ...@@ -160,8 +160,80 @@ void testIm2col() {
delete context; delete context;
} }
void testIm2colCPU(int ic, int ih, int iw, int fh, int fw, int ph, int pw) {
paddle::framework::Tensor input;
paddle::framework::Tensor output;
paddle::framework::Tensor ref_output;
std::vector<int> padding({ph, pw});
std::vector<int> stride({1, 1}); // stride_y, stride_x
std::vector<int> dilation({1, 1}); // dilation_y, dilation_x
int output_height = (ih - fh + padding[0] * 2) / stride[0] + 1;
int output_width = (iw - fw + padding[1] * 2) / stride[1] + 1;
float* input_ptr =
input.mutable_data<float>({ic, ih, iw}, paddle::platform::CPUPlace());
for (int i = 0; i < input.numel(); ++i) {
input_ptr[i] = static_cast<float>(i + 1);
}
paddle::platform::CPUPlace place;
paddle::platform::CPUDeviceContext context(place);
output.mutable_data<float>({ic, fh, fw, output_height, output_width}, place);
ref_output.mutable_data<float>({ic, fh, fw, output_height, output_width},
place);
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO,
paddle::platform::CPUDeviceContext, float>
im2col;
im2col(context, input, dilation, stride, padding, &output);
auto ref_im2col = [&](
const paddle::framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride, const std::vector<int>& padding,
paddle::framework::Tensor* col) {
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int output_height = col->dims()[3];
int output_width = col->dims()[4];
int channels_col = im_channels * filter_height * filter_width;
const float* im_data = im.data<float>();
float* col_data = col->data<float>();
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < output_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < output_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * output_height + h) * output_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? 0.f
: im_data[im_idx];
}
}
}
};
ref_im2col(input, dilation, stride, padding, &ref_output);
float* out_cfo_ptr = output.data<float>();
float* out_ref_ptr = ref_output.data<float>();
for (int i = 0; i < output.numel(); ++i) {
EXPECT_EQ(out_cfo_ptr[i], out_ref_ptr[i]);
}
}
TEST(math, im2col) { TEST(math, im2col) {
testIm2col<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>(); testIm2col<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>();
testIm2colCPU(/*ic*/ 3, /*ih*/ 5, /*iw*/ 5, /*fh*/ 3, /*fw*/ 2, /*ph*/ 0,
/*pw*/ 0);
testIm2colCPU(/*ic*/ 2, /*ih*/ 5, /*iw*/ 4, /*fh*/ 3, /*fw*/ 3, /*ph*/ 1,
/*pw*/ 1);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
testIm2col<paddle::platform::CUDADeviceContext, testIm2col<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace>(); paddle::platform::CUDAPlace>();
......
...@@ -52,7 +52,7 @@ void SoftmaxCUDNNFunctor<T>::operator()( ...@@ -52,7 +52,7 @@ void SoftmaxCUDNNFunctor<T>::operator()(
xDesc.descriptor<T>(layout, cudnn_tensor_dims); xDesc.descriptor<T>(layout, cudnn_tensor_dims);
cudnnTensorDescriptor_t cudnn_y_desc = cudnnTensorDescriptor_t cudnn_y_desc =
xDesc.descriptor<T>(layout, cudnn_tensor_dims); xDesc.descriptor<T>(layout, cudnn_tensor_dims);
PADDLE_ENFORCE(platform::dynload::cudnnSoftmaxForward( CUDNN_ENFORCE(platform::dynload::cudnnSoftmaxForward(
context.cudnn_handle(), CUDNN_SOFTMAX_ACCURATE, context.cudnn_handle(), CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType<T>::kOne(), cudnn_x_desc, CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType<T>::kOne(), cudnn_x_desc,
X->data<T>(), CudnnDataType<T>::kZero(), cudnn_y_desc, X->data<T>(), CudnnDataType<T>::kZero(), cudnn_y_desc,
...@@ -83,7 +83,7 @@ void SoftmaxGradCUDNNFunctor<T>::operator()( ...@@ -83,7 +83,7 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
dxDesc.descriptor<T>(layout, cudnn_tensor_dims); dxDesc.descriptor<T>(layout, cudnn_tensor_dims);
cudnnTensorDescriptor_t cudnn_ygrad_desc = cudnnTensorDescriptor_t cudnn_ygrad_desc =
dyDesc.descriptor<T>(layout, cudnn_tensor_dims); dyDesc.descriptor<T>(layout, cudnn_tensor_dims);
PADDLE_ENFORCE(platform::dynload::cudnnSoftmaxBackward( CUDNN_ENFORCE(platform::dynload::cudnnSoftmaxBackward(
context.cudnn_handle(), CUDNN_SOFTMAX_ACCURATE, context.cudnn_handle(), CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType<T>::kOne(), cudnn_y_desc, CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType<T>::kOne(), cudnn_y_desc,
Y->data<T>(), cudnn_ygrad_desc, YGrad->data<T>(), Y->data<T>(), cudnn_ygrad_desc, YGrad->data<T>(),
......
...@@ -81,7 +81,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> { ...@@ -81,7 +81,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn pool algorithm --------------------- // ------------------- cudnn pool algorithm ---------------------
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
PADDLE_ENFORCE(platform::dynload::cudnnPoolingForward( CUDNN_ENFORCE(platform::dynload::cudnnPoolingForward(
handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta, handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta,
cudnn_output_desc, output_data)); cudnn_output_desc, output_data));
} }
...@@ -154,7 +154,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -154,7 +154,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward( CUDNN_ENFORCE(platform::dynload::cudnnPoolingBackward(
handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data, handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data,
cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data, cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data,
&beta, cudnn_input_desc, input_grad_data)); &beta, cudnn_input_desc, input_grad_data));
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/ir/node.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,7 +23,10 @@ inline bool NeedSend(const framework::Scope& scope, ...@@ -22,7 +23,10 @@ inline bool NeedSend(const framework::Scope& scope,
const std::string& varname) { const std::string& varname) {
// dummy variable is only used in parallel executor to represent // dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it. // some dependency relationship, we don't need to send/recv it.
if (varname == "dummy") return false; // TODO(paddle-dev): Why would parallel executor logic leaked into here?
if (varname.find(framework::ir::Node::kControlDepVarName) !=
std::string::npos)
return false;
auto* var = scope.FindVar(varname); auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname); varname);
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include <omp.h>
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
...@@ -33,6 +34,7 @@ void SetNumThreads(int num_threads) { ...@@ -33,6 +34,7 @@ void SetNumThreads(int num_threads) {
#elif defined(PADDLE_WITH_MKLML) #elif defined(PADDLE_WITH_MKLML)
int real_num_threads = num_threads > 1 ? num_threads : 1; int real_num_threads = num_threads > 1 ? num_threads : 1;
platform::dynload::MKL_Set_Num_Threads(real_num_threads); platform::dynload::MKL_Set_Num_Threads(real_num_threads);
omp_set_num_threads(num_threads);
#else #else
PADDLE_ENFORCE(false, "To be implemented."); PADDLE_ENFORCE(false, "To be implemented.");
#endif #endif
......
...@@ -59,13 +59,12 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) { ...@@ -59,13 +59,12 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
#define CUDNN_VERSION_MIN(major, minor, patch) \ #define CUDNN_VERSION_MIN(major, minor, patch) \
(CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch))) (CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch)))
#define CUDNN_ENFORCE(condition) \ #define CUDNN_ENFORCE(condition) \
do { \ do { \
cudnnStatus_t status = condition; \ cudnnStatus_t status = condition; \
if (status != CUDNN_STATUS_SUCCESS) { \ if (UNLIKELY(status != CUDNN_STATUS_SUCCESS)) { \
VLOG(1) << ::paddle::platform::cudnnGetErrorString(status); \ PADDLE_THROW(::paddle::platform::cudnnGetErrorString(status)); \
PADDLE_THROW("cuDNN call failed"); \ } \
} \
} while (false) } while (false)
enum class DataLayout { // Not use enum class DataLayout { // Not use
......
...@@ -23,6 +23,9 @@ limitations under the License. */ ...@@ -23,6 +23,9 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/piece.h" #include "paddle/fluid/string/piece.h"
DEFINE_int32(paddle_num_threads, 1,
"Number of threads for each paddle instance.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -115,7 +118,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) { ...@@ -115,7 +118,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
places.emplace_back(platform::CPUPlace()); places.emplace_back(platform::CPUPlace());
platform::DeviceContextPool::Init(places); platform::DeviceContextPool::Init(places);
#ifndef PADDLE_WITH_MKLDNN #ifndef PADDLE_WITH_MKLDNN
platform::SetNumThreads(1); platform::SetNumThreads(FLAGS_paddle_num_threads);
#endif #endif
} }
......
...@@ -547,6 +547,7 @@ function test_fluid_inference_lib() { ...@@ -547,6 +547,7 @@ function test_fluid_inference_lib() {
EOF EOF
cd ${PADDLE_ROOT}/paddle/fluid/inference/api/demo_ci cd ${PADDLE_ROOT}/paddle/fluid/inference/api/demo_ci
./run.sh ${PADDLE_ROOT} ${WITH_MKL:-ON} ${WITH_GPU:-OFF} ./run.sh ${PADDLE_ROOT} ${WITH_MKL:-ON} ${WITH_GPU:-OFF}
./clean.sh
fi fi
} }
......
...@@ -62,33 +62,33 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable ...@@ -62,33 +62,33 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable
Tensor = LoDTensor Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \
trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \ trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + lod_tensor.__all__ + [ parallel_executor.__all__ + lod_tensor.__all__ + [
'io', 'io',
'initializer', 'initializer',
'layers', 'layers',
'contrib', 'contrib',
'transpiler', 'transpiler',
'nets', 'nets',
'optimizer', 'optimizer',
'learning_rate_decay', 'learning_rate_decay',
'backward', 'backward',
'regularizer', 'regularizer',
'LoDTensor', 'LoDTensor',
'LoDTensorArray', 'LoDTensorArray',
'CPUPlace', 'CPUPlace',
'CUDAPlace', 'CUDAPlace',
'CUDAPinnedPlace', 'CUDAPinnedPlace',
'Tensor', 'Tensor',
'ParamAttr', 'ParamAttr',
'WeightNormParamAttr', 'WeightNormParamAttr',
'DataFeeder', 'DataFeeder',
'clip', 'clip',
'profiler', 'profiler',
'unique_name', 'unique_name',
'recordio_writer', 'recordio_writer',
'Scope', 'Scope',
] ]
def __bootstrap__(): def __bootstrap__():
...@@ -123,7 +123,7 @@ def __bootstrap__(): ...@@ -123,7 +123,7 @@ def __bootstrap__():
read_env_flags = [ read_env_flags = [
'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir',
'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb',
'init_allocated_mem' 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads'
] ]
if core.is_compiled_with_dist(): if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline') read_env_flags.append('rpc_deadline')
......
...@@ -1540,7 +1540,12 @@ class Program(object): ...@@ -1540,7 +1540,12 @@ class Program(object):
def inference_optimize(self): def inference_optimize(self):
""" """
This method will create a new program and change the :code:`is_test` This method will create a new program and do following adjustments on it:
1. Remove all reader variables and their creator ops if exist.
2. Remove the :code:`read_op` if exists.
3. change the :code:`is_test`
attribute of operators to :code:`True`. All the :code:`Parameter` attribute of operators to :code:`True`. All the :code:`Parameter`
information will be lost. information will be lost.
...@@ -1554,6 +1559,22 @@ class Program(object): ...@@ -1554,6 +1559,22 @@ class Program(object):
# core.inference_optimize being fixed. # core.inference_optimize being fixed.
res = Program() res = Program()
res.desc = core.ProgramDesc(self.desc) res.desc = core.ProgramDesc(self.desc)
# remove all readers and the read_op if exist
read_op_idx = 0
root_block = res.desc.block(0)
while True:
if read_op_idx >= root_block.op_size() or root_block.op(
read_op_idx).type() == 'read':
break
read_op_idx += 1
if read_op_idx < root_block.op_size():
root_block._remove_op(0, read_op_idx + 1)
for var in root_block.all_vars():
if var.type() == core.VarDesc.VarType.READER:
root_block._remove_var(var.name())
# change all `is_test` attributes to True
for i in xrange(res.desc.num_blocks()): for i in xrange(res.desc.num_blocks()):
block = res.desc.block(i) block = res.desc.block(i)
for j in xrange(block.op_size()): for j in xrange(block.op_size()):
......
...@@ -790,101 +790,3 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -790,101 +790,3 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program() program = default_main_program()
var = program.global_block().var(name) var = program.global_block().var(name)
return get_parameter_value(var, executor) return get_parameter_value(var, executor)
def get_test_program(filelist, program=None, startup_program=None):
"""
Transpile current train program to a program to read test dataset
if the program is using reader ops like "open_files_op".
"""
def _copy_reader_var_(block, var, new_name=None):
if new_name == None:
new_name = var.name
new_var = block.create_var(
name=str(new_name), type=core.VarDesc.VarType.READER)
new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes())
new_var.persistable = True
return new_var
def _get_test_reader_name(train_reader_name):
return train_reader_name + "_test"
def _is_reader_op(op):
block = op.block
if "Out" in op.output_names:
reader_out = block.vars[op.output("Out")[0]]
if reader_out.type == core.VarDesc.VarType.READER:
return True
return False
if program == None:
program = default_main_program()
if startup_program == None:
startup_program = default_startup_program()
startup_block = startup_program.global_block()
# 1. find out the orignal reader var name
startup_reader_op_list = []
for op in startup_block.ops:
if _is_reader_op(op):
startup_reader_op_list.append(op)
if len(startup_reader_op_list) == 0:
return program
root_reader_op = startup_reader_op_list[0]
train_test_reader_map = {}
# 2. add operators to startup to read open and read test data files
for op in startup_reader_op_list:
assert (len(op.output("Out")) == 1)
train_reader_name = op.output("Out")[0]
train_reader = startup_block.vars[train_reader_name]
test_reader = _copy_reader_var_(
startup_block,
train_reader,
new_name=_get_test_reader_name(train_reader_name))
train_test_reader_map[train_reader.name] = test_reader
test_op_inputs = {}
for name in op.input_names:
train_arg_names = op.input(name)
test_arg_vars = []
for arg_name in train_arg_names:
arg_var = train_test_reader_map[
arg_name] if name == "UnderlyingReader" else startup_block.vars[
arg_name]
test_arg_vars.append(arg_var)
test_op_inputs[name] = test_arg_vars
test_op = startup_block.append_op(
type=op.type,
inputs=test_op_inputs,
outputs={'Out': [test_reader]},
attrs=op.attrs)
# root reader op's filelist attr for read test files
if op.type == root_reader_op.type:
test_op.set_attr("file_names", filelist)
if op.type == "create_multi_pass_reader":
test_op.set_attr("pass_num", 1)
# 3. rename reader vars in inference program to different name
# to avoid read from train data.
main_block = program.global_block()
for var in main_block.vars.values():
if var.type == core.VarDesc.VarType.READER:
main_block._rename_var(
str(var.name), str(_get_test_reader_name(var.name)))
for op in main_block.ops:
if op.type == root_reader_op.type:
test_op.set_attr("file_names", filelist)
if op.type == "create_multi_pass_reader":
test_op.set_attr("pass_num", 1)
startup_program._sync_with_cpp()
program._sync_with_cpp()
return program
...@@ -443,9 +443,6 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True): ...@@ -443,9 +443,6 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
main_prog_var = _copy_reader_var_(default_main_program().current_block(), main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_var)
if for_parallel:
main_prog_var = parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
......
...@@ -142,14 +142,20 @@ class L2DecayRegularizer(WeightDecayRegularizer): ...@@ -142,14 +142,20 @@ class L2DecayRegularizer(WeightDecayRegularizer):
dtype="float32", shape=param.shape, lod_level=param.lod_level) dtype="float32", shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS: if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
idx = block.create_var(
dtype="int64",
shape=param.shape,
type=core.VarDesc.VarType.LOD_TENSOR)
decay = block.create_var( decay = block.create_var(
dtype="float32", dtype="float32",
shape=param.shape, shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS) type=core.VarDesc.VarType.SELECTED_ROWS)
block.append_op(
type='extract_rows', inputs={'X': grad}, outputs={'Out': idx})
block.append_op( block.append_op(
type='lookup_table', type='lookup_table',
inputs={'W': param, inputs={'W': param,
'Ids': grad}, 'Ids': idx},
outputs={'Out': decay}, outputs={'Out': decay},
attrs={'is_sparse': True}) attrs={'is_sparse': True})
param = decay param = decay
...@@ -216,14 +222,20 @@ class L1DecayRegularizer(WeightDecayRegularizer): ...@@ -216,14 +222,20 @@ class L1DecayRegularizer(WeightDecayRegularizer):
dtype="float32", shape=param.shape, lod_level=param.lod_level) dtype="float32", shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS: if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
idx = block.create_var(
dtype="int64",
shape=param.shape,
type=core.VarDesc.VarType.LOD_TENSOR)
decay = block.create_var( decay = block.create_var(
dtype="float32", dtype="float32",
shape=param.shape, shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS) type=core.VarDesc.VarType.SELECTED_ROWS)
block.append_op(
type='extract_rows', inputs={'X': grad}, outputs={'Out': idx})
block.append_op( block.append_op(
type='lookup_table', type='lookup_table',
inputs={'W': param, inputs={'W': param,
'Ids': grad}, 'Ids': idx},
outputs={'Out': decay}, outputs={'Out': decay},
attrs={'is_sparse': True}) attrs={'is_sparse': True})
......
...@@ -35,7 +35,7 @@ if len(sys.argv) == 1: ...@@ -35,7 +35,7 @@ if len(sys.argv) == 1:
word_dict = paddle.dataset.imdb.word_dict() word_dict = paddle.dataset.imdb.word_dict()
else: else:
word_dict = load_vocab(sys.argv[1]) word_dict = load_vocab(sys.argv[1])
word_dict["<unk>"] = len(word_dict) word_dict["<unk>"] = len(word_dict)
print "Dict dim = ", len(word_dict) print "Dict dim = ", len(word_dict)
# input text data # input text data
...@@ -50,7 +50,7 @@ feeder = fluid.DataFeeder(feed_list=[data, label], place=fluid.CPUPlace()) ...@@ -50,7 +50,7 @@ feeder = fluid.DataFeeder(feed_list=[data, label], place=fluid.CPUPlace())
BATCH_SIZE = 128 BATCH_SIZE = 128
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=10000), paddle.dataset.imdb.train(word_dict), buf_size=25000),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
test_reader = paddle.batch( test_reader = paddle.batch(
......
...@@ -19,7 +19,7 @@ import sys ...@@ -19,7 +19,7 @@ import sys
TRAIN_FILES = ['train.recordio'] TRAIN_FILES = ['train.recordio']
TEST_FILES = ['test.recordio'] TEST_FILES = ['test.recordio']
DICT_DIM = 89528 DICT_DIM = 5147
# embedding dim # embedding dim
emb_dim = 128 emb_dim = 128
...@@ -27,58 +27,46 @@ emb_dim = 128 ...@@ -27,58 +27,46 @@ emb_dim = 128
# hidden dim # hidden dim
hid_dim = 128 hid_dim = 128
# hidden dim2
hid_dim2 = 96
# class num # class num
class_dim = 2 class_dim = 2
# epoch num
epoch_num = 10
def network_cfg(is_train, pass_num=100):
with fluid.unique_name.guard():
train_file_obj = fluid.layers.open_files(
filenames=TRAIN_FILES,
pass_num=pass_num,
shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0],
dtypes=['int64', 'int64'])
test_file_obj = fluid.layers.open_files(
filenames=TEST_FILES,
pass_num=1,
shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0],
dtypes=['int64', 'int64'])
if is_train: def build_program(is_train):
file_obj = fluid.layers.shuffle(train_file_obj, buffer_size=1000) file_obj_handle = fluid.layers.io.open_files(
else: filenames=TRAIN_FILES if is_train else TEST_FILES,
file_obj = test_file_obj shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0],
dtypes=['int64', 'int64'])
file_obj = fluid.layers.double_buffer( file_obj = fluid.layers.io.double_buffer(file_obj_handle)
file_obj,
name="train_double_buffer" if is_train else 'test_double_buffer') with fluid.unique_name.guard():
data, label = fluid.layers.read_file(file_obj) data, label = fluid.layers.read_file(file_obj)
emb = fluid.layers.embedding(input=data, size=[DICT_DIM, emb_dim]) emb = fluid.layers.embedding(input=data, size=[DICT_DIM, emb_dim])
# sequence conv with window size = 3
win_size = 3
conv_3 = fluid.nets.sequence_conv_pool( conv_3 = fluid.nets.sequence_conv_pool(
input=emb, input=emb,
num_filters=hid_dim, num_filters=hid_dim,
filter_size=win_size, filter_size=3,
act="tanh", act="tanh",
pool_type="max") pool_type="sqrt")
# fc layer after conv conv_4 = fluid.nets.sequence_conv_pool(
fc_1 = fluid.layers.fc(input=[conv_3], size=hid_dim2) input=emb,
num_filters=hid_dim,
filter_size=4,
act="tanh",
pool_type="sqrt")
# probability of each class prediction = fluid.layers.fc(input=[conv_3, conv_4],
prediction = fluid.layers.fc(input=[fc_1],
size=class_dim, size=class_dim,
act="softmax") act="softmax")
# cross entropy loss # cross entropy loss
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
...@@ -88,58 +76,62 @@ def network_cfg(is_train, pass_num=100): ...@@ -88,58 +76,62 @@ def network_cfg(is_train, pass_num=100):
if is_train: if is_train:
# SGD optimizer # SGD optimizer
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.01) sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return { return {'loss': avg_cost, 'log': [avg_cost, acc], 'file': file_obj_handle}
'loss': avg_cost,
'log': [avg_cost, acc],
'file': train_file_obj if is_train else test_file_obj
}
def main(): def main():
train = fluid.Program() train = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
test = fluid.Program()
with fluid.program_guard(train, startup): with fluid.program_guard(train, startup):
train_args = network_cfg(is_train=True) train_args = build_program(is_train=True)
test = fluid.Program()
with fluid.program_guard(test, fluid.Program()): with fluid.program_guard(test, startup):
test_args = network_cfg(is_train=False) test_args = build_program(is_train=False)
use_cuda = fluid.core.is_compiled_with_cuda()
# startup # startup
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
exe.run(startup) exe.run(startup)
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=True, loss_name=train_args['loss'].name, main_program=train) use_cuda=use_cuda,
loss_name=train_args['loss'].name,
main_program=train)
test_exe = fluid.ParallelExecutor(
use_cuda=use_cuda, main_program=test, share_vars_from=train_exe)
fetch_var_list = [var.name for var in train_args['log']] fetch_var_list = [var.name for var in train_args['log']]
for i in xrange(sys.maxint): for epoch_id in range(epoch_num):
result = map(numpy.array, # train
train_exe.run(fetch_list=fetch_var_list try:
if i % 1000 == 0 else [])) batch_id = 0
if len(result) != 0: while True:
print 'Train: ', result loss, acc = map(numpy.array,
train_exe.run(fetch_list=fetch_var_list))
if i % 1000 == 0: print 'Train epoch', epoch_id, 'batch', batch_id, 'loss:', loss, 'acc:', acc
test_exe = fluid.ParallelExecutor( batch_id += 1
use_cuda=True, main_program=test, share_vars_from=train_exe) except fluid.core.EOFException:
loss = [] print 'End of epoch', epoch_id
acc = [] train_args['file'].reset()
try:
while True: # test
loss_np, acc_np = map( loss = []
numpy.array, test_exe.run(fetch_list=fetch_var_list)) acc = []
loss.append(loss_np[0]) try:
acc.append(acc_np[0]) while True:
except: loss_np, acc_np = map(numpy.array,
test_args['file'].reset() test_exe.run(fetch_list=fetch_var_list))
print 'TEST: ', numpy.mean(loss), numpy.mean(acc) loss.append(loss_np[0])
acc.append(acc_np[0])
except:
test_args['file'].reset()
print 'Test loss:', numpy.mean(loss), 'acc:', numpy.mean(acc)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -278,7 +278,7 @@ class DistSeResneXt2x2: ...@@ -278,7 +278,7 @@ class DistSeResneXt2x2:
def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True): def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True):
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = get_model( test_program, avg_cost, train_reader, test_reader, batch_acc, predict = get_model(
batch_size=20) batch_size=2)
if is_dist: if is_dist:
t = get_transpiler(trainer_id, t = get_transpiler(trainer_id,
fluid.default_main_program(), endpoints, fluid.default_main_program(), endpoints,
...@@ -294,11 +294,7 @@ class DistSeResneXt2x2: ...@@ -294,11 +294,7 @@ class DistSeResneXt2x2:
strategy.num_threads = 1 strategy.num_threads = 1
strategy.allow_op_delay = False strategy.allow_op_delay = False
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
True, True, loss_name=avg_cost.name, exec_strategy=strategy)
loss_name=avg_cost.name,
exec_strategy=strategy,
num_trainers=trainers,
trainer_id=trainer_id)
feed_var_list = [ feed_var_list = [
var for var in trainer_prog.global_block().vars.itervalues() var for var in trainer_prog.global_block().vars.itervalues()
......
...@@ -56,7 +56,7 @@ class TestDistSeResneXt2x2(unittest.TestCase): ...@@ -56,7 +56,7 @@ class TestDistSeResneXt2x2(unittest.TestCase):
except os.error: except os.error:
retry_times -= 1 retry_times -= 1
def non_test_with_place(self): def test_with_place(self):
# *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN # *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN
required_envs = { required_envs = {
"PATH": os.getenv("PATH"), "PATH": os.getenv("PATH"),
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest
class TestExtractRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Variable
feature_len = 12
rows = [0, 4, 4, 7]
np_array = np.ones((len(rows), feature_len)).astype("float32")
in_x = scope.var('X').get_selected_rows()
in_x.set_height(len(rows))
in_x.set_rows(rows)
in_x_tensor = in_x.get_tensor()
in_x_tensor.set(np_array, place)
# create Out Variable
out_tensor = scope.var('Out').get_tensor()
# create and run lookup_table operator
extract_rows_op = Operator("extract_rows", X='X', Out='Out')
extract_rows_op.run(scope, place)
# get result from Out
result_array = np.array(out_tensor)
result_array = [ele[0] for ele in result_array]
assert result_array == rows
def test_concat_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
if __name__ == '__main__':
unittest.main()
...@@ -49,53 +49,6 @@ class TestLookupTableOpWithPadding(TestLookupTableOp): ...@@ -49,53 +49,6 @@ class TestLookupTableOpWithPadding(TestLookupTableOp):
pass pass
class TestLookupTableIdsIsSelectedRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Variable
height = 10
rows = [0, 4, 4, 7]
row_numel = 12
# create and initialize W Variable
W = scope.var('W').get_tensor()
W_array = np.full((height, row_numel), 1.0).astype("float32")
for i in range(height):
W_array[i] *= i
W.set(W_array, place)
# create and initialize Ids Variable
ids_selected_rows = scope.var('Ids').get_selected_rows()
ids_selected_rows.set_height(len(rows))
ids_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
ids_tensor = ids_selected_rows.get_tensor()
ids_tensor.set(np_array, place)
# create Out Variable
Out = scope.var('Out').get_selected_rows()
# create and run lookup_table operator
concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
concat_rows_op.run(scope, place)
# get result from Out
Out_tensor = Out.get_tensor()
result_array = np.array(Out_tensor)
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(rows):
assert (row == result_array[idx]).all()
def test_concat_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
class TestLookupTableWIsSelectedRows(OpTest): class TestLookupTableWIsSelectedRows(OpTest):
def check_with_place(self, place): def check_with_place(self, place):
scope = core.Scope() scope = core.Scope()
......
...@@ -107,44 +107,24 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -107,44 +107,24 @@ class TestMNIST(TestParallelExecutorBase):
label = np.ones(shape=[32, 1], dtype='int64') label = np.ones(shape=[32, 1], dtype='int64')
return img, label return img, label
# simple_fc def _compare_reduce_and_allreduce(self, model, use_cuda, random_data=True):
def check_simple_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
self.check_network_convergence(simple_fc_net, use_cuda=use_cuda)
self.check_network_convergence( self.check_network_convergence(
simple_fc_net, use_cuda=use_cuda, allow_op_delay=True) model, use_cuda=use_cuda, use_reduce=True)
img, label = self._init_data()
self.check_network_convergence( self.check_network_convergence(
simple_fc_net, model, use_cuda=use_cuda, allow_op_delay=True, use_reduce=True)
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=use_reduce)
def check_simple_fc_convergence_with_Reduce(self, use_cuda): img, label = self._init_data(random_data)
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence(
simple_fc_net, use_cuda=use_cuda, use_reduce=True)
self.check_network_convergence(
simple_fc_net,
use_cuda=use_cuda,
allow_op_delay=True,
use_reduce=True)
img, label = self._init_data()
all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence( all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence(
simple_fc_net, model,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda, use_cuda=use_cuda,
use_reduce=False) use_reduce=False)
reduce_first_loss, reduce_last_loss = self.check_network_convergence( reduce_first_loss, reduce_last_loss = self.check_network_convergence(
simple_fc_net, model,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda, use_cuda=use_cuda,
...@@ -153,7 +133,24 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -153,7 +133,24 @@ class TestMNIST(TestParallelExecutorBase):
for loss in zip(all_reduce_first_loss, reduce_first_loss): for loss in zip(all_reduce_first_loss, reduce_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(all_reduce_last_loss, reduce_last_loss): for loss in zip(all_reduce_last_loss, reduce_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) self.assertAlmostEquals(loss[0], loss[1], delta=1e-4)
# simple_fc
def check_simple_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence(simple_fc_net, use_cuda=use_cuda)
self.check_network_convergence(
simple_fc_net, use_cuda=use_cuda, allow_op_delay=True)
img, label = self._init_data()
self.check_network_convergence(
simple_fc_net,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=use_reduce)
def test_simple_fc(self): def test_simple_fc(self):
# use_cuda # use_cuda
...@@ -162,8 +159,8 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -162,8 +159,8 @@ class TestMNIST(TestParallelExecutorBase):
def test_simple_fc_with_new_strategy(self): def test_simple_fc_with_new_strategy(self):
# use_cuda, use_reduce # use_cuda, use_reduce
self.check_simple_fc_convergence_with_Reduce(True) self._compare_reduce_and_allreduce(simple_fc_net, True)
self.check_simple_fc_convergence_with_Reduce(False) self._compare_reduce_and_allreduce(simple_fc_net, False)
def check_simple_fc_parallel_accuracy(self, use_cuda): def check_simple_fc_parallel_accuracy(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
...@@ -209,39 +206,13 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -209,39 +206,13 @@ class TestMNIST(TestParallelExecutorBase):
"label": label}, "label": label},
use_cuda=use_cuda) use_cuda=use_cuda)
def check_batchnorm_fc_convergence_use_reduce(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence(
fc_with_batchnorm, use_cuda=use_cuda, use_reduce=True)
img, label = self._init_data()
all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence(
fc_with_batchnorm,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=False)
reduce_first_loss, reduce_last_loss = self.check_network_convergence(
fc_with_batchnorm,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=True)
for loss in zip(all_reduce_first_loss, reduce_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(all_reduce_last_loss, reduce_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-4)
def test_batchnorm_fc(self): def test_batchnorm_fc(self):
self.check_batchnorm_fc_convergence(True) self.check_batchnorm_fc_convergence(True)
self.check_batchnorm_fc_convergence(False) self.check_batchnorm_fc_convergence(False)
def test_batchnorm_fc_with_new_strategy(self): def test_batchnorm_fc_with_new_strategy(self):
self.check_batchnorm_fc_convergence_use_reduce(True) self._compare_reduce_and_allreduce(fc_with_batchnorm, True)
self.check_batchnorm_fc_convergence_use_reduce(False) self._compare_reduce_and_allreduce(fc_with_batchnorm, False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -779,7 +779,9 @@ class DistributeTranspiler(object): ...@@ -779,7 +779,9 @@ class DistributeTranspiler(object):
outputs={"Out": prefetch_output_vars}, outputs={"Out": prefetch_output_vars},
attrs={ attrs={
"epmap": pserver_endpoints, "epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE # FIXME(qiao) temporarily disable this config because prefetch
# is not act as other rpc op, it's more like a forward op
# RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
# insert concat_op # insert concat_op
...@@ -887,7 +889,8 @@ class DistributeTranspiler(object): ...@@ -887,7 +889,8 @@ class DistributeTranspiler(object):
# create table optimize block in pserver program # create table optimize block in pserver program
table_opt_op = [ table_opt_op = [
op for op in self.optimize_ops op for op in self.optimize_ops
if op.input("Param")[0] == self.table_name if 'Param' in op.input_names and op.input("Param")[0] ==
self.table_name
][0] ][0]
table_opt_block = pserver_program.create_block(pre_block_idx) table_opt_block = pserver_program.create_block(pre_block_idx)
# only support sgd now # only support sgd now
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册