未验证 提交 96980c22 编写于 作者: Y Yiqun Liu 提交者: GitHub

Polish the PADDLE_ENFORCE in fusion_group pass related codes. (#22144)

* Polish the PADDLE_ENFORCE in fusion_group pass related codes.
test=develop

* Correct the unittest because of the change relu_grad's formula.
test=develop
上级 4f7a2bd0
...@@ -36,6 +36,16 @@ std::string CodeGenerator::Generate(SubGraph* subgraph) { ...@@ -36,6 +36,16 @@ std::string CodeGenerator::Generate(SubGraph* subgraph) {
return Generate(subgraph->func_name, expressions); return Generate(subgraph->func_name, expressions);
} }
static bool HasInput(Node* n, std::string name) {
PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
platform::errors::InvalidArgument(
"Expected node %p to be an operator node.", n));
std::vector<std::string> input_names = n->Op()->InputNames();
std::unordered_set<std::string> input_names_set(input_names.begin(),
input_names.end());
return input_names_set.find(name) != input_names_set.end();
}
std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
SubGraph* subgraph) { SubGraph* subgraph) {
std::unordered_map<std::string, int> var_ids = EncodeVarNodes(subgraph); std::unordered_map<std::string, int> var_ids = EncodeVarNodes(subgraph);
...@@ -45,19 +55,20 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( ...@@ -45,19 +55,20 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
auto* op = node->Op(); auto* op = node->Op();
// Input ids should be set in fixed order, like: // Input ids should be set in fixed order, like:
// - x, y in forward operations // - X, Y in forward operations
// - x, y, out, out@GRAD in backward operations // - X, Y, Out, out@GRAD in backward operations
std::vector<int> input_ids; std::vector<int> input_ids;
std::vector<std::string> input_names = std::vector<std::string> input_names =
OperationMap::Instance().Get(op->Type()).input_names; OperationMap::Instance().Get(op->Type()).input_names;
for (auto& name : input_names) { for (auto& name : input_names) {
// TODO(liuyiqun): support duplicated input. // Some input vars are not used in grad ops, such as
if (op->Input(name).size() >= 1U) { // "elementwise_add_grad", where "X", "Y" and "Out" are not used.
// Some input vars are not used in grad ops, such as if (HasInput(node, name) && op->Input(name).size() >= 1U) {
// "elementwise_add_grad", where "X", "Y" and "Out" are not used. // TODO(liuyiqun): support duplicated input.
PADDLE_ENFORCE_NE(var_ids.find(op->Input(name)[0]), var_ids.end(), PADDLE_ENFORCE_NE(
"Input(%s) of operation %s should be set.", name, var_ids.find(op->Input(name)[0]), var_ids.end(),
op->Type()); platform::errors::InvalidArgument(
"Input(%s) of operation %s is not set.", name, op->Type()));
input_ids.push_back(var_ids[op->Input(name)[0]]); input_ids.push_back(var_ids[op->Input(name)[0]]);
} else { } else {
input_ids.push_back(-1); input_ids.push_back(-1);
...@@ -69,12 +80,14 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( ...@@ -69,12 +80,14 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
std::vector<std::string> output_names = std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names; OperationMap::Instance().Get(op->Type()).output_names;
for (auto& name : output_names) { for (auto& name : output_names) {
PADDLE_ENFORCE_EQ(op->Output(name).size(), 1U, PADDLE_ENFORCE_EQ(
"Output(%s) of operation %s should be set.", name, op->Output(name).size(), 1U,
op->Type()); platform::errors::InvalidArgument(
PADDLE_ENFORCE_NE(var_ids.find(op->Output(name)[0]), var_ids.end(), "Output(%s) of operation %s is not set.", name, op->Type()));
"Output(%s) of operation %s should be set.", name, PADDLE_ENFORCE_NE(
op->Type()); var_ids.find(op->Output(name)[0]), var_ids.end(),
platform::errors::InvalidArgument(
"Output(%s) of operation %s is not set.", name, op->Type()));
output_ids.push_back(var_ids[op->Output(name)[0]]); output_ids.push_back(var_ids[op->Output(name)[0]]);
} }
expressions.push_back( expressions.push_back(
...@@ -218,8 +231,9 @@ std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes( ...@@ -218,8 +231,9 @@ std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes(
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
is_found, true, is_found, true,
"Subgraph with internal var nodes (%s) is not supported yet.", platform::errors::Unimplemented(
node->Name()); "Subgraph with internal var nodes (%s) is not supported yet.",
node->Name()));
} }
} }
// Encoding output vars. // Encoding output vars.
......
...@@ -45,11 +45,16 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used, ...@@ -45,11 +45,16 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
} }
std::string index_str = rhs.substr(pos + 2, length); std::string index_str = rhs.substr(pos + 2, length);
int index = StringTo<int>(index_str); int index = StringTo<int>(index_str);
PADDLE_ENFORCE_LT(index, input_ids_.size(), PADDLE_ENFORCE_LT(
"Only %d inputs are provided, but need %d.", index, input_ids_.size(),
input_ids_.size(), index + 1); platform::errors::InvalidArgument(
PADDLE_ENFORCE_GE(input_ids_[index], 0, "Only %d inputs are provided, but need %d for operation < %s >.",
"Input id should be no less than 0."); input_ids_.size(), index + 1, op_type_));
PADDLE_ENFORCE_GE(
input_ids_[index], 0,
platform::errors::InvalidArgument(
"Expected %d-th input id > 0 for operation < %s >. Received %d.",
index, op_type_, input_ids_[index]));
rhs.replace(pos, length + 3, TmpName(input_ids_[index])); rhs.replace(pos, length + 3, TmpName(input_ids_[index]));
used->insert(input_ids_[index]); used->insert(input_ids_[index]);
} }
......
...@@ -113,7 +113,8 @@ class CodeTemplate { ...@@ -113,7 +113,8 @@ class CodeTemplate {
for (auto iter : template_var.Get()) { for (auto iter : template_var.Get()) {
PADDLE_ENFORCE_NE(found.find(iter.first), found.end(), PADDLE_ENFORCE_NE(found.find(iter.first), found.end(),
"Keyword %s in template is not set.", iter.first); platform::errors::PreconditionNotMet(
"Keyword %s in template is not set.", iter.first));
} }
return EmitIndents(ret); return EmitIndents(ret);
......
...@@ -35,7 +35,7 @@ namespace fusion_group { ...@@ -35,7 +35,7 @@ namespace fusion_group {
inline float relu(float x) { return x > 0 ? x : 0.; } inline float relu(float x) { return x > 0 ? x : 0.; }
inline float relu_grad_dx(float x, float out, float dout) { inline float relu_grad_dx(float x, float out, float dout) {
return x > 0 ? dout : 0; return out > 0 ? dout : 0;
} }
// sigmoid // sigmoid
...@@ -117,7 +117,7 @@ void CheckOutput(const std::vector<OperationExpression>& expressions, ...@@ -117,7 +117,7 @@ void CheckOutput(const std::vector<OperationExpression>& expressions,
elementwise_mul(var[input_ids[0]], var[input_ids[1]]); elementwise_mul(var[input_ids[0]], var[input_ids[1]]);
} else if (op_type == "relu_grad") { } else if (op_type == "relu_grad") {
var[output_ids[0]] = var[output_ids[0]] =
relu_grad_dx(var[input_ids[0]], 0, var[input_ids[2]]); relu_grad_dx(0, var[input_ids[1]], var[input_ids[2]]);
} else if (op_type == "sigmoid_grad") { } else if (op_type == "sigmoid_grad") {
var[output_ids[0]] = var[output_ids[0]] =
sigmoid_grad_dx(0, var[input_ids[1]], var[input_ids[2]]); sigmoid_grad_dx(0, var[input_ids[1]], var[input_ids[2]]);
...@@ -138,8 +138,7 @@ void CheckOutput(const std::vector<OperationExpression>& expressions, ...@@ -138,8 +138,7 @@ void CheckOutput(const std::vector<OperationExpression>& expressions,
for (auto id : output_ids_of_subgraph) { for (auto id : output_ids_of_subgraph) {
float actual = cpu_tensors[id].data<float>()[i]; float actual = cpu_tensors[id].data<float>()[i];
float expect = var[id]; float expect = var[id];
PADDLE_ENFORCE_LT(fabs(actual - expect), 1.E-05, EXPECT_LT(fabs(actual - expect), 1.E-05);
"Get %f vs %f (actual vs expect).", actual, expect);
} }
} }
...@@ -150,8 +149,7 @@ void SetupRandomCPUTensor(LoDTensor* tensor) { ...@@ -150,8 +149,7 @@ void SetupRandomCPUTensor(LoDTensor* tensor) {
std::uniform_real_distribution<double> uniform_dist(0, 1); std::uniform_real_distribution<double> uniform_dist(0, 1);
T* ptr = tensor->data<T>(); T* ptr = tensor->data<T>();
PADDLE_ENFORCE_NOT_NULL( EXPECT_NE(ptr, nullptr);
ptr, "Call mutable_data to alloc memory for Tensor first.");
for (int64_t i = 0; i < tensor->numel(); ++i) { for (int64_t i = 0; i < tensor->numel(); ++i) {
ptr[i] = static_cast<T>(uniform_dist(rng)) - static_cast<T>(0.5); ptr[i] = static_cast<T>(uniform_dist(rng)) - static_cast<T>(0.5);
} }
...@@ -283,7 +281,7 @@ TEST(code_generator, elementwise_grad) { ...@@ -283,7 +281,7 @@ TEST(code_generator, elementwise_grad) {
// t3 = relu(t2) // t3 = relu(t2)
// t2' = relu_grad(t2, t3, t3') // t2' = relu_grad(t2, t3, t3')
// t0', t1' = elementwise_mul_grad(t0, t1, t2, t2') // t0', t1' = elementwise_mul_grad(t0, t1, t2, t2')
fusion_group::OperationExpression exp1("relu_grad", {2, -1, 7}, {6}); fusion_group::OperationExpression exp1("relu_grad", {-1, 3, 7}, {6});
fusion_group::OperationExpression exp2("elementwise_mul_grad", {0, 1, 2, 6}, fusion_group::OperationExpression exp2("elementwise_mul_grad", {0, 1, 2, 6},
{4, 5}); {4, 5});
std::vector<fusion_group::OperationExpression> expressions = {exp1, exp2}; std::vector<fusion_group::OperationExpression> expressions = {exp1, exp2};
...@@ -300,7 +298,7 @@ TEST(code_generator, elementwise_grad) { ...@@ -300,7 +298,7 @@ TEST(code_generator, elementwise_grad) {
// Op(relu_grad), inputs:{2,3,7}, outputs:{6} // Op(relu_grad), inputs:{2,3,7}, outputs:{6}
// Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5} // Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5}
int n = cpu_tensors[0].numel(); int n = cpu_tensors[0].numel();
std::vector<int> input_ids = {0, 1, 2, -1, 7}; std::vector<int> input_ids = {0, 1, 2, 3, 7};
std::vector<int> output_ids = {4, 5, 6}; std::vector<int> output_ids = {4, 5, 6};
TestMain("elementwise_grad_kernel_0", expressions, cpu_tensors, n, input_ids, TestMain("elementwise_grad_kernel_0", expressions, cpu_tensors, n, input_ids,
output_ids); output_ids);
...@@ -332,22 +330,25 @@ std::unique_ptr<paddle::framework::ir::Graph> BuildGraph( ...@@ -332,22 +330,25 @@ std::unique_ptr<paddle::framework::ir::Graph> BuildGraph(
// tmp_2@GRAD(13), x2@GRAD(14), x0@GRAD(15), // tmp_2@GRAD(13), x2@GRAD(14), x0@GRAD(15),
// x3@GRAD(16), x1@GRAD(17) // x3@GRAD(16), x1@GRAD(17)
paddle::framework::ir::Layers layers; paddle::framework::ir::Layers layers;
auto* x0 = layers.data("x0", {16, 32}); std::vector<int64_t> shape = {16, 32};
auto* x0 = layers.data("x0", shape);
auto* tmp_0 = layers.sigmoid(x0); auto* tmp_0 = layers.sigmoid(x0);
tmp_0->SetShape({16, 32}); auto* x1 = layers.data("x1", shape);
auto* x1 = layers.data("x1", {16, 32});
auto* tmp_1 = layers.elementwise_mul(tmp_0, x1); auto* tmp_1 = layers.elementwise_mul(tmp_0, x1);
tmp_1->SetShape({16, 32}); auto* x2 = layers.data("x2", shape);
auto* x2 = layers.data("x2", {16, 32});
auto* tmp_2 = layers.tanh(x2); auto* tmp_2 = layers.tanh(x2);
tmp_2->SetShape({16, 32}); auto* x3 = layers.data("x3", shape);
auto* x3 = layers.data("x3", {16, 32});
auto* tmp_3 = layers.elementwise_mul(x3, tmp_2); auto* tmp_3 = layers.elementwise_mul(x3, tmp_2);
tmp_3->SetShape({16, 32}); auto* tmp_4 = layers.elementwise_add(tmp_1, tmp_3);
layers.elementwise_add(tmp_1, tmp_3);
std::vector<paddle::framework::VarDesc*> elementwise_vars = {
tmp_0, tmp_1, tmp_2, tmp_3, tmp_4};
for (auto* var : elementwise_vars) {
var->SetShape(shape);
}
if (backward) { if (backward) {
layers.backward(); layers.backward({tmp_4});
} }
std::unique_ptr<paddle::framework::ir::Graph> graph( std::unique_ptr<paddle::framework::ir::Graph> graph(
......
...@@ -22,17 +22,15 @@ namespace paddle { ...@@ -22,17 +22,15 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void VisualizeGraph(std::unique_ptr<Graph> graph, std::string graph_viz_path) { void VisualizeGraph(std::unique_ptr<Graph>* graph, std::string graph_viz_path) {
// Insert a graph_viz_pass to transform the graph to a .dot file. // Insert a graph_viz_pass to transform the graph to a .dot file.
// It can be used for debug. // It can be used for debug.
auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass"); auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
graph_viz_pass->Set("graph_viz_path", new std::string(graph_viz_path)); graph_viz_pass->Set("graph_viz_path", new std::string(graph_viz_path));
graph.reset(graph_viz_pass->Apply(graph.release())); graph->reset(graph_viz_pass->Apply(graph->release()));
} }
TEST(FusionGroupPass, elementwise_list) { std::unique_ptr<Graph> BuildElementwiseListGraph(bool backward = false) {
fusion_group::OperationMap::Init();
// inputs operator output // inputs operator output
// -------------------------------------------------------- // --------------------------------------------------------
// (x, y) mul -> tmp_0 // (x, y) mul -> tmp_0
...@@ -42,34 +40,33 @@ TEST(FusionGroupPass, elementwise_list) { ...@@ -42,34 +40,33 @@ TEST(FusionGroupPass, elementwise_list) {
// //
// Expression: tmp_3 = relu(mul(x, y) + z) + w // Expression: tmp_3 = relu(mul(x, y) + z) + w
Layers layers; Layers layers;
std::vector<int64_t> shape = {16, 32};
auto* x = layers.data("x", {16, 16}); auto* x = layers.data("x", {16, 16});
auto* y = layers.data("y", {16, 32}); auto* y = layers.data("y", {16, 32});
auto* tmp_0 = layers.mul(x, y); auto* tmp_0 = layers.mul(x, y);
tmp_0->SetShape({16, 32}); auto* z = layers.data("z", shape);
auto* z = layers.data("z", {16, 32});
auto* tmp_1 = layers.elementwise_add(tmp_0, z); auto* tmp_1 = layers.elementwise_add(tmp_0, z);
auto* tmp_2 = layers.relu(tmp_1); auto* tmp_2 = layers.relu(tmp_1);
tmp_2->SetShape({16, 32}); auto* w = layers.data("w", shape);
auto* w = layers.data("w", {16, 32}); auto* tmp_3 = layers.elementwise_add(tmp_2, w);
layers.elementwise_add(tmp_2, w); std::vector<VarDesc*> elementwise_vars = {tmp_0, z, tmp_1, tmp_2, w, tmp_3};
for (auto* var : elementwise_vars) {
std::unique_ptr<Graph> graph(new Graph(layers.main_program())); var->SetShape(shape);
// VisualizeGraph(graph, "00_elementwise_list.dot"); }
auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass");
VLOG(3) << DebugString(graph);
graph.reset(fusion_group_pass->Apply(graph.release())); if (backward) {
// VisualizeGraph(graph, "01_elementwise_list.fusion_group.dot"); layers.backward({tmp_3});
int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group"); }
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_fusion_group_ops, 1); std::unique_ptr<Graph> graph(new Graph(layers.main_program()));
#ifdef __clang__
return graph;
#else
return std::move(graph);
#endif
} }
TEST(FusionGroupPass, elementwise_tree) { std::unique_ptr<Graph> BuildElementwiseTreeGraph(bool backward = false) {
fusion_group::OperationMap::Init();
// inputs operator output // inputs operator output
// -------------------------------------------------------- // --------------------------------------------------------
// (x0, y0) mul -> tmp_0 // (x0, y0) mul -> tmp_0
...@@ -88,53 +85,72 @@ TEST(FusionGroupPass, elementwise_tree) { ...@@ -88,53 +85,72 @@ TEST(FusionGroupPass, elementwise_tree) {
// tmp_9 = tanh(x4) * sigmoid(x5) // tmp_9 = tanh(x4) * sigmoid(x5)
// tmp_10 = mul(tmp_6, tmp_9) // tmp_10 = mul(tmp_6, tmp_9)
Layers layers; Layers layers;
std::vector<int64_t> shape = {16, 32};
auto* x0 = layers.data("x0", {16, 16}); auto* x0 = layers.data("x0", {16, 16});
auto* y0 = layers.data("y0", {16, 32}); auto* y0 = layers.data("y0", {16, 32});
auto* tmp_0 = layers.mul(x0, y0); auto* tmp_0 = layers.mul(x0, y0);
tmp_0->SetShape({16, 32}); auto* x1 = layers.data("x1", shape);
auto* x1 = layers.data("x1", {16, 32});
auto* tmp_1 = layers.sigmoid(x1); auto* tmp_1 = layers.sigmoid(x1);
tmp_1->SetShape({16, 32});
auto* tmp_2 = layers.elementwise_mul(tmp_0, tmp_1); auto* tmp_2 = layers.elementwise_mul(tmp_0, tmp_1);
tmp_2->SetShape({16, 32}); auto* x2 = layers.data("x2", shape);
auto* x2 = layers.data("x2", {16, 32});
auto* tmp_3 = layers.sigmoid(x2); auto* tmp_3 = layers.sigmoid(x2);
tmp_3->SetShape({16, 32}); auto* x3 = layers.data("x3", shape);
auto* x3 = layers.data("x3", {16, 32});
auto* tmp_4 = layers.tanh(x3); auto* tmp_4 = layers.tanh(x3);
tmp_4->SetShape({16, 32});
auto* tmp_5 = layers.elementwise_mul(tmp_3, tmp_4); auto* tmp_5 = layers.elementwise_mul(tmp_3, tmp_4);
tmp_5->SetShape({16, 32});
auto* tmp_6 = layers.elementwise_add(tmp_2, tmp_5); auto* tmp_6 = layers.elementwise_add(tmp_2, tmp_5);
tmp_6->SetShape({16, 32}); auto* x4 = layers.data("x4", shape);
auto* x4 = layers.data("x4", {16, 32});
auto* tmp_7 = layers.tanh(x4); auto* tmp_7 = layers.tanh(x4);
tmp_7->SetShape({16, 32}); auto* x5 = layers.data("x5", shape);
auto* x5 = layers.data("x5", {16, 32});
auto* tmp_8 = layers.sigmoid(x5); auto* tmp_8 = layers.sigmoid(x5);
tmp_8->SetShape({16, 32});
auto* tmp_9 = layers.elementwise_mul(tmp_7, tmp_8); auto* tmp_9 = layers.elementwise_mul(tmp_7, tmp_8);
tmp_9->SetShape({16, 32}); auto* tmp_10 = layers.mul(tmp_6, tmp_9);
layers.mul(tmp_6, tmp_9);
std::vector<VarDesc*> elementwise_vars = {tmp_0, tmp_1, tmp_2, tmp_3, tmp_4,
tmp_5, tmp_6, tmp_7, tmp_8, tmp_9};
for (auto* var : elementwise_vars) {
var->SetShape(shape);
}
if (backward) {
layers.backward({tmp_10});
}
std::unique_ptr<Graph> graph(new Graph(layers.main_program())); std::unique_ptr<Graph> graph(new Graph(layers.main_program()));
// VisualizeGraph(graph, "00_elementwise_tree.dot"); #ifdef __clang__
return graph;
#else
return std::move(graph);
#endif
}
auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass"); int TestMain(std::unique_ptr<Graph> graph, std::string prefix) {
// VisualizeGraph(&graph, prefix + ".dot");
auto pass = PassRegistry::Instance().Get("fusion_group_pass");
pass->Set("use_gpu", new bool(true));
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
graph.reset(fusion_group_pass->Apply(graph.release())); graph.reset(pass->Apply(graph.release()));
// VisualizeGraph(graph, "01_elementwise_tree.fusion_group.dot"); // VisualizeGraph(&graph, prefix + ".fusion_group.dot");
int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group"); int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_fusion_group_ops, 2); return num_fusion_group_ops;
}
TEST(FusionGroupPass, elementwise_list) {
fusion_group::OperationMap::Init();
std::unique_ptr<Graph> graph = BuildElementwiseListGraph(false);
int num_fusion_group_ops = TestMain(std::move(graph), "elementwise_list");
EXPECT_EQ(num_fusion_group_ops, 1);
}
TEST(FusionGroupPass, elementwise_tree) {
fusion_group::OperationMap::Init();
std::unique_ptr<Graph> graph = BuildElementwiseTreeGraph(false);
int num_fusion_group_ops = TestMain(std::move(graph), "elementwise_tree");
EXPECT_EQ(num_fusion_group_ops, 2);
} }
} // namespace ir } // namespace ir
......
...@@ -43,7 +43,11 @@ void OperationMap::Insert(int type, int num_operands, std::string op_type, ...@@ -43,7 +43,11 @@ void OperationMap::Insert(int type, int num_operands, std::string op_type,
std::vector<std::string> input_names, std::vector<std::string> input_names,
std::vector<std::string> output_names) { std::vector<std::string> output_names) {
Operation op(type, num_operands, op_type, {expr}, input_names, output_names); Operation op(type, num_operands, op_type, {expr}, input_names, output_names);
PADDLE_ENFORCE_EQ(op.IsValid(), true, "Operation %s is invalid.", op_type); PADDLE_ENFORCE_EQ(op.IsValid(), true,
platform::errors::InvalidArgument(
"Operation %s is invalid. Please set correct "
"expression for forward calculation.",
op_type));
operations_[op_type] = op; operations_[op_type] = op;
if (grad_exprs.size() > 0U) { if (grad_exprs.size() > 0U) {
...@@ -63,8 +67,11 @@ void OperationMap::Insert(int type, int num_operands, std::string op_type, ...@@ -63,8 +67,11 @@ void OperationMap::Insert(int type, int num_operands, std::string op_type,
} }
Operation grad_op(type, num_operands, grad_op_type, grad_exprs, Operation grad_op(type, num_operands, grad_op_type, grad_exprs,
grad_input_names, grad_output_names); grad_input_names, grad_output_names);
PADDLE_ENFORCE_EQ(grad_op.IsValid(), true, "Operation %s is invalid.", PADDLE_ENFORCE_EQ(grad_op.IsValid(), true,
grad_op_type); platform::errors::InvalidArgument(
"Operation %s is invalid. Please set correct "
"expression for backward calculation.",
grad_op_type));
operations_[grad_op_type] = grad_op; operations_[grad_op_type] = grad_op;
} }
} }
...@@ -83,8 +90,8 @@ void OperationMap::InsertUnaryElementwiseOperations() { ...@@ -83,8 +90,8 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// relu: // relu:
// out = f(x) = x > 0 ? x : 0 // out = f(x) = x > 0 ? x : 0
// dx = dout * (out > 0 ? 1 : 0) = dout * (x > 0 ? 1 : 0) // dx = dout * (out > 0 ? 1 : 0)
insert_handler("relu", "real_max(${0}, 0)", {"${0} > 0 ? ${2} : 0"}); insert_handler("relu", "real_max(${0}, 0)", {"${1} > 0 ? ${2} : 0"});
// sigmoid: // sigmoid:
// out = f(x) = 1.0 / (1.0 + exp(-x)) // out = f(x) = 1.0 / (1.0 + exp(-x))
// dx = dout * out * (1 - out) // dx = dout * out * (1 - out)
......
...@@ -70,7 +70,10 @@ class OperationMap { ...@@ -70,7 +70,10 @@ class OperationMap {
OperationMap(); OperationMap();
static OperationMap& Instance() { static OperationMap& Instance() {
PADDLE_ENFORCE_NOT_NULL(map, "Need to initialize OperationMap first!"); PADDLE_ENFORCE_NOT_NULL(
map, platform::errors::PreconditionNotMet(
"Please initialize OperationMap first, by calling "
"framework::fusion_group::OperationMap::Init()!"));
return *map; return *map;
} }
......
...@@ -270,9 +270,19 @@ struct Layers { ...@@ -270,9 +270,19 @@ struct Layers {
return outs; return outs;
} }
void backward() { void backward(std::vector<VarDesc*> targets) {
// This function is designed to simulate the structure of training program,
// but is constructed differently as the actual program.
BlockDesc* block = program_.MutableBlock(0); BlockDesc* block = program_.MutableBlock(0);
std::vector<OpDesc*> forward_ops = block->AllOps(); std::vector<OpDesc*> forward_ops = block->AllOps();
for (auto* var : targets) {
OpDesc* none_op = block->AppendOp();
none_op->SetType("none");
none_op->SetInput("X", {var->Name()});
VarDesc* grad_var =
lod_tensor(GradVarName(var->Name()), var->GetShape(), false);
none_op->SetOutput("Out", {grad_var->Name()});
}
for (int i = forward_ops.size() - 1; i >= 0; --i) { for (int i = forward_ops.size() - 1; i >= 0; --i) {
OpDesc* op = forward_ops[i]; OpDesc* op = forward_ops[i];
OpDesc* grad_op = block->AppendOp(); OpDesc* grad_op = block->AppendOp();
...@@ -428,8 +438,21 @@ static std::string DebugString(Node* node) { ...@@ -428,8 +438,21 @@ static std::string DebugString(Node* node) {
is_first = false; is_first = false;
} }
os << "}."; os << "}.";
} else if (node->IsVar() && node->Var()) { } else {
os << "Node(" << node->Name() << "), inputs:{"; os << "Node(" << node->Name();
if (node->IsVar() && node->Var()) {
os << "{";
bool is_first = true;
for (auto dim : node->Var()->GetShape()) {
if (!is_first) {
os << "x";
}
os << dim;
is_first = false;
}
os << "}";
}
os << "), inputs:{";
bool is_first = true; bool is_first = true;
for (auto* in : node->inputs) { for (auto* in : node->inputs) {
if (!is_first) { if (!is_first) {
...@@ -477,12 +500,16 @@ static std::string DebugString(const std::unordered_set<Node*>& nodes) { ...@@ -477,12 +500,16 @@ static std::string DebugString(const std::unordered_set<Node*>& nodes) {
return DebugString(vec); return DebugString(vec);
} }
static std::string DebugString(const std::unique_ptr<Graph>& graph) { static std::string DebugString(Graph* graph) {
std::ostringstream os; std::ostringstream os;
os << "Graph: {\n" << DebugString(graph->Nodes()) << "}\n"; os << "Graph: {\n" << DebugString(graph->Nodes()) << "}\n";
return os.str(); return os.str();
} }
static std::string DebugString(const std::unique_ptr<Graph>& graph) {
return DebugString(graph.get());
}
static int GetNumOpNodes(const std::unique_ptr<Graph>& graph, static int GetNumOpNodes(const std::unique_ptr<Graph>& graph,
std::string op_type) { std::string op_type) {
int num_nodes = 0; int num_nodes = 0;
......
...@@ -85,7 +85,8 @@ TEST(DeviceCode, cuda) { ...@@ -85,7 +85,8 @@ TEST(DeviceCode, cuda) {
} }
TEST(DeviceCodePool, cuda) { TEST(DeviceCodePool, cuda) {
if (!paddle::platform::dynload::HasNVRTC()) { if (!paddle::platform::dynload::HasNVRTC() ||
!paddle::platform::dynload::HasCUDADriver()) {
return; return;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册