提交 27573ece 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: trailing spaces removed

上级 7f5c8a95
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional>
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle { namespace paddle {
...@@ -8,15 +24,14 @@ namespace patterns { ...@@ -8,15 +24,14 @@ namespace patterns {
struct Pattern : public PatternBase { struct Pattern : public PatternBase {
Pattern(PDPattern* pattern, const std::string& name_scope) Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase{pattern, name_scope, ""} : PatternBase{pattern, name_scope, ""} {}
{ }
private:
private:
std::string name_scope() { return name_scope_; } std::string name_scope() { return name_scope_; }
std::string repr() { return repr_; } std::string repr() { return repr_; }
size_t id() { return id_; } size_t id() { return id_; }
PDPattern* node_pattern() { return pattern; } PDPattern* node_pattern() { return pattern; }
public: public:
std::string node_name(std::string op_name) { std::string node_name(std::string op_name) {
return PDNodeName(name_scope(), repr(), id(), op_name); return PDNodeName(name_scope(), repr(), id(), op_name);
...@@ -37,22 +52,18 @@ struct Conv { ...@@ -37,22 +52,18 @@ struct Conv {
std::string filter_name() { return "Filter"; } std::string filter_name() { return "Filter"; }
std::string output_name() { return "Output"; } std::string output_name() { return "Output"; }
std::function<PDNode* ()> operator()(std::shared_ptr<Pattern> pattern) { std::function<PDNode*()> operator()(std::shared_ptr<Pattern> pattern) {
return [&]() -> PDNode* { return [&]() -> PDNode* {
auto conv_op = pattern->new_node(op_name()) auto conv_op = pattern->new_node(op_name())->assert_is_op("conv2d");
->assert_is_op("conv2d");
auto input_var = pattern->new_node(input_name()) auto input_var = pattern->new_node(input_name())
->assert_is_op_input(op_name(), ->assert_is_op_input(op_name(), input_name());
input_name());
auto filter_var = pattern->new_node(filter_name()) auto filter_var = pattern->new_node(filter_name())
->assert_is_op_input(op_name(), ->assert_is_op_input(op_name(), filter_name());
filter_name());
auto output_var = pattern->new_node(output_name()) auto output_var = pattern->new_node(output_name())
->assert_is_op_output(op_name(), ->assert_is_op_output(op_name(), output_name());
output_name());
conv_op->LinksFrom({input_var, filter_var}); conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var}); conv_op->LinksTo({output_var});
...@@ -68,22 +79,19 @@ struct ElementwiseAdd { ...@@ -68,22 +79,19 @@ struct ElementwiseAdd {
std::string y_name() { return "Y"; } std::string y_name() { return "Y"; }
std::string out_name() { return "Out"; } std::string out_name() { return "Out"; }
std::function<PDNode* (PDNode*)> operator()(std::shared_ptr<Pattern> pattern) { std::function<PDNode*(PDNode*)> operator()(std::shared_ptr<Pattern> pattern) {
return [&](PDNode* conv_output) -> PDNode* { return [&](PDNode* conv_output) -> PDNode* {
auto elementwise_add_op = pattern->new_node(op_name()) auto elementwise_add_op =
->assert_is_op("elementwise_add"); pattern->new_node(op_name())->assert_is_op("elementwise_add");
auto x_var =
pattern->new_node(x_name())->assert_is_op_input(op_name(), x_name());
auto x_var = pattern->new_node(x_name()) conv_output->assert_is_op_input(op_name(), y_name());
->assert_is_op_input(op_name(),
x_name());
conv_output->assert_is_op_input(op_name(),
y_name());
auto out_var = pattern->new_node(out_name()) auto out_var = pattern->new_node(out_name())
->AsOutput() ->AsOutput()
->assert_is_op_output(op_name(), ->assert_is_op_output(op_name(), out_name());
out_name());
elementwise_add_op->LinksFrom({x_var, conv_output}); elementwise_add_op->LinksFrom({x_var, conv_output});
elementwise_add_op->LinksTo({out_var}); elementwise_add_op->LinksTo({out_var});
...@@ -94,13 +102,13 @@ struct ElementwiseAdd { ...@@ -94,13 +102,13 @@ struct ElementwiseAdd {
}; };
Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph, Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
std::shared_ptr<patterns::Pattern> pattern, std::shared_ptr<patterns::Pattern> pattern,
const std::string& op_name) { const std::string& op_name) {
PADDLE_ENFORCE(subgraph.count(pattern->retrieve_node(op_name)), PADDLE_ENFORCE(subgraph.count(pattern->retrieve_node(op_name)),
"Node not found for PDNode %s", pattern->node_name(op_name)); "Node not found for PDNode %s", pattern->node_name(op_name));
Node* var = subgraph.at(pattern->retrieve_node(op_name)); Node* var = subgraph.at(pattern->retrieve_node(op_name));
PADDLE_ENFORCE(var, "node %s not exists in the sub-graph"); PADDLE_ENFORCE(var, "node %s not exists in the sub-graph");
return var; return var;
} }
...@@ -109,10 +117,9 @@ void LinkNodes(Node* from, Node* to) { ...@@ -109,10 +117,9 @@ void LinkNodes(Node* from, Node* to) {
to->inputs.push_back(from); to->inputs.push_back(from);
} }
template<typename IT, typename FindFunc, typename ReplaceFunc> template <typename IT, typename FindFunc, typename ReplaceFunc>
void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) { void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
if (s == e) if (s == e) return;
return;
auto it = std::find_if(s, e, f); auto it = std::find_if(s, e, f);
...@@ -126,8 +133,7 @@ void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) { ...@@ -126,8 +133,7 @@ void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) { void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
for (auto& node : GraphTraits::DFS(*graph)) { for (auto& node : GraphTraits::DFS(*graph)) {
auto same = std::find_if(std::begin(node.inputs), auto same = std::find_if(std::begin(node.inputs), std::end(node.inputs),
std::end(node.inputs),
[from](Node* n) { return n == from; }); [from](Node* n) { return n == from; });
if (same != std::end(node.inputs)) { if (same != std::end(node.inputs)) {
...@@ -137,17 +143,19 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) { ...@@ -137,17 +143,19 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
using input_type = VariableNameMap::value_type; using input_type = VariableNameMap::value_type;
ReplaceAllOccurances(std::begin(inputs), std::end(inputs), ReplaceAllOccurances(
[from](const input_type& i) -> bool { std::begin(inputs), std::end(inputs),
auto params = i.second; [from](const input_type& i) -> bool {
auto pi = std::find_if(std::begin(params), std::end(params), auto params = i.second;
std::bind(std::equal_to<std::string>(), auto pi =
from->Name(), std::placeholders::_1)); std::find_if(std::begin(params), std::end(params),
return pi != std::end(params); std::bind(std::equal_to<std::string>(),
}, from->Name(), std::placeholders::_1));
[to, &node](const input_type& i) { return pi != std::end(params);
node.Op()->SetInput(i.first, {to->Name()}); },
}); [to, &node](const input_type& i) {
node.Op()->SetInput(i.first, {to->Name()});
});
} }
} }
} }
...@@ -169,7 +177,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -169,7 +177,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* conv_output, Node* elementwise_add_x) { auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter,
Node* conv_output, Node* elementwise_add_x) {
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType("conv2d"); op_desc.SetType("conv2d");
...@@ -189,22 +198,23 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -189,22 +198,23 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
patterns::LinkNodes(fused_conv_op, conv_output); patterns::LinkNodes(fused_conv_op, conv_output);
}; };
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto conv_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto conv_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.op_name()); conv_pattern.op_name());
auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.input_name()); conv_pattern.input_name());
auto conv_filter = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto conv_filter = patterns::GetNodeFromSubgraph(
conv_pattern.filter_name()); subgraph, pattern_ptr, conv_pattern.filter_name());
auto conv_output = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto conv_output = patterns::GetNodeFromSubgraph(
conv_pattern.output_name()); subgraph, pattern_ptr, conv_pattern.output_name());
auto elementwise_add_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto elementwise_add_op = patterns::GetNodeFromSubgraph(
elementwise_add_pattern.op_name()); subgraph, pattern_ptr, elementwise_add_pattern.op_name());
auto elementwise_add_x = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto elementwise_add_x = patterns::GetNodeFromSubgraph(
elementwise_add_pattern.x_name()); subgraph, pattern_ptr, elementwise_add_pattern.x_name());
auto elementwise_add_out = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto elementwise_add_out = patterns::GetNodeFromSubgraph(
elementwise_add_pattern.out_name()); subgraph, pattern_ptr, elementwise_add_pattern.out_name());
fuse_conv(g, conv_input, conv_filter, conv_output, elementwise_add_x); fuse_conv(g, conv_input, conv_filter, conv_output, elementwise_add_x);
patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output); patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output);
...@@ -219,4 +229,5 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -219,4 +229,5 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass); REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h" // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#include "paddle/fluid/framework/ir/graph_traits.h" //
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -33,10 +47,11 @@ void SetOp(ProgramDesc* prog, const std::string& type, ...@@ -33,10 +47,11 @@ void SetOp(ProgramDesc* prog, const std::string& type,
} }
struct IsReachable { struct IsReachable {
using func = std::function<bool (const std::string&, const std::string&)>; using func = std::function<bool(const std::string&, const std::string&)>;
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func { auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
auto find_node = [](const std::unique_ptr<ir::Graph>& graph, const std::string& name) -> Node* { auto find_node = [](const std::unique_ptr<ir::Graph>& graph,
const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(*graph)) { for (auto& node : GraphTraits::DFS(*graph)) {
if (name == node.Name()) { if (name == node.Name()) {
return &node; return &node;
...@@ -47,8 +62,7 @@ struct IsReachable { ...@@ -47,8 +62,7 @@ struct IsReachable {
}; };
return [&](std::string from, const std::string to) -> bool { return [&](std::string from, const std::string to) -> bool {
if (from == to) if (from == to) return true;
return true;
std::map<std::string, bool> visited; std::map<std::string, bool> visited;
...@@ -61,16 +75,14 @@ struct IsReachable { ...@@ -61,16 +75,14 @@ struct IsReachable {
std::list<std::string> queue; std::list<std::string> queue;
queue.push_back(from); queue.push_back(from);
while(!queue.empty()) { while (!queue.empty()) {
auto cur = find_node(graph, queue.front()); auto cur = find_node(graph, queue.front());
queue.pop_front(); queue.pop_front();
if (cur == nullptr) if (cur == nullptr) return false;
return false;
for (auto n : cur->outputs) { for (auto n : cur->outputs) {
if (n->Name() == to) if (n->Name() == to) return true;
return true;
if (!visited[n->Name()]) { if (!visited[n->Name()]) {
visited[n->Name()] = true; visited[n->Name()] = true;
...@@ -87,14 +99,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { ...@@ -87,14 +99,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto build_program_desc = [&]() -> ProgramDesc { auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : for (auto& v :
std::vector<std::string>({"a", "b", "weights", "c", "d", "e"})) { std::vector<std::string>({"a", "b", "weights", "c", "d", "e"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR); var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights") { if (v == "weights") {
var->SetPersistable(true); var->SetPersistable(true);
} }
} }
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"}); SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"}); SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
SetOp(&prog, "relu", {"d"}, {"e"}); SetOp(&prog, "relu", {"d"}, {"e"});
...@@ -109,14 +121,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { ...@@ -109,14 +121,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
EXPECT_TRUE(is_reachable(graph)("a", "relu")); EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu")); EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num); EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
// Assert conv_relu op in newly generated graph // Assert conv_relu op in newly generated graph
int conv_count = 0; int conv_count = 0;
int elementwise_add_count = 0; int elementwise_add_count = 0;
...@@ -136,15 +150,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { ...@@ -136,15 +150,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
auto build_program_desc = [&]() -> ProgramDesc { auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : for (auto& v : std::vector<std::string>({"a", "b", "weights"})) {
std::vector<std::string>({"a", "b", "weights"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR); var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights" || v == "bias") { if (v == "weights" || v == "bias") {
var->SetPersistable(true); var->SetPersistable(true);
} }
} }
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"}); SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"}); SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
...@@ -157,14 +170,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { ...@@ -157,14 +170,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
IsReachable is_reachable; IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "d")); EXPECT_TRUE(is_reachable(graph)("a", "d"));
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
EXPECT_FALSE(is_reachable(graph)("a", "d")); EXPECT_FALSE(is_reachable(graph)("a", "d"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num); EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
// Assert conv_relu op in newly generated graph // Assert conv_relu op in newly generated graph
int conv_count = 0; int conv_count = 0;
int elementwise_add_count = 0; int elementwise_add_count = 0;
...@@ -185,14 +200,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { ...@@ -185,14 +200,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
auto build_program_desc = [&]() -> ProgramDesc { auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : for (auto& v :
std::vector<std::string>({"a", "b", "weights", "c", "d", "e", "f"})) { std::vector<std::string>({"a", "b", "weights", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR); var->SetType(proto::VarType::LOD_TENSOR);
if (v.find("weights")) { if (v.find("weights")) {
var->SetPersistable(true); var->SetPersistable(true);
} }
} }
SetOp(&prog, "sigmoid", {"a"}, {"b"}); SetOp(&prog, "sigmoid", {"a"}, {"b"});
SetOp(&prog, "conv2d", {"b", "weights"}, {"c"}); SetOp(&prog, "conv2d", {"b", "weights"}, {"c"});
SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"}); SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"});
...@@ -208,14 +223,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { ...@@ -208,14 +223,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
EXPECT_TRUE(is_reachable(graph)("a", "f")); EXPECT_TRUE(is_reachable(graph)("a", "f"));
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "f")); EXPECT_TRUE(is_reachable(graph)("a", "f"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num); EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
// Assert conv_relu op in newly generated graph // Assert conv_relu op in newly generated graph
int conv_count = 0; int conv_count = 0;
int elementwise_add_count = 0; int elementwise_add_count = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册