diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index 09bec347dbd569203103eccc7dbc0521c291bc0a..fb899e3d7cd4224acd25a559d0e18a09f552ad7d 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -44,9 +44,9 @@ if(WIN32) set(CUDNN_LIB_NAME "cudnn.lib" "cudnn64_7.dll") endif(WIN32) -if(Apple) +if(APPLE) set(CUDNN_LIB_NAME "libcudnn.dylib" "libcudnn.so") -endif(Apple) +endif(APPLE) find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} # libcudnn_static.a PATHS ${CUDNN_CHECK_LIBRARY_DIRS} ${CUDNN_INCLUDE_DIR} ${__libpath_hist} diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index e3b44499258d288fe5692ca23efe1c4ec234f75c..b6974c6af290438f827c16bb478eb43e3cf42247 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -351,6 +351,23 @@ paddle.fluid.contrib.QuantizeTranspiler.__init__ ArgSpec(args=['self', 'weight_b paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.contrib.QuantizeTranspiler.freeze_program ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.contrib.QuantizeTranspiler.training_transpile ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)) +paddle.fluid.contrib.build_compressor ArgSpec(args=['place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'config'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None)) +paddle.fluid.contrib.CompressPass.__init__ ArgSpec(args=['self', 'place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'program_exe'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None)) +paddle.fluid.contrib.CompressPass.add_strategy ArgSpec(args=['self', 'strategy'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.CompressPass.apply ArgSpec(args=['self', 'graph'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.ImitationGraph.__init__ ArgSpec(args=['self', 'program'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.contrib.ImitationGraph.all_parameters ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.SensitivePruneStrategy.__init__ ArgSpec(args=['self', 'pruner', 'start_epoch', 'end_epoch', 'delta_rate', 'acc_loss_threshold', 'sensitivities'], varargs=None, keywords=None, defaults=(None, 0, 10, 0.2, 0.2, None)) +paddle.fluid.contrib.SensitivePruneStrategy.on_batch_begin ArgSpec(args=['self', 'context'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.SensitivePruneStrategy.on_batch_end ArgSpec(args=['self', 'context'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.SensitivePruneStrategy.on_compress_begin ArgSpec(args=['self', 'context'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.SensitivePruneStrategy.on_compress_end ArgSpec(args=['self', 'context'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.SensitivePruneStrategy.on_epoch_begin ArgSpec(args=['self', 'context'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.SensitivePruneStrategy.on_epoch_end ArgSpec(args=['self', 'context'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.MagnitudePruner.__init__ ArgSpec(args=['self', 'threshold'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.MagnitudePruner.prune ArgSpec(args=['self', 'param', 'threshold'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.contrib.RatioPruner.__init__ ArgSpec(args=['self', 'ratios'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.contrib.RatioPruner.prune ArgSpec(args=['self', 'param', 'ratio'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.contrib.load_persistables_for_increment ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var', 'lookup_table_var_path'], varargs=None, keywords=None, defaults=None) paddle.fluid.contrib.load_persistables_for_inference ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var_name'], varargs=None, keywords=None, defaults=None) paddle.fluid.contrib.convert_dist_to_sparse_program ArgSpec(args=['program'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/framework/ngraph_bridge.cc b/paddle/fluid/framework/ngraph_bridge.cc index 5fcb17b9f3ac390548aba33db7d0b8350cde7e00..42190b52289bfc6fc510f13cb5190a0d3e03b836 100644 --- a/paddle/fluid/framework/ngraph_bridge.cc +++ b/paddle/fluid/framework/ngraph_bridge.cc @@ -31,10 +31,12 @@ std::map>>)>> NgraphBridge::NG_NODE_MAP = { + {"fill_constant", paddle::operators::ngraphs::BuildFillConstantNode}, {"mul", paddle::operators::ngraphs::BuildMulNode}, {"mul_grad", paddle::operators::ngraphs::BuildMulGradNode}, {"relu", paddle::operators::ngraphs::BuildUnaryNode}, - {"tanh", paddle::operators::ngraphs::BuildUnaryNode}}; + {"tanh", paddle::operators::ngraphs::BuildUnaryNode}, + {"top_k", paddle::operators::ngraphs::BuildTopKNode}}; void NgraphBridge::BuildNgNode(const std::shared_ptr& op) { auto& op_type = op->Type(); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index c751e8515829d06970c55f097f50de8bf33ee2a4..3937884ce4a5a16a1093ac8977033eaa98b2678e 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -231,11 +231,14 @@ bool AnalysisPredictor::SetFeed(const std::vector &inputs, inputs[i].data.length()); } else { #ifdef PADDLE_WITH_CUDA + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto *dev_ctx = + static_cast(pool.Get(place_)); auto dst_gpu_place = boost::get(place_); memory::Copy(dst_gpu_place, static_cast(input_ptr), platform::CPUPlace(), inputs[i].data.data(), - inputs[i].data.length(), - 0); // stream 0 for sync copy + inputs[i].data.length(), dev_ctx->stream()); #else PADDLE_THROW("Not compile with CUDA, should not reach here."); #endif diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 3d121e046004dfe6fc6953e0b23852b9ecda5c1b..102147a493ed1454db1a78124200f163f68e555b 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -208,11 +208,14 @@ bool NativePaddlePredictor::SetFeed(const std::vector &inputs, inputs[i].data.length()); } else { #ifdef PADDLE_WITH_CUDA + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto *dev_ctx = + static_cast(pool.Get(place_)); auto dst_gpu_place = boost::get(place_); memory::Copy(dst_gpu_place, static_cast(input_ptr), platform::CPUPlace(), inputs[i].data.data(), - inputs[i].data.length(), - 0); // stream 0 for sync copy + inputs[i].data.length(), dev_ctx->stream()); #else PADDLE_THROW("Not compile with CUDA, should not reach here."); #endif diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 46ce61b73611d05369f90e7d8f97e9b6724b860f..95bbc74a5961eb28a0d8fbd7c680c0740fc68d8a 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -75,6 +75,11 @@ set(LAC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/lac") download_model_and_data(${LAC_INSTALL_DIR} "lac_model.tar.gz" "lac_data.txt.tar.gz") inference_analysis_api_test(test_analyzer_lac ${LAC_INSTALL_DIR} analyzer_lac_tester.cc) +# MM DNN +set(MM_DNN_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/mm_dnn") +download_model_and_data(${MM_DNN_INSTALL_DIR} "MM_DNN_model.tar.gz" "MM_DNN_data.txt.tar.gz") +inference_analysis_api_test(test_analyzer_mm_dnn ${MM_DNN_INSTALL_DIR} analyzer_mm_dnn_tester.cc) + # text_classification set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classification") download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz") diff --git a/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc b/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..8aaab6d6649e1d4b6db7695df0e9dd219c89422c --- /dev/null +++ b/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc @@ -0,0 +1,178 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { +using contrib::AnalysisConfig; + +struct DataRecord { + std::vector> query_data_all, title_data_all; + std::vector lod1, lod2; + size_t batch_iter{0}; + size_t batch_size{1}; + size_t num_samples; // total number of samples + DataRecord() = default; + explicit DataRecord(const std::string &path, int batch_size = 1) + : batch_size(batch_size) { + Load(path); + } + DataRecord NextBatch() { + DataRecord data; + size_t batch_end = batch_iter + batch_size; + // NOTE skip the final batch, if no enough data is provided. + if (batch_end <= query_data_all.size()) { + data.query_data_all.assign(query_data_all.begin() + batch_iter, + query_data_all.begin() + batch_end); + data.title_data_all.assign(title_data_all.begin() + batch_iter, + title_data_all.begin() + batch_end); + // Prepare LoDs + data.lod1.push_back(0); + data.lod2.push_back(0); + CHECK(!data.query_data_all.empty()); + CHECK(!data.title_data_all.empty()); + CHECK_EQ(data.query_data_all.size(), data.title_data_all.size()); + for (size_t j = 0; j < data.query_data_all.size(); j++) { + // calculate lod + data.lod1.push_back(data.lod1.back() + data.query_data_all[j].size()); + data.lod2.push_back(data.lod2.back() + data.title_data_all[j].size()); + } + } + batch_iter += batch_size; + return data; + } + void Load(const std::string &path) { + std::ifstream file(path); + std::string line; + int num_lines = 0; + while (std::getline(file, line)) { + num_lines++; + std::vector data; + split(line, '\t', &data); + // load query data + std::vector query_data; + split_to_int64(data[0], ' ', &query_data); + // load title data + std::vector title_data; + split_to_int64(data[1], ' ', &title_data); + query_data_all.push_back(std::move(query_data)); + title_data_all.push_back(std::move(title_data)); + } + num_samples = num_lines; + } +}; + +void PrepareInputs(std::vector *input_slots, DataRecord *data, + int batch_size) { + PaddleTensor lod_query_tensor, lod_title_tensor; + lod_query_tensor.name = "left"; + lod_title_tensor.name = "right"; + auto one_batch = data->NextBatch(); + int size1 = one_batch.lod1[one_batch.lod1.size() - 1]; // token batch size + int size2 = one_batch.lod2[one_batch.lod2.size() - 1]; // token batch size + lod_query_tensor.shape.assign({size1, 1}); + lod_query_tensor.lod.assign({one_batch.lod1}); + lod_title_tensor.shape.assign({size2, 1}); + lod_title_tensor.lod.assign({one_batch.lod2}); + // assign data + TensorAssignData(&lod_query_tensor, one_batch.query_data_all); + TensorAssignData(&lod_title_tensor, one_batch.title_data_all); + // Set inputs. + input_slots->assign({lod_query_tensor, lod_title_tensor}); + for (auto &tensor : *input_slots) { + tensor.dtype = PaddleDType::INT64; + } +} + +void SetConfig(contrib::AnalysisConfig *cfg) { + cfg->model_dir = FLAGS_infer_model; + cfg->use_gpu = false; + cfg->device = 0; + cfg->specify_input_name = true; + cfg->enable_ir_optim = true; +} + +void SetInput(std::vector> *inputs) { + DataRecord data(FLAGS_infer_data, FLAGS_batch_size); + std::vector input_slots; + int epoch = FLAGS_test_all_data ? data.num_samples / FLAGS_batch_size : 1; + LOG(INFO) << "number of samples: " << epoch * FLAGS_batch_size; + for (int bid = 0; bid < epoch; ++bid) { + PrepareInputs(&input_slots, &data, FLAGS_batch_size); + (*inputs).emplace_back(input_slots); + } +} + +// Easy for profiling independently. +TEST(Analyzer_MM_DNN, profile) { + contrib::AnalysisConfig cfg; + SetConfig(&cfg); + std::vector outputs; + + std::vector> input_slots_all; + SetInput(&input_slots_all); + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, FLAGS_num_threads); + + if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { + PADDLE_ENFORCE_EQ(outputs.size(), 2UL); + for (auto &output : outputs) { + size_t size = GetSize(output); + PADDLE_ENFORCE_GT(size, 0); + float *result = static_cast(output.data.data()); + // output is probability, which is in (-1, 1). + for (size_t i = 0; i < size; i++) { + EXPECT_GT(result[i], -1); + EXPECT_LT(result[i], 1); + } + } + } +} + +// Check the fuse status +TEST(Analyzer_MM_DNN, fuse_statis) { + contrib::AnalysisConfig cfg; + SetConfig(&cfg); + + int num_ops; + auto predictor = CreatePaddlePredictor(cfg); + auto fuse_statis = GetFuseStatis( + static_cast(predictor.get()), &num_ops); +} + +// Compare result of NativeConfig and AnalysisConfig +TEST(Analyzer_MM_DNN, compare) { + contrib::AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + SetInput(&input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); +} + +// Compare Deterministic result +TEST(Analyzer_MM_DNN, compare_determine) { + AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + SetInput(&input_slots_all); + CompareDeterministic(reinterpret_cast(&cfg), + input_slots_all); +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/dequantize_mkldnn_op.cc b/paddle/fluid/operators/dequantize_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..262b7408a7f5f65c4d97120914c16f38ce5fdbe7 --- /dev/null +++ b/paddle/fluid/operators/dequantize_mkldnn_op.cc @@ -0,0 +1,88 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/framework/data_layout_transform.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/dequantize_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using mkldnn::memory; +using mkldnn::primitive; +using mkldnn::reorder; +using platform::to_void_cast; +using Tensor = framework::Tensor; +using framework::DataLayout; +using mkldnn::stream; +using platform::GetMKLDNNFormat; + +template +class DeQuantOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto scale_data = ctx.Attr("Scale"); + auto* output = ctx.Output("Output"); + auto& dev_ctx = + ctx.template device_context(); + const auto& engine = dev_ctx.GetEngine(); + + const T* input_data = input->data(); + float* output_data = output->mutable_data(ctx.GetPlace()); + std::vector reorder_scale = {1.0f / scale_data}; + + std::vector pipeline; + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + mkldnn::memory::data_type src_dt = + paddle::framework::ToMKLDNNDataType(input->type()); + mkldnn::memory::format src_fmt = input->format(); + + mkldnn::primitive_attr attri; + int mask = 0; + attri.set_output_scales(mask, reorder_scale); + + auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt); + auto src_pd = mkldnn::memory::primitive_desc(src_md, engine); + auto src_memory = + std::make_shared(src_pd, to_void_cast(input_data)); + std::shared_ptr src_memory_p = + std::shared_ptr(new primitive::at(*src_memory)); + + auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32, + memory::format::nchw); + auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine); + auto dst_memory = mkldnn::memory(dst_pd, to_void_cast(output_data)); + + auto reorder_pd = std::shared_ptr( + new reorder::primitive_desc(src_pd, dst_pd, attri)); + auto reorder_p = std::shared_ptr( + new reorder(*reorder_pd, *src_memory_p, dst_memory)); + pipeline.push_back(*reorder_p); + stream(stream::kind::eager).submit(pipeline).wait(); + + output->set_format(GetMKLDNNFormat(dst_memory)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, + ops::DeQuantOpKernel, ops::DeQuantOpKernel); diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..38159f84a0d56f45cfef233a3c70c3c6cef17d9f --- /dev/null +++ b/paddle/fluid/operators/dequantize_op.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/dequantize_op.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +framework::OpKernelType DeQuantOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library_ = framework::LibraryType::kMKLDNN; + framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; + + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.GetPlace(), layout_, library_); +} + +void DeQuantOpMaker::Make() { + AddInput("Input", "input data"); + AddOutput("Output", "output data"); + AddAttr("Scale", "scale data").SetDefault({1.0f}); + AddComment(R"DOC(This op will dequantize data from INT8 to FP32)DOC"); +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, + paddle::framework::DefaultGradOpDescMaker); diff --git a/paddle/fluid/operators/dequantize_op.h b/paddle/fluid/operators/dequantize_op.h new file mode 100644 index 0000000000000000000000000000000000000000..75c27a06c210f2d0e4d7cf52aa16f4c123f8ad8e --- /dev/null +++ b/paddle/fluid/operators/dequantize_op.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class DeQuantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); + ctx->ShareLoD("Input", /*->*/ "Output"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class DeQuantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +class DeQuantGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detection/density_prior_box_op.cu b/paddle/fluid/operators/detection/density_prior_box_op.cu index 6a92762896b89a06a91cd11fb38587f7df69e6c3..acd5993154ed03f206f20082231feb5059ef32e1 100644 --- a/paddle/fluid/operators/detection/density_prior_box_op.cu +++ b/paddle/fluid/operators/detection/density_prior_box_op.cu @@ -142,12 +142,13 @@ class DensityPriorBoxOpCUDAKernel : public framework::OpKernel { vars->mutable_data(ctx.GetPlace()); framework::Tensor d_temp; - framework::TensorCopySync(h_temp, ctx.GetPlace(), &d_temp); + framework::TensorCopy(h_temp, ctx.GetPlace(), &d_temp); // At least use 32 threads, at most 512 threads. // blockx is multiple of 32. int blockx = std::min( - static_cast(((feature_width * num_priors + 31) >> 5) << 5), 512L); + static_cast(((feature_width * num_priors + 31) >> 5) << 5), + 512L); int gridx = (feature_width * num_priors + blockx - 1) / blockx; dim3 threads(blockx, 1); dim3 grids(gridx, feature_height); diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 3eba268cfa9712e4bc5475dd44076bc768552bce..1a11b584e2bab7eeb395bf391da080ec0ba62ae4 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include @@ -252,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT* blas, template struct MergeAdd { framework::SelectedRows operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input) { + const framework::SelectedRows& input, + const bool sorted_result = false) { framework::SelectedRows out; - (*this)(context, input, &out); + (*this)(context, input, &out, sorted_result); return out; } void operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input, - framework::SelectedRows* output) { + framework::SelectedRows* output, + const bool sorted_result = false) { std::vector inputs; inputs.push_back(&input); - (*this)(context, inputs, output); + (*this)(context, inputs, output, sorted_result); } void operator()(const platform::CPUDeviceContext& context, const std::vector& inputs, - framework::SelectedRows* output) { + framework::SelectedRows* output, + const bool sorted_result = false) { if (inputs.size() == 0) { VLOG(3) << "no input! return"; return; @@ -301,6 +305,9 @@ struct MergeAdd { } std::vector merge_rows(merged_row_set.begin(), merged_row_set.end()); + if (sorted_result) { + std::sort(merge_rows.begin(), merge_rows.end()); + } std::unordered_map rows_to_id; for (size_t i = 0; i < merge_rows.size(); ++i) { rows_to_id[merge_rows[i]] = i; diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index c4fccdbf862fda8a599869c30ae598573ca367aa..0d63f641c8670f8629c52b9e5fc380a250d80dd7 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -266,7 +266,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, template struct MergeAdd { framework::SelectedRows operator()(const platform::CUDADeviceContext& context, - const framework::SelectedRows& input) { + const framework::SelectedRows& input, + const bool sorted_result = false) { framework::SelectedRows out; (*this)(context, input, &out); return out; @@ -274,7 +275,8 @@ struct MergeAdd { void operator()(const platform::CUDADeviceContext& context, const framework::SelectedRows& input, - framework::SelectedRows* output) { + framework::SelectedRows* output, + const bool sorted_result = false) { framework::Vector input_rows(input.rows()); if (input_rows.size() == 0) { return; @@ -312,7 +314,8 @@ struct MergeAdd { void operator()(const platform::CUDADeviceContext& context, const std::vector& inputs, - framework::SelectedRows* output) { + framework::SelectedRows* output, + const bool sorted_result = false) { if (inputs.size() == 0) { VLOG(3) << "no input! return"; return; diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index 6d146d39d6d07678e859b82b25ba60ed7661546d..222d761ef91d8aee4843d717dabba7edf131f8dc 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -81,13 +81,16 @@ struct MergeAdd { // unary functor, merge by adding duplicated rows in // the input SelectedRows object. framework::SelectedRows operator()(const DeviceContext& context, - const framework::SelectedRows& input); + const framework::SelectedRows& input, + const bool sorted_result = false); void operator()(const DeviceContext& context, const framework::SelectedRows& input, - framework::SelectedRows* output); + framework::SelectedRows* output, + const bool sorted_result = false); void operator()(const DeviceContext& context, const std::vector& inputs, - framework::SelectedRows* output); + framework::SelectedRows* output, + const bool sorted_result = false); }; enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; diff --git a/paddle/fluid/operators/ngraph/ngraph_ops.h b/paddle/fluid/operators/ngraph/ngraph_ops.h index 0ed77ff5577cf4f45a8865db9b42e8bda9839478..8e7457dd56c2413f84008ce467537e07b3e80cc7 100644 --- a/paddle/fluid/operators/ngraph/ngraph_ops.h +++ b/paddle/fluid/operators/ngraph/ngraph_ops.h @@ -22,4 +22,6 @@ limitations under the License. */ #pragma once #include "ops/binary_unnary_op.h" +#include "ops/fill_constant_op.h" #include "ops/mul_op.h" +#include "ops/top_k_op.h" diff --git a/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h b/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h index 4e2f5e231c16cd0fad6db287aa19430c56b534fd..6610380fcf432d0019f7e844fa9304e151b20efd 100644 --- a/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h +++ b/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h @@ -45,7 +45,6 @@ static void BuildUnaryNode( auto out = std::make_shared(input); paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); } - } // namespace ngraphs } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/ngraph/ops/fill_constant_op.h b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5eff69e7b165fa19c775926914b7b3e8fcb043e5 --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h @@ -0,0 +1,61 @@ +/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_NGRAPH +#pragma once + +#include +#include +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +void BuildFillConstantNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + auto vsp = op_attrs.Get>("shape"); + ngraph::Shape shape; + for (auto& sp : vsp) { + shape.push_back(sp); + } + float value = op_attrs.Get("value"); + ngraph::element::Type ng_dtype; + auto data_type = static_cast( + op_attrs.Get("dtype")); + if (data_type == paddle::framework::proto::VarType::FP32) { + ng_dtype = ngraph::element::f32; + } else if (data_type == paddle::framework::proto::VarType::FP64) { + ng_dtype = ngraph::element::f64; + } else if (data_type == paddle::framework::proto::VarType::INT64) { + ng_dtype = ngraph::element::i64; + } else if (data_type == paddle::framework::proto::VarType::INT32) { + ng_dtype = ngraph::element::i32; + } else if (data_type == paddle::framework::proto::VarType::BOOL) { + ng_dtype = ngraph::element::boolean; + } else { + PADDLE_THROW("unsupported data type: %s", data_type); + } + auto out = ngraph::op::Constant::create(ng_dtype, shape, {value}); + paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle +#endif diff --git a/paddle/fluid/operators/ngraph/ops/top_k_op.h b/paddle/fluid/operators/ngraph/ops/top_k_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2b7254497c0e1aab2e653e69e6461f262b929703 --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/top_k_op.h @@ -0,0 +1,51 @@ +/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_NGRAPH +#pragma once + +#include +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +void BuildTopKNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto op_attrs = paddle::framework::AttrReader(op->Attrs()); + int k = op_attrs.Get("k"); + auto input = paddle::platform::GetInputNode(op, "X", ngb_node_map); + auto top_k = std::make_shared( + input, input->get_shape().size() - 1, ngraph::element::i64, k); + std::shared_ptr indices = + std::make_shared(top_k, 0); + std::shared_ptr out = + std::make_shared(top_k, 1); + auto dummy_out = paddle::platform::GetOutputNode(op, "Out", ngb_node_map); + if (dummy_out && dummy_out->get_element_type() != out->get_element_type()) { + out = std::make_shared(out, + dummy_out->get_element_type()); + } + paddle::platform::SetOutputNode(op, "Indices", indices, ngb_node_map); + paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle +#endif diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index f214d8272f5cc5f1cb2e32c9bb59ca60a1066500..1138bb7400e0e7a00983e7bfaad2b2d9704b77ab 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -157,8 +157,11 @@ struct AdamFunctor { } }; +template +struct SparseAdamFunctor; + template -struct SparseAdamFunctor { +struct SparseAdamFunctor { T beta1_; T beta2_; T epsilon_; @@ -236,6 +239,106 @@ struct SparseAdamFunctor { } }; +template +struct SparseAdamFunctor { + T beta1_; + T beta2_; + T epsilon_; + + const T* beta1_pow_; + const T* beta2_pow_; + const T* moment1_; + T* moment1_out_; + const T* moment2_; + T* moment2_out_; + const T* lr_; + const T* grad_; + const T* param_; + T* param_out_; + + const int64_t* rows_; + int64_t row_numel_; + int64_t row_count_; + + SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, + const T* beta2_pow, const T* mom1, T* mom1_out, + const T* mom2, T* mom2_out, const T* lr, const T* grad, + const T* param, T* param_out, const int64_t* rows, + int64_t row_numel, int64_t row_count, bool lazy_mode) + : beta1_(beta1), + beta2_(beta2), + epsilon_(epsilon), + beta1_pow_(beta1_pow), + beta2_pow_(beta2_pow), + moment1_(mom1), + moment1_out_(mom1_out), + moment2_(mom2), + moment2_out_(mom2_out), + lr_(lr), + grad_(grad), + param_(param), + param_out_(param_out), + rows_(rows), + row_numel_(row_numel), + row_count_(row_count) {} + + inline HOSTDEVICE void adam_update(size_t i, T g) const { + // The following code is the same as dense + T mom1 = moment1_[i]; + T mom2 = moment2_[i]; + T lr = *lr_; + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; + T p = param_[i]; + + // Calculation + lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + + mom1 = beta1_ * mom1 + (1 - beta1_) * g; + mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + + // Write back to global memory + moment1_out_[i] = mom1; + moment2_out_[i] = mom2; + param_out_[i] = p; + } + + inline void operator()(size_t numel) const { + // lr could be reuse + T lr = *lr_; + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; + lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + size_t row_count = numel / row_numel_; + + for (size_t i = 0U, j = 0U; i != row_count; ++i) { + if (i == *(rows_ + j)) { + for (size_t k = 0U; k != row_numel_; ++k) { + T g = grad_[j * row_numel_ + k]; + adam_update(i * row_numel_ + k, g); + } + ++j; + } else { + for (size_t k = 0U; k != row_numel_; ++k) { + T mom1 = moment1_[i * row_numel_ + k]; + T mom2 = moment2_[i * row_numel_ + k]; + T p = param_[i * row_numel_ + k]; + + mom1 = beta1_ * mom1; + mom2 = beta2_ * mom2; + + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + // Write back to global memory + moment1_out_[i * row_numel_ + k] = mom1; + moment2_out_[i * row_numel_ + k] = mom2; + param_out_[i * row_numel_ + k] = p; + } + } + } + } +}; + template class AdamOpKernel : public framework::OpKernel { public: @@ -331,7 +434,7 @@ class AdamOpKernel : public framework::OpKernel { .Var() ->GetMutable(); merge_func(ctx.template device_context(), grad, - grad_merge_var); + grad_merge_var, true); grad_merge_ptr = grad_merge_var; } @@ -347,32 +450,46 @@ class AdamOpKernel : public framework::OpKernel { } else { #endif rows = grad_merge.rows().data(); - #if defined(PADDLE_WITH_CUDA) } #endif auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); - SparseAdamFunctor functor( - beta1, beta2, epsilon, beta1_pow.template data(), - beta2_pow.template data(), mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - lr.template data(), grad_data, param.template data(), - param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, - grad_merge.rows().size(), lazy_mode); - VLOG(3) << "lazy_mode :" << lazy_mode; - if (lazy_mode && platform::is_cpu_place(ctx.GetPlace())) { - size_t row_count = grad_merge.rows().size(); - std::vector cpu_rows(grad_merge.rows()); - for (size_t row_index = 0; row_index < row_count; ++row_index) { - for (size_t offset = 0; offset < row_numel; ++offset) { - size_t i = cpu_rows[row_index] * row_numel + offset; - functor.adam_update(i, grad_data[row_index * row_numel + offset]); + if (platform::is_cpu_place(ctx.GetPlace())) { + SparseAdamFunctor functor( + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad_data, param.template data(), + param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, + grad_merge.rows().size(), lazy_mode); + + if (lazy_mode) { + size_t row_count = grad_merge.rows().size(); + std::vector cpu_rows(grad_merge.rows()); + for (size_t row_index = 0; row_index < row_count; ++row_index) { + for (size_t offset = 0; offset < row_numel; ++offset) { + size_t i = cpu_rows[row_index] * row_numel + offset; + functor.adam_update(i, grad_data[row_index * row_numel + offset]); + } } + } else { + functor(param.numel()); } - } else { + } else if (platform::is_gpu_place(ctx.GetPlace())) { + SparseAdamFunctor functor( + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad_data, param.template data(), + param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, + grad_merge.rows().size(), lazy_mode); + + // FIXME(minqiyang): remove BinarySearch in GPU later platform::ForRange for_range( static_cast(ctx.device_context()), param.numel()); diff --git a/paddle/fluid/operators/quantize_mkldnn_op.cc b/paddle/fluid/operators/quantize_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0638e42873376bcec6e4de61494da46d1f0073d1 --- /dev/null +++ b/paddle/fluid/operators/quantize_mkldnn_op.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/quantize_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using mkldnn::memory; +using mkldnn::primitive; +using mkldnn::reorder; +using platform::to_void_cast; +using Tensor = framework::Tensor; +using framework::DataLayout; +using mkldnn::stream; +using platform::GetMKLDNNFormat; + +template +class QuantOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto scale_data = ctx.Attr("Scale"); + auto* output = ctx.Output("Output"); + auto& dev_ctx = + ctx.template device_context(); + const auto& engine = dev_ctx.GetEngine(); + + std::vector pipeline; + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + + const T* input_data = input->data(); + + mkldnn::primitive_attr attri; + int mask = 0; + attri.set_output_scales(mask, {scale_data}); + + auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32, + input->format()); + auto src_pd = mkldnn::memory::primitive_desc(src_md, engine); + auto src_memory = + std::make_shared(src_pd, to_void_cast(input_data)); + std::shared_ptr src_memory_p = + std::shared_ptr(new primitive::at(*src_memory)); + + bool is_negative = ctx.Attr("is_negative_input"); + std::shared_ptr dst_pd; + std::shared_ptr dst_memory; + if (is_negative) { + platform::ConvMKLDNNHandler::SetDstMemory( + ctx, output, dst_tz, engine, dst_pd, dst_memory); + } else { + platform::ConvMKLDNNHandler::SetDstMemory( + ctx, output, dst_tz, engine, dst_pd, dst_memory); + } + auto reorder_pd = std::shared_ptr( + new reorder::primitive_desc(src_pd, *dst_pd, attri)); + auto reorder_p = std::shared_ptr( + new reorder(*reorder_pd, *src_memory_p, *dst_memory)); + pipeline.push_back(*reorder_p); + stream(stream::kind::eager).submit(pipeline).wait(); + output->set_layout(DataLayout::kMKLDNN); + output->set_format(GetMKLDNNFormat(*dst_memory)); + } +}; +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +// TODO(Xiaoli) Support FP32->S8 quantization. + +REGISTER_OP_KERNEL(quantize, MKLDNN, ::paddle::platform::CPUPlace, + ops::QuantOpKernel); diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf70c08bdb82218a2d0f63f3e70a2a1093e6a542 --- /dev/null +++ b/paddle/fluid/operators/quantize_op.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/quantize_op.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +framework::OpKernelType QuantOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library_ = framework::LibraryType::kMKLDNN; + framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; + + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.GetPlace(), layout_, library_); +} + +void QuantOpMaker::Make() { + AddInput("Input", "input data"); + AddOutput("Output", "output data"); + AddAttr("is_negative_input", + "(bool, default false) Only used in mkldnn INT8 kernel") + .SetDefault(false); + AddAttr("Scale", "scale data").SetDefault({1.0f}); + AddComment(R"DOC(This op will quantize data from FP32 to INT8)DOC"); +} + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker, + paddle::framework::DefaultGradOpDescMaker); diff --git a/paddle/fluid/operators/quantize_op.h b/paddle/fluid/operators/quantize_op.h new file mode 100644 index 0000000000000000000000000000000000000000..091306e4637c7e2393b6736f0e1edf9dd7fd2c8a --- /dev/null +++ b/paddle/fluid/operators/quantize_op.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class QuantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); + ctx->ShareLoD("Input", /*->*/ "Output"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class QuantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 23f00406de6190dfef91e259d6af358b5dac1713..584df85e80203c383a89954aac73dd1dcd723f7c 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/place.h" @@ -181,6 +182,21 @@ class MKLDNNHandler { return dims2str(operand_dims) + suffix; } + template + static void SetDstMemory( + const framework::ExecutionContext& ctx, framework::Tensor* output, + std::vector dst_tz, const mkldnn::engine& engine, + std::shared_ptr& dst_pd, // NOLINT + std::shared_ptr& dst_memory) { // NOLINT + M* output_data = output->mutable_data(ctx.GetPlace()); + auto dst_md = platform::MKLDNNMemDesc( + {dst_tz}, paddle::framework::ToMKLDNNDataType( + framework::DataTypeTrait::DataType), + mkldnn::memory::format::nhwc); + dst_pd.reset(new mkldnn::memory::primitive_desc(dst_md, engine)); + dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast(output_data))); + } + protected: static std::string dims2str(const mkldnn::memory::dims& operand_dims) { std::string dstr = ""; diff --git a/python/paddle/fluid/contrib/__init__.py b/python/paddle/fluid/contrib/__init__.py index ece97b661fd7d60f8822439a84ee4403b9e3d81c..24621110b18f63779da14edc42765aae3bf4abd6 100644 --- a/python/paddle/fluid/contrib/__init__.py +++ b/python/paddle/fluid/contrib/__init__.py @@ -22,6 +22,8 @@ from . import op_frequence from .op_frequence import * from . import quantize from .quantize import * +from . import slim +from .slim import * from . import utils from .utils import * @@ -30,4 +32,5 @@ __all__ += decoder.__all__ __all__ += memory_usage_calc.__all__ __all__ += op_frequence.__all__ __all__ += quantize.__all__ +__all__ += slim.__all__ __all__ += utils.__all__ diff --git a/python/paddle/fluid/contrib/slim/__init__.py b/python/paddle/fluid/contrib/slim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22dbf7c8b6bb2da7c310a20bdcbaffca248575b0 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .core import * +from .graph import * +from .prune import * +__all__ = [ + 'build_compressor', + 'CompressPass', + 'ImitationGraph', + 'SensitivePruneStrategy', + 'MagnitudePruner', + 'RatioPruner', +] diff --git a/python/paddle/fluid/contrib/slim/core/__init__.py b/python/paddle/fluid/contrib/slim/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7826d5830a6f7f6d42cb1275c2289695c080e52f --- /dev/null +++ b/python/paddle/fluid/contrib/slim/core/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import config +from .config import * +from . import compress_pass +from .compress_pass import * +from . import strategy +from .strategy import * +from . import pass_builder +from .pass_builder import * + +__all__ = config.__all__ + compress_pass.__all__ + strategy.__all__ + pass_builder.__all__ diff --git a/python/paddle/fluid/contrib/slim/core/compress_pass.py b/python/paddle/fluid/contrib/slim/core/compress_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..c4c348b878a1df43d7fb909f506c8cf65366866f --- /dev/null +++ b/python/paddle/fluid/contrib/slim/core/compress_pass.py @@ -0,0 +1,129 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ....core import CPUPlace +from ..graph import get_executor + +__all__ = ['Context', 'CompressPass'] + + +class Context(object): + """ + The context in the process of compression. + Args: + exe: The executor used to execute graph. + graph: The graph to be compressed. + scope: The scope used to execute graph. + program_exe: The program_exe is used to execute the program + created for modifying the variables in scope. + """ + + def __init__(self, exe, graph, scope, program_exe=None): + # The total number of epoches to be trained. + self.epoch = 0 + # Current epoch + self.epoch_id = 0 + # Current batch + self.batch_id = 0 + self.exe = exe + self.graph = graph + self.scope = scope + self.program_exe = program_exe + + +class CompressPass(object): + """ + The pass used to compress model. + Args: + place: The device used in compression. + data_reader: The data_reader used to run graph. + data_feeder: The data_feeder used to run graph. + scope: The scope used to run graph. + metrics: The metrics for evaluating model. + epoch: The total epoches of trainning in compression. + program_exe: The program_exe is used to execute the program + created for modifying the variables in scope. + """ + + def __init__(self, + place=None, + data_reader=None, + data_feeder=None, + scope=None, + metrics=None, + epoch=None, + program_exe=None): + self.strategies = [] + self.place = CPUPlace() if place is None else place + self.data_reader = data_reader + self.data_feeder = data_feeder + self.scope = scope + self.metrics = metrics + self.epoch = epoch + self.program_exe = program_exe + + def add_strategy(self, strategy): + """ + Add a strategy to current compress pass. + Args: + strategy: The strategy to be added into current compress pass. + """ + self.strategies.append(strategy) + self.epoch = max(strategy.end_epoch, self.epoch) + + def apply(self, graph): + """ + Compress a model. + Args: + graph: The target graph to be compressed. + """ + self.executor = get_executor(graph, self.place) + context = Context( + self.executor, graph, self.scope, program_exe=self.program_exe) + + for strategy in self.strategies: + strategy.on_compress_begin(context) + + for epoch in range(self.epoch): + + for strategy in self.strategies: + strategy.on_epoch_begin(context) + + for data in self.data_reader(): + + for strategy in self.strategies: + strategy.on_batch_begin(context) + fetches = None + if self.metrics: + fetches = self.metrics.values() + feed = None + if self.data_feeder: + feed = self.data_feeder.feed(data) + results = self.executor.run(graph, + fetches=fetches, + scope=self.scope, + feed=feed) + if results: + print("results: {}".format( + zip(self.metrics.keys(), results))) + for strategy in self.strategies: + strategy.on_batch_end(context) + context.batch_id += 1 + + for strategy in self.strategies: + strategy.on_epoch_end(context) + context.epoch_id += 1 + + for strategy in self.strategies: + strategy.on_compress_end(context) diff --git a/python/paddle/fluid/contrib/slim/core/config.py b/python/paddle/fluid/contrib/slim/core/config.py new file mode 100644 index 0000000000000000000000000000000000000000..811c45700376aff9883fe197007b582f63817f03 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/core/config.py @@ -0,0 +1,111 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import funcsigs +import yaml +from collections import OrderedDict +from ..prune import * +from .compress_pass import * +from .strategy import * + +__all__ = ['ConfigFactory'] +"""This factory is used to create instances by loading and parsing configure file with yaml format. +""" + + +class ConfigFactory(object): + def __init__(self, config): + """Init a factory from configure file.""" + self.instances = {} + self.version = None + self._parse_config(config) + + def get_compress_pass(self): + """ + Get compress pass from factory. + """ + return self.instance('compress_pass') + + def instance(self, name): + """ + Get instance from factory. + """ + if name in self.instances: + return self.instances[name] + else: + return None + + def _new_instance(self, name, attrs): + if name not in self.instances: + class_ = globals()[attrs['class']] + sig = funcsigs.signature(class_.__init__) + keys = [ + param.name for param in sig.parameters.values() + if (param.kind == param.POSITIONAL_OR_KEYWORD) + ][1:] + keys = set(attrs.keys()).intersection(set(keys)) + args = {} + for key in keys: + value = attrs[key] + if isinstance(value, str) and value in self.instances: + value = self.instances[value] + args[key] = value + self.instances[name] = class_(**args) + return self.instances.get(name) + + def _parse_config(self, config): + assert config + with open(config, 'r') as config_file: + key_values = self._ordered_load(config_file) + for key in key_values: + # parse version + if key == 'version' and self.version is None: + self.version = int(key_values['version']) + assert self.version == int(key_values['version']) + + # parse pruners + if key == 'pruners' or key == 'strategies': + instances = key_values[key] + for name in instances: + self._new_instance(name, instances[name]) + + if key == 'compress_pass': + compress_pass = self._new_instance(key, key_values[key]) + for name in key_values[key]['strategies']: + strategy = self.instance(name) + compress_pass.add_strategy(strategy) + + if key == 'include': + for config_file in key_values[key]: + self._parse_config(config_file.strip()) + + def _ordered_load(self, + stream, + Loader=yaml.Loader, + object_pairs_hook=OrderedDict): + """ + See: https://stackoverflow.com/questions/5121931/in-python-how-can-you-load-yaml-mappings-as-ordereddicts + """ + + class OrderedLoader(Loader): + pass + + def construct_mapping(loader, node): + loader.flatten_mapping(node) + return object_pairs_hook(loader.construct_pairs(node)) + + OrderedLoader.add_constructor( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping) + return yaml.load(stream, OrderedLoader) diff --git a/python/paddle/fluid/contrib/slim/core/pass_builder.py b/python/paddle/fluid/contrib/slim/core/pass_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1ddc94e04f1d606292071ba7e5cc74fedd5d36 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/core/pass_builder.py @@ -0,0 +1,39 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .compress_pass import CompressPass +from .config import ConfigFactory + +__all__ = ['build_compressor'] + + +def build_compressor(place=None, + data_reader=None, + data_feeder=None, + scope=None, + metrics=None, + epoch=None, + config=None): + if config is not None: + factory = ConfigFactory(config) + comp_pass = factory.get_compress_pass() + else: + comp_pass = CompressPass() + comp_pass.place = place + comp_pass.data_reader = data_reader + comp_pass.data_feeder = data_feeder + comp_pass.scope = scope + comp_pass.metrics = metrics + comp_pass.epoch = epoch + return comp_pass diff --git a/python/paddle/fluid/contrib/slim/core/strategy.py b/python/paddle/fluid/contrib/slim/core/strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..74d98e98b0c390599acfaefeb0636a599b46d391 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/core/strategy.py @@ -0,0 +1,48 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['Strategy'] + + +class Strategy(object): + """ + Base class for all strategies. + """ + + def __init__(self, start_epoch=0, end_epoch=10): + """ + Args: + start_epoch: The first epoch to apply the strategy. + end_epoch: The last epoch to apply the strategy. + """ + self.start_epoch = start_epoch + self.end_epoch = end_epoch + + def on_compress_begin(self, context): + pass + + def on_epoch_begin(self, context): + pass + + def on_epoch_end(self, context): + pass + + def on_batch_begin(self, context): + pass + + def on_batch_end(self, context): + pass + + def on_compress_end(self, context): + pass diff --git a/python/paddle/fluid/contrib/slim/demo/filter_prune/config.yaml b/python/paddle/fluid/contrib/slim/demo/filter_prune/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ea888fa2c74a23b4769f75dce6a776afcca41a51 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/demo/filter_prune/config.yaml @@ -0,0 +1,28 @@ +version: 1.0 +pruners: + pruner_1: + class: 'RatioPruner' + ratios: + 'conv1_1.w': 0.3 + 'conv1_2.w': 0.4 + '*': 0.9 + group_dims: + '*': [1, 2, 3] + criterions: + '*': 'l1-norm' +strategies: + strategy_1: + class: 'SensitivePruneStrategy' + pruner: 'pruner_1' + start_epoch: 0 + end_epoch: 10 + delta_rate: 0.20 + acc_loss_threshold: 0.2 + sensitivities: + 'conv1_1.w': 0.4 + +compress_pass: + class: 'CompressPass' + epoch: 100 + strategies: + - strategy_1 diff --git a/python/paddle/fluid/contrib/slim/demo/filter_prune/demo.py b/python/paddle/fluid/contrib/slim/demo/filter_prune/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..21c59c0c9d2d9b76932ab6eeff73754940a3bfa0 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/demo/filter_prune/demo.py @@ -0,0 +1,69 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle +import os +import sys +from paddle.fluid.contrib.slim import CompressPass +from paddle.fluid.contrib.slim import build_compressor +from paddle.fluid.contrib.slim import ImitationGraph + + +class LinearModel(object): + def __init__(slef): + pass + + def train(self): + train_program = fluid.Program() + startup_program = fluid.Program() + startup_program.random_seed = 10 + with fluid.program_guard(train_program, startup_program): + x = fluid.layers.data(name='x', shape=[13], dtype='float32') + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + predict = fluid.layers.fc(input=x, size=1, act=None) + cost = fluid.layers.square_error_cost(input=predict, label=y) + avg_cost = fluid.layers.mean(cost) + eval_program = train_program.clone() + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + sgd_optimizer.minimize(avg_cost) + + train_reader = paddle.batch( + paddle.dataset.uci_housing.train(), batch_size=1) + eval_reader = paddle.batch( + paddle.dataset.uci_housing.test(), batch_size=1) + place = fluid.CPUPlace() + train_feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) + eval_feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) + exe = fluid.Executor(place) + exe.run(startup_program) + train_metrics = {"loss": avg_cost.name} + eval_metrics = {"loss": avg_cost.name} + + graph = ImitationGraph(train_program) + config = './config.yaml' + comp_pass = build_compressor( + place, + data_reader=train_reader, + data_feeder=train_feeder, + scope=fluid.global_scope(), + metrics=train_metrics, + epoch=1, + config=config) + comp_pass.apply(graph) + + +if __name__ == "__main__": + model = LinearModel() + model.train() diff --git a/python/paddle/fluid/contrib/slim/graph/__init__.py b/python/paddle/fluid/contrib/slim/graph/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d65472d193b639f0766e278ec14b5dc36c5d62bc --- /dev/null +++ b/python/paddle/fluid/contrib/slim/graph/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import executor +from .executor import * +from . import graph +from .graph import * +from . import graph_pass +from .graph_pass import * +__all__ = executor.__all__ +__all__ += graph.__all__ +__all__ += graph_pass.__all__ diff --git a/python/paddle/fluid/contrib/slim/graph/executor.py b/python/paddle/fluid/contrib/slim/graph/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..c02c3af82013287bf19e1869cb60dc65239b720a --- /dev/null +++ b/python/paddle/fluid/contrib/slim/graph/executor.py @@ -0,0 +1,62 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from abc import abstractmethod +from .... import executor +from .graph import IRGraph, ImitationGraph + +__all__ = ['get_executor'] + + +class GraphExecutor(object): + __metaclass__ = abc.ABCMeta + + def __init__(self, place): + self.place = place + + @abstractmethod + def run(self, graph, feches=None, feed=None): + pass + + +class IRGraphExecutor(GraphExecutor): + def run(self, grah, fetches, feed=None): + pass + + +class ImitationGraphExecutor(GraphExecutor): + def __init__(self, place): + super(ImitationGraphExecutor, self).__init__(place) + self.exe = executor.Executor(place) + + def run(self, graph, scope=None, fetches=None, feed=None): + assert isinstance(graph, ImitationGraph) + fetch_list = None + if fetches: + fetch_list = [ + graph.program.global_block().var(name) for name in fetches + ] + results = self.exe.run(graph.program, + scope=scope, + fetch_list=fetch_list, + feed=feed) + return results + + +def get_executor(graph, place): + if isinstance(graph, ImitationGraph): + return ImitationGraphExecutor(place) + if isinstance(graph, IRGraph): + return IRGraphExecutor(place) diff --git a/python/paddle/fluid/contrib/slim/graph/graph.py b/python/paddle/fluid/contrib/slim/graph/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6b0702035d49189c0919f976ea3c0c52663547 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/graph/graph.py @@ -0,0 +1,45 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ....framework import Program + +__all__ = ['Graph', 'ImitationGraph', 'IRGraph'] + + +class Graph(object): + """ + Base class for all graph. + """ + + def __init__(self): + pass + + def all_parameters(self): + """ + Return all the parameters in current graph. + """ + pass + + +class ImitationGraph(Graph): + def __init__(self, program=None): + super(ImitationGraph, self).__init__() + self.program = Program() if program is None else program + + def all_parameters(self): + return self.program.global_block().all_parameters() + + +class IRGraph(Graph): + pass diff --git a/python/paddle/fluid/contrib/slim/graph/graph_pass.py b/python/paddle/fluid/contrib/slim/graph/graph_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..1db6c4f110daa44be7fcbcc36f47224797b6dc88 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/graph/graph_pass.py @@ -0,0 +1,42 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['GraphPass', 'PruneParameterPass'] + + +class GraphPass(object): + """ + Base class for all graph pass. + """ + + def __init__(self): + pass + + def apply(self, graph): + pass + + +class PruneParameterPass(GraphPass): + """ + Generate a graph for pruning parameters from target graph. + """ + + def __init__(self, pruned_params, thresholds): + super(PruneParameterPass, self).__init__() + self.pruned_params = pruned_params + self.thresholds = thresholds + self.default_threshold = thresholds['*'] + + def apply(self, graph): + pass diff --git a/python/paddle/fluid/contrib/slim/prune/__init__.py b/python/paddle/fluid/contrib/slim/prune/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..764a45bb130a9993015858f1cbdbc9f3b864bd5e --- /dev/null +++ b/python/paddle/fluid/contrib/slim/prune/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import pruner +from .pruner import * +from . import prune_strategy +from .prune_strategy import * + +__all__ = pruner.__all__ +__all__ += prune_strategy.__all__ diff --git a/python/paddle/fluid/contrib/slim/prune/prune_strategy.py b/python/paddle/fluid/contrib/slim/prune/prune_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..34c5107daa3cde10e7995902be37e34e19664da8 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/prune/prune_strategy.py @@ -0,0 +1,66 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..core.strategy import Strategy +from ....framework import Program, program_guard +from .... import layers +import numpy as np + +__all__ = ['SensitivePruneStrategy', 'PruneStrategy'] + + +class SensitivePruneStrategy(Strategy): + def __init__(self, + pruner=None, + start_epoch=0, + end_epoch=10, + delta_rate=0.20, + acc_loss_threshold=0.2, + sensitivities=None): + super(SensitivePruneStrategy, self).__init__(start_epoch, end_epoch) + self.pruner = pruner + self.delta_rate = delta_rate + self.acc_loss_threshold = acc_loss_threshold + self.sensitivities = sensitivities + + +class PruneStrategy(Strategy): + """ + The strategy that pruning weights by threshold or ratio iteratively. + """ + + def __init__(self, + pruner, + mini_batch_pruning_frequency=1, + start_epoch=0, + end_epoch=10): + super(PruneStrategy, self).__init__(start_epoch, end_epoch) + self.pruner = pruner + self.mini_batch_pruning_frequency = mini_batch_pruning_frequency + + def _triger(self, context): + return (context.batch_id % self.mini_batch_pruning_frequency == 0 and + self.start_epoch <= context.epoch_id < self.end_epoch) + + def on_batch_end(self, context): + if self._triger(context): + prune_program = Program() + with program_guard(prune_program): + for param in context.graph.all_parameters(): + prune_program.global_block().clone_variable(param) + p = prune_program.global_block().var(param.name) + zeros_mask = self.pruner.prune(p) + pruned_param = p * zeros_mask + layers.assign(input=pruned_param, output=param) + context.program_exe.run(prune_program, scope=context.scope) diff --git a/python/paddle/fluid/contrib/slim/prune/pruner.py b/python/paddle/fluid/contrib/slim/prune/pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..ca72bcb6f6004c18f3ec794850e0aeaecb92d7ac --- /dev/null +++ b/python/paddle/fluid/contrib/slim/prune/pruner.py @@ -0,0 +1,83 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from .... import layers + +__all__ = ['Pruner', 'MagnitudePruner', 'RatioPruner'] + + +class Pruner(object): + """ + Base class of all pruners. + """ + + def __init__(self): + pass + + def prune(self, param): + pass + + +class MagnitudePruner(Pruner): + """ + Pruner used to pruning a parameter by threshold. + """ + + def __init__(self, threshold): + self.threshold = threshold + + def prune(self, param, threshold=None): + if threshold is None: + thres = layers.fill_constant( + shape=[1], dtype='float32', value=self.threshold) + else: + thres = threshold + zeros_mask = layers.less_than(x=param, y=thres) + return zeros_mask + + +class RatioPruner(Pruner): + """ + Pruner used to pruning a parameter by ratio. + """ + + def __init__(self, ratios=None): + """ + Args: + ratios: dict with pair (paramer_name, pruned_ratio). + """ + self.ratios = ratios + + def prune(self, param, ratio=None): + """ + Args: + ratio: `ratio=40%` means pruning (1 - 40%) weights to zero. + """ + if ratio is None: + rat = self.ratios[ + param.name] if param.name in self.ratios else self.ratios['*'] + else: + rat = ratio + if rat < 1.0: + k = max(int(rat * np.prod(param.shape)), 1) + param_vec = layers.reshape(x=param, shape=[1, -1]) + param_topk, _ = layers.topk(param_vec, k=k) + threshold = layers.slice( + param_topk, axes=[1], starts=[-1], ends=[k]) + threshold = layers.reshape(x=threshold, shape=[1]) + zeros_mask = layers.less_than(x=param, y=threshold) + else: + zeros_mask = layers.ones(param.shape) + return zeros_mask diff --git a/python/paddle/fluid/contrib/slim/unitest/__init__.py b/python/paddle/fluid/contrib/slim/unitest/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d41233e227dc7bab94ee4284cc25e12b45bf469 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/unitest/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/fluid/contrib/slim/unitest/configs/config.yaml b/python/paddle/fluid/contrib/slim/unitest/configs/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..db488b96330210df15b02b19d90abd5c9101f844 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/unitest/configs/config.yaml @@ -0,0 +1,29 @@ +version: 1.0 +include: ["./unitest/configs/pruners.yaml", "./unitest/configs/pruners_0.yaml"] +pruners: + pruner_1: + class: 'RatioPruner' + ratios: + 'conv1_1.w': 0.3 + 'conv1_2.w': 0.4 + '*': 0.9 + group_dims: + '*': [1, 2, 3] + criterions: + '*': 'l1-norm' +strategies: + strategy_1: + class: 'SensitivePruneStrategy' + pruner: 'pruner_2' + start_epoch: 0 + end_epoch: 10 + delta_rate: 0.20 + acc_loss_threshold: 0.2 + sensitivities: + 'conv1_1.w': 0.4 + +compress_pass: + class: 'CompressPass' + epoch: 100 + strategies: + - strategy_1 diff --git a/python/paddle/fluid/contrib/slim/unitest/configs/pruners.yaml b/python/paddle/fluid/contrib/slim/unitest/configs/pruners.yaml new file mode 100644 index 0000000000000000000000000000000000000000..235092c595bf7c653221c7fe2b381fecf487fa49 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/unitest/configs/pruners.yaml @@ -0,0 +1,12 @@ +version: 1.0 +pruners: + pruner_2: + class: 'RatioPruner' + ratios: + 'conv1_1.w': 0.5 + 'conv1_2.w': 0.2 + '*': 0.7 + group_dims: + '*': [1, 2, 3] + criterions: + '*': 'l1-norm' diff --git a/python/paddle/fluid/contrib/slim/unitest/configs/pruners_0.yaml b/python/paddle/fluid/contrib/slim/unitest/configs/pruners_0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cd2ef9eb56ddbc1367ce2e3b413372fbcd542bde --- /dev/null +++ b/python/paddle/fluid/contrib/slim/unitest/configs/pruners_0.yaml @@ -0,0 +1,12 @@ +version: 1.0 +pruners: + pruner_3: + class: 'RatioPruner' + ratios: + 'conv1_1.w': 0.5 + 'conv1_2.w': 0.2 + '*': 0.7 + group_dims: + '*': [1, 2, 3] + criterions: + '*': 'l1-norm' diff --git a/python/paddle/fluid/contrib/slim/unitest/test_factory.py b/python/paddle/fluid/contrib/slim/unitest/test_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..07f28aac905d1a2813dbde6143235c7916fd9278 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/unitest/test_factory.py @@ -0,0 +1,41 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.contrib.slim import ConfigFactory +import unittest + + +class TestFactory(unittest.TestCase): + def test_parse(self): + factory = ConfigFactory('./unitest/configs/config.yaml') + + pruner = factory.instance('pruner_1') + self.assertEquals(pruner.ratios['conv1_1.w'], 0.3) + + pruner = factory.instance('pruner_2') + self.assertEquals(pruner.ratios['*'], 0.7) + + strategy = factory.instance('strategy_1') + pruner = strategy.pruner + self.assertEquals(pruner.ratios['*'], 0.7) + + compress_pass = factory.get_compress_pass() + self.assertEquals(compress_pass.epoch, 100) + + strategy = compress_pass.strategies[0] + self.assertEquals(strategy.delta_rate, 0.2) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_fill_constant_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_fill_constant_ngraph_op.py new file mode 100644 index 0000000000000000000000000000000000000000..835376ffe78f9119a9be6c379998e3a3b50aab43 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_fill_constant_ngraph_op.py @@ -0,0 +1,37 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from paddle.fluid.tests.unittests.test_fill_constant_op import TestFillConstantOp1, TestFillConstantOp2, TestFillConstantOpWithSelectedRows + + +class TestNGRAPHFillConstantOp1(TestFillConstantOp1): + def setUp(self): + super(TestNGRAPHFillConstantOp1, self).setUp() + + +class TestNGRAPHFillConstantOp2(TestFillConstantOp2): + def setUp(self): + super(TestNGRAPHFillConstantOp2, self).setUp() + + +class TestNGRAPHFillConstantOpWithSelectedRows( + TestFillConstantOpWithSelectedRows): + def setUp(self): + super(TestFillConstantOpWithSelectedRows, self).setUp() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_top_k_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_top_k_ngraph_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0171087dce5d4c7b72eca7f7e4fb955af94812 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_top_k_ngraph_op.py @@ -0,0 +1,41 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import unittest +from paddle.fluid.tests.unittests.test_top_k_op import TestTopkOp, TestTopkOp3d, TestTopkOp2, TestTopkOp3, TestTopkOp4 + + +class TestNGRAPHTopkOp(TestTopkOp): + def setUp(self): + super(TestNGRAPHTopkOp, self).setUp() + + +class TestNGRAPHTopkOp2(TestTopkOp2): + def setUp(self): + super(TestNGRAPHTopkOp2, self).setUp() + + +class TestNGRAPHTopkOp3(TestTopkOp3): + def setUp(self): + super(TestNGRAPHTopkOp3, self).setUp() + + +class TestNGRAPHTopkOp4(TestTopkOp4): + def setUp(self): + super(TestNGRAPHTopkOp4, self).setUp() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dequantize_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_dequantize_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5e1abd7c8fb010357998c0ceaebaf21619fda9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dequantize_mkldnn_op.py @@ -0,0 +1,73 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestDeQuantizeOp(OpTest): + def setUp(self): + self.op_type = 'dequantize' + self.scale = 2.0 + self.input_size = [1, 1, 5, 5] #Naive nChw16c + self.data_type = 'int8' + self.set_scale() + self.set_data_type() + + if self.data_type == 'int8': + input = (np.random.randint(0, 100, self.input_size) - 50 + ).astype(self.data_type) + output = (input * (1 / self.scale)).astype('float') + else: + input = (np.random.randint(0, 100, + self.input_size)).astype(self.data_type) + output = (input * (1 / self.scale)).astype('float') + + self.inputs = {'Input': OpTest.np_dtype_to_fluid_dtype(input)} + + self.outputs = {'Output': output} + + self.attrs = {'Scale': self.scale, } + + def test_check_output(self): + self.check_output() + + def set_scale(self): + pass + + def set_data_type(OpTest): + pass + + +class TestDeQuantizeOp1(TestDeQuantizeOp): + def set_scale(self): + self.scale = 1.5 + + def set_data_type(self): + self.data_type = 'int8' + + +class TestDeQuantizeOp2(TestDeQuantizeOp): + def set_scale(self): + self.scale = 0.8 + + def set_data_type(self): + self.data_type = 'uint8' + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_quantize_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_quantize_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..99607928648be437b7f944f86a0c28b99d1775c4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_quantize_mkldnn_op.py @@ -0,0 +1,76 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestQuantizeOp(OpTest): + def setUp(self): + self.op_type = 'quantize' + self.scale = 2.0 + self.input_size = [1, 1, 5, 5] #Naive nChw16c + self.is_negative = False + self.set_scale() + self.set_is_negative() + + if self.is_negative: + input = (100 * np.random.random_sample(self.input_size) - 50 + ).astype('float32') + output = np.round(input * self.scale).astype('int8') + else: + input = (100 * + np.random.random_sample(self.input_size)).astype('float32') + output = np.round(input * self.scale).astype('uint8') + + self.inputs = {'Input': OpTest.np_dtype_to_fluid_dtype(input)} + + self.outputs = {'Output': output} + + self.attrs = { + 'Scale': self.scale, + 'is_negative_input': self.is_negative + } + + def test_check_output(self): + self.check_output() + + def set_scale(self): + pass + + def set_is_negative(self): + pass + + +class TestQuantizeOp1(TestQuantizeOp): + def set_scale(self): + self.scale = 1.5 + + def set_is_negative(self): + self.is_nagative = True + + +class TestQuantizeOp2(TestQuantizeOp): + def set_scale(self): + self.scale = 0.1 + + def set_is_negative(self): + self.is_nagative = False + + +if __name__ == '__main__': + unittest.main() diff --git a/python/requirements.txt b/python/requirements.txt index 2f81d85df0626b294f4d861706b5c1b7ec9841d5..03d5e33e88cd5f1138ca8f6a6e885d6acfbc260e 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -9,3 +9,5 @@ Pillow nltk>=3.2.2 graphviz six +funcsigs +pyyaml diff --git a/python/setup.py.in b/python/setup.py.in index bfcd47f5b0a6c2ede8d3e4d25e97ed809360cbca..c9afe6c885658b88ac520aad2e7b13facda02a92 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -109,6 +109,10 @@ packages=['paddle', 'paddle.fluid.contrib', 'paddle.fluid.contrib.decoder', 'paddle.fluid.contrib.quantize', + 'paddle.fluid.contrib.slim', + 'paddle.fluid.contrib.slim.core', + 'paddle.fluid.contrib.slim.graph', + 'paddle.fluid.contrib.slim.prune', 'paddle.fluid.contrib.utils', 'paddle.fluid.transpiler', 'paddle.fluid.transpiler.details']