提交 46677fb0 编写于 作者: W Wojciech Uss

Move cpu_quantize_* passes into mkldnn subfolder

test=develop
上级 de3b70a1
......@@ -46,9 +46,6 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
pass_library(cpu_quantize_placement_pass base)
pass_library(cpu_quantize_pass inference)
pass_library(cpu_quantize_squash_pass inference)
pass_library(fc_fuse_pass inference)
pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference)
......@@ -87,6 +84,9 @@ if(WITH_MKLDNN)
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn)
pass_library(cpu_quantize_pass inference mkldnn)
pass_library(cpu_quantize_squash_pass inference mkldnn)
endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
......@@ -105,9 +105,6 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_cpu_quantize_placement_pass SRCS cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
if(NOT WIN32)
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
endif()
......@@ -117,4 +114,7 @@ if (WITH_MKLDNN)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
endif ()
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cpu_quantize_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cpu_quantize_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h"
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/cpu_quantize_placement_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h"
#include <string>
#include <unordered_set>
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cpu_quantize_placement_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h"
#include <gtest/gtest.h>
#include <boost/logic/tribool.hpp>
......
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cpu_quantize_squash_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cpu_quantize_squash_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h"
......
......@@ -41,8 +41,11 @@ namespace inference {
namespace analysis {
using framework::ir::Graph;
#ifdef PADDLE_WITH_MKLDNN
using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, framework::LoDTensor>>;
#endif
/*
* The argument definition of both Pass and PassManagers.
......@@ -132,6 +135,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types, MKLDNNEnabledOpTypes,
std::unordered_set<std::string>);
#ifdef PADDLE_WITH_MKLDNN
// A set of op types to enable their quantized kernels
DECL_ARGUMENT_FIELD(quantize_enabled_op_types, QuantizeEnabledOpTypes,
std::unordered_set<std::string>);
......@@ -142,6 +146,7 @@ struct Argument {
// Scales for variables to be quantized
DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale);
#endif
// Passed from config.
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
......
......@@ -61,6 +61,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(
argument->mkldnn_enabled_op_types()));
#ifdef PADDLE_WITH_MKLDNN
} else if (pass_name == "cpu_quantize_placement_pass") {
pass->Set("quantize_enabled_op_types",
new std::unordered_set<std::string>(
......@@ -71,6 +72,7 @@ void IRPassManager::CreatePasses(Argument *argument,
} else if (pass_name == "cpu_quantize_pass") {
pass->Set("quant_var_scales",
new VarQuantScale(argument->quant_var_scales()));
#endif
} else if (pass_name == "tensorrt_subgraph_pass") {
pass->Set("workspace_size", new int(argument->tensorrt_workspace_size()));
pass->Set("max_batch_size", new int(argument->tensorrt_max_batch_size()));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册