未验证 提交 4f859408 编写于 作者: Z Zeng Jinle 提交者: GitHub

Enhance inplace/mem-opt pass and enhance softmax_with_cross_entropy op inplace (#17225)

* add use_cuda to inplace pass,test=develop

* add test softmax_with_xe_inplace test,test=develop

* fix potential inplace bug
test=develop

* add more skip vars in mem opt pass,test=develop

* follow comment,test=develop

* follow comments,move duplicate out arg check to program->graph,test=develop
上级 8b62f537
...@@ -111,10 +111,14 @@ class InplacePass : public ir::Pass { ...@@ -111,10 +111,14 @@ class InplacePass : public ir::Pass {
// Check whether all `ops` is the preceding ops of `op` // Check whether all `ops` is the preceding ops of `op`
bool CheckOpDeps(ir::Node *op, const std::vector<ir::Node *> &ops) const; bool CheckOpDeps(ir::Node *op, const std::vector<ir::Node *> &ops) const;
// Find nodes whose name are equal to the given name // Find nodes whose names are equal to the given name
static std::unordered_set<ir::Node *> FindNodesByName( static std::unordered_set<ir::Node *> FindNodesByName(
const std::string &name, const std::vector<ir::Node *> &nodes); const std::string &name, const std::vector<ir::Node *> &nodes);
// Collect inputs and outputs of op_desc
static void CollectInputArgsOfOpDesc(
const OpDesc *op_desc, std::unordered_multiset<std::string> *in_args);
// Get all versions vars named var_name // Get all versions vars named var_name
std::vector<ir::Node *> *AllVersionVars(const std::string &var_name) const; std::vector<ir::Node *> *AllVersionVars(const std::string &var_name) const;
...@@ -201,37 +205,6 @@ void InplacePass::CollectSkipVars(ir::Graph *graph, ...@@ -201,37 +205,6 @@ void InplacePass::CollectSkipVars(ir::Graph *graph,
for (const auto &var : mem_opt_whitelist) { for (const auto &var : mem_opt_whitelist) {
skip_vars_.emplace(var); skip_vars_.emplace(var);
} }
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// pserver can not find each other's name.
// Also check the ops which has sub-block
auto update_skip_set = [&](ir::Node *node) {
for (auto &in : node->inputs) {
if (in->IsVar() && in->Var() != nullptr) {
skip_vars_.emplace(in->Name());
}
}
for (auto &out : node->outputs) {
if (out->IsVar() && out->Var() != nullptr) {
skip_vars_.emplace(out->Name());
}
}
};
for (auto *node : ops) {
if (!node->IsOp()) continue;
// avoid optimizing the variable used in sub-blocks
if (OpHasSubBlock(node->Op())) {
update_skip_set(node);
continue;
}
auto node_name = node->Name();
if (node_name == "send" || node_name == "recv" || node_name == "prefetch") {
update_skip_set(node);
}
}
} }
void InplacePass::RenameInOut(ir::Node *op, ir::Node *in_var, void InplacePass::RenameInOut(ir::Node *op, ir::Node *in_var,
...@@ -301,6 +274,14 @@ std::unordered_set<ir::Node *> InplacePass::FindNodesByName( ...@@ -301,6 +274,14 @@ std::unordered_set<ir::Node *> InplacePass::FindNodesByName(
return ret; return ret;
} }
void InplacePass::CollectInputArgsOfOpDesc(
const OpDesc *op_desc, std::unordered_multiset<std::string> *in_args) {
in_args->clear();
for (auto &in_name : op_desc->InputArgumentNames()) {
in_args->insert(in_name);
}
}
void InplacePass::ApplyImpl(ir::Graph *graph) const { void InplacePass::ApplyImpl(ir::Graph *graph) const {
// Step 1: topo sort ops, collect skip vars // Step 1: topo sort ops, collect skip vars
auto ops = ir::TopologySortOperations(*graph); auto ops = ir::TopologySortOperations(*graph);
...@@ -346,6 +327,11 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const { ...@@ -346,6 +327,11 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
} }
auto in_to_outs = infer_inplace(*op_desc, use_cuda); auto in_to_outs = infer_inplace(*op_desc, use_cuda);
if (in_to_outs.empty()) continue;
std::unordered_multiset<std::string> all_in_args;
CollectInputArgsOfOpDesc(op_desc, &all_in_args);
for (auto &pair : in_to_outs) { for (auto &pair : in_to_outs) {
auto &in_param = pair.first; auto &in_param = pair.first;
auto &out_param = pair.second; auto &out_param = pair.second;
...@@ -387,6 +373,14 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const { ...@@ -387,6 +373,14 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
continue; continue;
} }
size_t in_arg_occur_times = all_in_args.count(in_arg);
if (in_arg_occur_times > 1) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " occurs " << in_arg_occur_times << " times in input of op "
<< op_type;
continue;
}
auto in_nodes = FindNodesByName(in_arg, op_node->inputs); auto in_nodes = FindNodesByName(in_arg, op_node->inputs);
PADDLE_ENFORCE(!in_nodes.empty(), "Input(%s)=%s cannot be found in op %s", PADDLE_ENFORCE(!in_nodes.empty(), "Input(%s)=%s cannot be found in op %s",
in_param, in_arg, op_type); in_param, in_arg, op_type);
......
...@@ -207,28 +207,8 @@ void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const { ...@@ -207,28 +207,8 @@ void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const {
// fill skip_set_ // fill skip_set_
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars)); PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars));
auto& mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars); auto& mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var); for (const auto& var : mem_opt_whitelist) {
skip_set_.emplace(var);
auto update_skip_set = [&](OpDesc* op_desc) {
auto inputs = op_desc->InputArgumentNames();
auto outputs = op_desc->OutputArgumentNames();
skip_set_.insert(inputs.begin(), inputs.end());
skip_set_.insert(outputs.begin(), outputs.end());
};
auto nodes = graph->Nodes();
for (auto& op : nodes) {
if (!op->IsOp() || op->Op() == nullptr) continue;
auto* op_desc = op->Op();
// NOTE(dzhwinter):
// current block can not reuse next level block vars.
if (OpHasSubBlock(op_desc)) update_skip_set(op_desc);
// NOTE(dzhwinter):
// distributed ops input/output name need to
// keep same bettwen trainer/pserver
if (op_desc->Type() == "send") update_skip_set(op_desc);
if (op_desc->Type() == "recv") update_skip_set(op_desc);
if (op_desc->Type() == "prefetch") update_skip_set(op_desc);
} }
} }
......
...@@ -13,11 +13,14 @@ ...@@ -13,11 +13,14 @@
// limitations under the License. // limitations under the License.
#include <string> #include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -30,26 +33,129 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass { ...@@ -30,26 +33,129 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
graph->Set(kMemOptSkipVars, new MemOptSkipVars); graph->Set(kMemOptSkipVars, new MemOptSkipVars);
auto& skip_vars = graph->Get<MemOptSkipVars>(kMemOptSkipVars); auto& skip_vars = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
std::vector<ir::Node*> op_nodes;
for (auto& node : graph->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node, "The node should not be nullptr.");
if (node->IsOp() && node->Op()) {
op_nodes.emplace_back(node);
}
}
// Insert kEmptyVarName to avoid optimizing empty variable
skip_vars.insert(framework::kEmptyVarName);
// NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename // NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename
// in memory optimize pass. // in memory optimize pass.
InsertOpRoleVarsToSkipVarSet(graph, &skip_vars); InsertOpRoleVarsToSkipVarSet(op_nodes, &skip_vars);
InsertSkipMemOptOpInOutToSkipVarSet(op_nodes, &skip_vars);
} }
void InsertOpRoleVarsToSkipVarSet(const ir::Graph* graph, private:
MemOptSkipVars* skip_vars) const { static void InsertOpRoleVarsToSkipVarSet(const std::vector<ir::Node*>& ops,
for (auto& node : graph->Nodes()) { MemOptSkipVars* skip_vars) {
PADDLE_ENFORCE_NOT_NULL(node, "The node should not be nullptr."); for (auto& node : ops) {
if (node->IsOp() && node->Op()) { try {
try { auto op_role_vars =
auto op_role_vars = boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr( OpProtoAndCheckerMaker::OpRoleVarAttrName()));
OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(op_role_vars.size() % 2, 0);
PADDLE_ENFORCE_EQ(op_role_vars.size() % 2, 0); for (size_t i = 0; i < op_role_vars.size(); i += 2) {
for (size_t i = 0; i < op_role_vars.size(); i += 2) { auto& g_name = op_role_vars[i + 1];
auto& g_name = op_role_vars[i + 1]; skip_vars->insert(g_name);
skip_vars->insert(g_name); }
} } catch (boost::bad_get& e) {
} catch (boost::bad_get e) { }
}
}
static void UpdateSkipVarSet(
MemOptSkipVars* skip_vars,
const std::vector<std::vector<std::string>>& var_names) {
for (auto& var_name : var_names) {
skip_vars->insert(var_name.begin(), var_name.end());
}
}
static std::vector<std::string> ToGradVarName(
const std::vector<std::string>& names) {
std::vector<std::string> ret;
ret.reserve(names.size());
for (auto& name : names) {
if (name != framework::kEmptyVarName) {
ret.emplace_back(framework::GradVarName(name));
}
}
return ret;
}
static void InsertSkipMemOptOpInOutToSkipVarSet(
const std::vector<ir::Node*>& ops, MemOptSkipVars* skip_vars) {
static std::unordered_set<std::string> kSkipMemOptOps{
"send", "recv", "prefetch", "send_barrier", "fetch_barrier"};
for (auto& node : ops) {
auto* op_desc = node->Op();
// Some ops (while, conditional_block, recurrent, etc.) have sub-blocks.
// These ops often use variables from its parent or forward blocks.
// Optimizing in/out of such ops would make these variables cannot
// be found when running sub-block ops.
if (OpHasSubBlock(op_desc)) {
UpdateSkipVarSet(skip_vars, {op_desc->InputArgumentNames(),
op_desc->OutputArgumentNames()});
}
// Skip ops that are related to parameter server.
// In distributed mode, trainers and parameter server use same
// variable names to track same variables. We cannot change the
// names of these variables, otherwise trainers or parameter
// server would not find them.
if (kSkipMemOptOps.count(op_desc->Type()) > 0) {
UpdateSkipVarSet(skip_vars, {op_desc->InputArgumentNames(),
op_desc->OutputArgumentNames()});
}
// FIXME(zjl): some ops use variables that are not from their
// inputs or outputs. We do not have a nice method to solve this
// issue yet. Currently, we should skip these variables when
// memory optimization is enabled.
auto op_type = op_desc->Type();
if (op_type == "while_grad") {
// In while_grad, framework::GradVarName(Input("X")) is visited
// without being any in/out of while_grad. While_grad uses
// these variable to accumulate gradient of X across time steps.
UpdateSkipVarSet(skip_vars, {ToGradVarName(op_desc->Input("X"))});
} else if (op_type == "conditional_block_grad") {
// In conditional_block_grad, framework::GradVarName(Input("Input",
// "Cond")) is visited without being any in/out of
// conditional_block_grad. Conditional_block_grad uses these
// variables to accumulate gradient of Input/Cond across time steps.
UpdateSkipVarSet(skip_vars, {ToGradVarName(op_desc->Input("Input")),
ToGradVarName(op_desc->Input("Cond"))});
} else if (op_type == "recurrent" || op_type == "recurrent_grad") {
// Recurrent and recurrent_grad ops are implemented by a very trickly
// way. Attr("states", "ex_states") is visited without being any
// in/out of op. It is because these variables are from sub blocks,
// not main block. Adding these variables to input would make recurrent
// fail since "states" and "ex_states" cannot be found in main block.
// When memory optimization is enabled, "states", "ex_states" and their
// gradient should be skipped.
auto& ex_states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("ex_states"));
auto& states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("states"));
if (op_type == "recurrent") {
UpdateSkipVarSet(skip_vars, {ex_states, states});
} else {
// In recurrent_grad, framework::GradVarName(Input("parameters",
// "input")) is visited without being any in/out of recurrent_grad.
// Recurrent_grad uses these variables to accumulate gradient of
// parameters/input across time steps.
UpdateSkipVarSet(
skip_vars,
{ToGradVarName(op_desc->Input("parameters")),
ToGradVarName(op_desc->Input("input")), ex_states, states,
ToGradVarName(ex_states), ToGradVarName(states)});
} }
} }
} }
......
...@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <memory>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
...@@ -61,7 +66,16 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( ...@@ -61,7 +66,16 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
var->outputs.push_back(node); var->outputs.push_back(node);
} }
// For output args, always create a new var. // For output args, always create a new var.
std::unordered_set<std::string> out_arg_set;
for (auto &each_var_name : op->OutputArgumentNames()) { for (auto &each_var_name : op->OutputArgumentNames()) {
if (each_var_name != kEmptyVarName) {
PADDLE_ENFORCE(out_arg_set.count(each_var_name) == 0,
"Program is wrong. %s occurs in output of %s several "
"times.",
each_var_name, op->Type());
out_arg_set.insert(each_var_name);
}
ir::Node *var = nullptr; ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) { if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name)); var = CreateVarNode(all_vars.at(each_var_name));
......
...@@ -261,11 +261,7 @@ class SoftmaxWithCrossEntropyInplaceInference ...@@ -261,11 +261,7 @@ class SoftmaxWithCrossEntropyInplaceInference
public: public:
std::unordered_map<std::string, std::string> operator()( std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const { const framework::OpDesc& op_desc, bool use_cuda) const {
if (use_cuda && !boost::get<bool>(op_desc.GetAttr("soft_label"))) { return {{"Logits", "Softmax"}};
return {{"Logits", "Softmax"}};
} else {
return {};
}
} }
}; };
......
...@@ -21,25 +21,39 @@ import unittest ...@@ -21,25 +21,39 @@ import unittest
class TestSoftmaxWithXe(unittest.TestCase): class TestSoftmaxWithXe(unittest.TestCase):
def setUp(self): def setUp(self):
self.initParameter()
self.m, self.n = np.random.random_integers( self.m, self.n = np.random.random_integers(
low=100, high=2000, size=[2]).astype('int64') low=100, high=2000, size=[2]).astype('int64')
def softmax_with_xe(self, x, y, place, inplace=True): def initParameter(self):
self.dtype = 'float32'
self.soft_label = False
def softmax_with_xe(self,
x,
y,
place,
inplace=True,
numeric_stable_mode=True):
m, n = x.shape m, n = x.shape
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()): with fluid.scope_guard(fluid.Scope()):
x_d = fluid.layers.data( x_d = fluid.layers.data(
name='x', name='x',
shape=[m, n], shape=[m, n],
dtype='float32', dtype=self.dtype,
append_batch_size=False) append_batch_size=False)
y_d = fluid.layers.data( y_d = fluid.layers.data(
name='y', name='y',
shape=[m, 1], shape=[m, 1] if not self.soft_label else [m, n],
dtype='int64', dtype='int64' if not self.soft_label else self.dtype,
append_batch_size=False) append_batch_size=False)
z_d, s_d = fluid.layers.softmax_with_cross_entropy( z_d, s_d = fluid.layers.softmax_with_cross_entropy(
x_d, y_d, return_softmax=True) x_d,
y_d,
soft_label=self.soft_label,
return_softmax=True,
numeric_stable_mode=numeric_stable_mode)
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -51,7 +65,7 @@ class TestSoftmaxWithXe(unittest.TestCase): ...@@ -51,7 +65,7 @@ class TestSoftmaxWithXe(unittest.TestCase):
)).with_data_parallel( )).with_data_parallel(
build_strategy=build_strategy, places=place) build_strategy=build_strategy, places=place)
if inplace and isinstance(place, fluid.CUDAPlace): if inplace:
fetch_list = [z_d.name, x_d.name] fetch_list = [z_d.name, x_d.name]
else: else:
fetch_list = [z_d.name, s_d.name] fetch_list = [z_d.name, s_d.name]
...@@ -63,16 +77,33 @@ class TestSoftmaxWithXe(unittest.TestCase): ...@@ -63,16 +77,33 @@ class TestSoftmaxWithXe(unittest.TestCase):
return z, s return z, s
def main_with_place(self, place): def main_with_place(self, place):
x = np.random.random(size=[self.m, self.n]).astype('float32') x = np.random.random(size=[self.m, self.n]).astype(self.dtype)
x_range = [(-30, 30), (10, 20), (-1, 1), (2, 3), (0, 0.3), (-200, -100)] x_range = [(-30, 30), (10, 20), (-1, 1), (2, 3), (0, 0.3), (-200, -100)]
for a, b in x_range: for a, b in x_range:
x = ((b - a) * x + a).astype('float32') x = ((b - a) * x + a).astype(self.dtype)
y = np.random.random_integers( if not self.soft_label:
size=[self.m, 1], low=0, high=self.n - 1).astype('int64') y = np.random.random_integers(
z1, s1 = self.softmax_with_xe(x, y, place, False) size=[self.m, 1], low=0, high=self.n - 1).astype('int64')
z2, s2 = self.softmax_with_xe(x, y, place, True) else:
y = np.random.random(size=[self.m, self.n]).astype(self.dtype)
norm_y = np.broadcast_to(
np.reshape(
np.sum(y, axis=1), [-1, 1]), y.shape)
y = y / norm_y
z1, s1 = self.softmax_with_xe(
x, y, place, inplace=False, numeric_stable_mode=False)
z2, s2 = self.softmax_with_xe(
x, y, place, inplace=True, numeric_stable_mode=False)
self.assertTrue((z1 == z2).all())
self.assertTrue((s1 == s2).all())
z1, s1 = self.softmax_with_xe(
x, y, place, inplace=False, numeric_stable_mode=True)
z2, s2 = self.softmax_with_xe(
x, y, place, inplace=True, numeric_stable_mode=True)
self.assertTrue((z1 == z2).all()) self.assertTrue((z1 == z2).all())
self.assertTrue((s1 == s2).all()) self.assertTrue((s1 == s2).all())
...@@ -82,5 +113,23 @@ class TestSoftmaxWithXe(unittest.TestCase): ...@@ -82,5 +113,23 @@ class TestSoftmaxWithXe(unittest.TestCase):
self.main_with_place(fluid.CUDAPlace(0)) self.main_with_place(fluid.CUDAPlace(0))
class TestSoftmaxWithXe1(TestSoftmaxWithXe):
def initParameter(self):
self.dtype = 'float32'
self.soft_label = True
class TestSoftmaxWithXe2(TestSoftmaxWithXe):
def initParameter(self):
self.dtype = 'float64'
self.soft_label = False
class TestSoftmaxWithXe3(TestSoftmaxWithXe):
def initParameter(self):
self.dtype = 'float64'
self.soft_label = True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册